#!/usr/bin/env python
# coding=utf-8

#import matplotlib
#matplotlib.use('Agg')

### STATION COORDINATES ###
station_name = "Zeppelin"
target_lon = 11.887
target_lat = 78.907
year = 2015

import threddsclient
import os
import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, cm, maskoceans
import time
import calendar
from datetime import datetime, timedelta
import glob
import sys
from matplotlib.cbook import get_sample_data
import matplotlib.image as image
from matplotlib.colors import BoundaryNorm, LogNorm, Normalize
import matplotlib.dates as mdates
from matplotlib import dates

histFP_urls = threddsclient.opendap_urls('https://thredds.nilu.no/thredds/catalog/flexpart/IRISCC/catalog.html')
antbc = next(
    f for f in histFP_urls
    if f.endswith(f"_{year}.nc") and "FLEXPART_BC_" in f
)

bbbc = next(
    f for f in histFP_urls
    if f.endswith(f"_{year}.nc") and "fireBC" in f
)

print("antbc =", antbc)
print("bbbc  =", bbbc)

# ========== READ NC ==========
def readNC_fast(
    fname, target_lon, target_lat,
    chunk_t=24, idZy=None, idZx=None,
    ):
    """
    Fast streaming reader:
      - Returns only nct_point: (nt, nz) at one (idZy, idZx)
      - No full 3D fields in memory
      - No vertical integration

    Assumes BC_conc_BB shape: (1, nr, nt, nz, ny, nx)
    """

    print("Reading:", fname)

    with Dataset(fname, "r") as nc:
        for v in ("longitude", "latitude", "time", "height", "BC_conc_BB"):
            try:
                nc[v].set_auto_mask(False)
            except Exception:
                pass

        lons = nc["longitude"][:]
        lats = nc["latitude"][:]
        time = nc["time"][:]
        hei  = nc["altitude"][:]

        # Resolve grid point
        if idZx is None:
            idZx = int(np.abs(lons - target_lon).argmin())
        if idZy is None:
            idZy = int(np.abs(lats - target_lat).argmin())
        print("Point indices (idZy,idZx):", idZy, idZx)

        mr = nc["BC_conc_BB"]
        nt, nz, ny, nx = mr.shape

        # Validate indices
        if not (0 <= idZy < ny and 0 <= idZx < nx):
            raise IndexError(
                f"(idZy,idZx)=({idZy},{idZx}) out of bounds for (ny,nx)=({ny},{nx})."
            )

        # Output: time × height at the point
        nct_point = np.zeros((nt, nz), dtype=np.float32)

        # Temp buffer: (chunk_t, nz)
        tmp = np.empty((chunk_t, nz), dtype=np.float32)

        for t0 in range(0, nt, chunk_t):
            t1 = min(t0 + chunk_t, nt)
            tlen = t1 - t0

            tmp[:tlen].fill(0.0)
            # Direct slice: (tlen, nz)
            blk = mr[t0:t1, :, idZy, idZx].astype(np.float32, copy=False)

            tmp[:tlen] = blk

            nct_point[t0:t1, :] = tmp[:tlen]

    # Time conversion
    base = datetime(1970, 1, 1)
    times = [base + timedelta(days=float(t)) for t in time]

    return lons, lats, times, hei, nct_point, idZy, idZx

lons, lats, tt, hei, nct1, idZy, idZx = readNC_fast(antbc,target_lon, target_lat, 24)
_, _, _, _, nct2, _, _ = readNC_fast(bbbc,target_lon, target_lat, 24)

print(nct1.shape, nct2.shape)


#### PLOT FUNCTIONS BELOW ####
def lineplot_full(ax, dat, nct1, nct2, idZy=None, idZx=None, level=0,
                  label1="Anthropogenic", label2="Wildfire",
                  ylabel=r'Surface BC (ng m$\mathrm{\mathsf{^{-3}}}$)',
                  title=f'Surface concentration of Black Carbon at {station_name}',
                  ylim=None):
    """
    Stacked area plot of nct1 + nct2.
    y-limits are automatic unless ylim=(ymin, ymax) is provided.
    """

    dat = np.asarray(dat)
    nct1 = np.asarray(nct1)
    nct2 = np.asarray(nct2)

    def select_series(arr):
        if arr.ndim == 1:
            return arr
        elif arr.ndim == 2:      # (nt, nz)
            return arr[:, level]
        elif arr.ndim == 3:      # (nt, ny, nx)
            if idZy is None or idZx is None:
                raise ValueError("idZy and idZx required for (nt, ny, nx)")
            return arr[:, idZy, idZx]
        else:
            raise ValueError(f"Unsupported array shape: {arr.shape}")

    a = select_series(nct1)
    b = select_series(nct2)

    if not (len(dat) == len(a) == len(b)):
        raise ValueError("Time and data lengths do not match")

    # stacked fill
    ax.fill_between(dat, 0.0, a, alpha=0.7, label=label1)
    ax.fill_between(dat, a, a + b, alpha=0.7, label=label2)

    # total outline
    total = a + b
    ax.plot(dat, total, color="k", linewidth=1.2)

    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_title(title, fontsize=14)

    fmt = mdates.DateFormatter('%d-%b')
    ax.xaxis.set_major_locator(mdates.DayLocator(interval=15))
    ax.xaxis.set_major_formatter(fmt)
    ax.tick_params(axis="x", rotation=30)
    ax.set_xlim(dat[0], dat[-1])
    
    # --- automatic y-limits ---
    if ylim is not None:
        ax.set_ylim(ylim)
    else:
        ymax = np.nanmax(total)
        if ymax > 0:
            ax.set_ylim(0, ymax * 1.1)   # 10% headroom
        else:
            ax.set_ylim(0, 1.0)

    ax.legend(frameon=False)


def cross(fig, ax, tt, hei, nct,
          ylabel=r'Height (km)',
          title=f'Vertical cross-section of Black Carbon concentrations at {station_name}',
          log=False, log_base=2,
          min_v=0.0, max_v=6.0, freq=0.5):


    cmap = plt.get_cmap("rainbow")

    # --- Build time–height grid (growing window in time) ---
    x, y = np.meshgrid(tt, np.asarray(hei) * 1e-3)

    # --- Extract and orient data ---
    # nct expected as (nt, nz) OR (nz, nt) for the selected point already
    Z = np.asarray(nct)[:, :]  # ng/m3
    if Z.shape != x.shape:
        Z = Z.T
    if Z.shape != x.shape:
        raise ValueError(
            f"Z shape {Z.shape} does not match grid shape {x.shape}. "
            "Expected nct to be (nt,nz) or (nz,nt)."
        )

    # --- Clear axis (useful for loop/animation) ---
    ax.clear()

    # --- Choose normalization + levels + ticks ---
    if log:
        min_v = float(min_v)
        max_v = float(max_v)
        if not (min_v > 0 and max_v > min_v):
            raise ValueError("For log scaling, require 0 < min_v < max_v")

        # Replace non-positive values with min_v so log plotting never crashes
        Z = np.asarray(Z, dtype=float)
        Z[Z <= 0] = min_v

        # Older matplotlib: LogNorm has no 'base' kwarg
        norm = LogNorm(vmin=min_v, vmax=max_v)

        # Explicit log-spaced contour levels (prevents NaN/locator issues)
        pmin = np.log(min_v) / np.log(log_base)
        pmax = np.log(max_v) / np.log(log_base)
        clevs = np.power(float(log_base), np.linspace(pmin, pmax, 256))

        # Ticks at powers of base (2^n etc.)
        pmin_i = int(np.floor(np.log(min_v) / np.log(log_base)))
        pmax_i = int(np.ceil(np.log(max_v) / np.log(log_base)))
        tick_positions = np.power(float(log_base),
                                  np.arange(pmin_i, pmax_i + 1, dtype=float))
        tick_positions = tick_positions[(tick_positions >= min_v) & (tick_positions <= max_v)]
        tick_labels = [f"{t:g}" for t in tick_positions]
        
    else:
        min_v = float(min_v)
        max_v = float(max_v)
        norm = Normalize(vmin=min_v, vmax=max_v)
        clevs = np.linspace(min_v, max_v, 256)

        tick_positions = np.arange(min_v, max_v + freq, freq, dtype=float)
        tick_labels = [f"{t:g}" for t in tick_positions]

    # --- Plot ---
    cs1 = ax.contourf(x, y, Z, clevs, cmap=cmap, extend="both", norm=norm)

    # --- Axis formatting ---
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_title(title, fontsize=14)
    
    ax.set_xlim(tt[0], tt[-1])
    ax.set_ylim(0, 20.01)

    hfmt = dates.DateFormatter("%d-%b")
    ax.xaxis.set_major_locator(dates.DayLocator(interval=15))
    ax.xaxis.set_major_formatter(hfmt)
    ax.tick_params(axis="x", rotation=30)

    # --- Remove any previous colorbars (prevents stacking in a loop) ---
    for cax in getattr(ax, "_cross_cbar_axes", []):
        try:
            cax.remove()
        except Exception:
            pass
    ax._cross_cbar_axes = []

    # --- Colorbar placement under the subplot using bbox ---
    bbox = ax.get_position()
    fraction = 0.05  # colorbar height = 5% of axes height
    pad = 0.15       # gap = 15% of axes height

    cb_height = fraction * bbox.height
    cb_pad = pad * bbox.height

    cbaxes = fig.add_axes([
        bbox.x0,
        bbox.y0 - cb_pad - cb_height,
        bbox.width,
        cb_height
    ])
    ax._cross_cbar_axes.append(cbaxes)

    cbar = fig.colorbar(cs1, cax=cbaxes, ticks=tick_positions, orientation="horizontal")
    cbar.minorticks_off()
    cbar.set_ticklabels(tick_labels)
    cbar.set_label(r"SO$_2$ concentration ($\mu$g m$\mathrm{^{-3}}$)", fontsize=16)

    return cs1
    
    
################################
fig = plt.figure(figsize=(10,10))
F1 = fig.add_subplot(211)
lineplot_full(F1, tt, nct1, nct2, idZy=None, idZx=None, level=0,
                  label1="Anthropogenic", label2="Wildfire",
                  ylabel=r'Surface BC (ng m$\mathrm{\mathsf{^{-3}}}$)',
                  title=f'Surface concentration of Black Carbon at {station_name}',
                  ylim=None)

F2 = fig.add_subplot(212)
cross(fig, F2, tt, hei, (nct1+nct2),
      log=log, log_base=log_base,
      min_v=min_v, max_v=max_v, freq=freq)

fig.show()
