import nc_time_axis
import numpy as np
import pandas as pd
#from matplotlib import pyplot as plt
from statsmodels.distributions.empirical_distribution import ECDF
import glob, sys, os
import datetime
#import timeit
import cftime
import xarray as xr
import time

#import warnings
#warnings.filterwarnings('ignore')


infile="../data/cmip6/global/historical/ACCESS-CM2/r1i1p1f1/day/county/tasmin_day_ACCESS-CM2_historical_r1i1p1f1_gn_merged.nc"
outfile="../data/cmip6_bc/global/historical/ACCESS-CM2/r1i1p1f1/day/county/tasmin_day_ACCESS-CM2_historical_r1i1p1f1_gn_merged_bc.nc"
var="tasmin"

#var=sys.argv[1]
#infile=sys.argv[2]
#outfile=sys.argv[3]

frefy,lrefy = 1981,2010


obsfile="../data/nclimgrid/{}_day_nclimgrid_19510101-20250131.nc".format(var)



if var=="pr":
    varunit="mm/day"
    long_name="precipitation"
    traceobs=0.5
    tracegcm=traceobs
    gcmmult=86400
    gcmadd=0
    maxratio=2
else:
    gcmmult=1
    gcmadd=-273.15
    varunit="degC"
    if var=="tasmin":
        long_name="minimum temperature"
    if var=="tasmax":
        long_name="maximum temperature"
    if var=="tas":
        long_name="mean temperature"



def gettime(start):
    end = time.perf_counter()
    print(f"Execution time: {end - start:.6f} seconds")


def do_bc(_obs,_refgcm,_gcm):
#    print(i)
#    i=i+1
    #print(_obs,_refgcm,_gcm)
    if np.isnan(_obs).sum()==0:
        output=[]
        for year in np.unique(gyears):
            for m in range(1,13):
                source=(gmonths==m) & (gyears==year)
                oref=omonths==m
                gref=grefmonths==m                
                sourcewindow=(gyears<np.int16(year+15)) & (gyears>np.int16(year-15)) & (gmonths==m)
                
                _projgcm=_gcm[source] #raw gcm data - 1 month
                _projgcmwindow=_gcm[sourcewindow] #window around given month gcm data     
                _histgcm=_refgcm[gref] #reference period gcm
                _histobs=_obs[oref] #reference period obs
                #print(_projgcmwindow, sourcewindow.sum(), gyears, year, m, (gmonths==m).sum())
                ecdfprojgcmwindow = ECDF(_projgcmwindow)
                quantprojgcm=ecdfprojgcmwindow(_projgcm) #quantiles
                histvalues=np.quantile(_histgcm, quantprojgcm) #part of eq.4 
                obsvalues=np.quantile(_histobs, quantprojgcm) #eq. 5
                
                if var in ["pr","wnd","hurs"]:
                    proj2histratio=_projgcm/histvalues #eq. 4
                    final=obsvalues*proj2histratio #eq. 6
                else:
                    proj2histratio=_projgcm-histvalues #eq. 4
                    final=obsvalues+proj2histratio #eq. 6
                if np.max(final)>10000:
                    print("ERROR")
                    print(m, year)
                    s=np.where(final==np.max(final))[0][0]
                    print("final",final[s])
                    
                    print("quantprojgcm",quantprojgcm[s])
                    print("projgcm",_projgcm[s])
                    print("histvalues",histvalues[s])
                    print("obs",obsvalues[s])
                    print("proj2hist",proj2histratio[s])
                    sys.exit()
                output=output+list(final)
#        sys.exit()
        output=np.array(output)
    else:
        output=np.array([np.nan]*len(gmonths))
    return(output)

print("reading data")
tim = time.perf_counter()
obsdata=xr.open_mfdataset(obsfile)[var]
if var=="pr":
    print("adding trace...")
    rdata=np.copy(obsdata)
    rdata[:]=np.random.uniform(traceobs/10,traceobs,len(rdata.flatten())).reshape(*rdata.shape)
    obsdata=obsdata.where((obsdata<traceobs)==False, rdata)
    del rdata


#reading GCM file
gcmds=xr.open_mfdataset(infile)

gcmdata=(gcmds[var]*gcmmult)+gcmadd

#making sure indices are strings
gcmdata["nclimcode"]=gcmdata.nclimcode.astype(str)

#removing negative values
if var in ["pr","wnd","hurs"]:
    gcmdata=gcmdata.where(gcmdata>0,0)

#filling zeros with random values
if var=="pr":
    rdata=np.copy(gcmdata)
    rdata[:]=np.random.uniform(traceobs/10,traceobs,len(rdata.flatten())).reshape(*rdata.shape)
    gcmdata=gcmdata.where((gcmdata<traceobs)==False, rdata)
    del rdata


#these are globals used in the bc function    
oyears=obsdata.sel(time=slice(str(frefy),str(lrefy))).time.dt.year.data
omonths=obsdata.sel(time=slice(str(frefy),str(lrefy))).time.dt.month.data
                    
gyears=gcmdata.time.dt.year.data
gmonths=gcmdata.time.dt.month.data
grefmonths=gcmdata.sel(time=slice(str(frefy),str(lrefy))).time.dt.month.data


#making sure that obs and gcmdata aligns along the nclimcode dimension
sel=obsdata.nclimcode.isin(gcmdata.nclimcode)
obsdata=obsdata.loc[:,sel]
obsdata["nclimcode"]=obsdata.nclimcode.astype(np.dtype("U5"))
obsdata=obsdata.loc[:,gcmdata.nclimcode.data]
gettime(tim)

print("loading data")
tim = time.perf_counter()
obsdata=obsdata.load()
gcmdata=gcmdata.load()
gettime(tim)

print("performing bias correction")
tim = time.perf_counter()
temp=xr.apply_ufunc(
    do_bc,
    obsdata.sel(time=slice(str(frefy),str(lrefy))).rename({"time":"times1"}),
    gcmdata.sel(time=slice(str(frefy),str(lrefy))).rename({"time":"times"}),
    gcmdata,
    input_core_dims=[["times1"],["times"],["time"]],
    output_core_dims=[["time"]],
    vectorize=True
)
gettime(tim)

print("aligning dimensions")
gcmdatabc=temp.transpose("time","nclimcode")

gcmdatabc=gcmdatabc.astype(np.float32)

if var=="pr":
    gcmdata=gcmdatabc.where(gcmdatabc>traceobs,0)

print("converting to dataset")
outds=gcmdatabc.to_dataset(name=var)




print("adding attributes")


attributes={
    "title":"Bias corrected CMIP6 model climate data from historical experiment",
    "source":"bias corrected climate model data",
    "contributor_name":"University of Cape Town",
    "contributor_role":"Performed bias correction of raw CMIP6 data to nClimGrid-Daily data using python code implementing QDM approach",
    "creator_url":"www.csag.uct.ac.za",
    "creator_email":"wolski@csag.uct.ac.za",
    "comment":"This data product is generated by the following processing steps: 1) daily CMIP6 model data for historical and hist-nat experiments, in original spatial resolution, were subset to domain of interest (contiguous USA) 2) GCM data for US counties were derived from the corse grid GCM data using custom python code that implements zonal_stats function in time-optimal manner 3) bias correction at county-level was performed using QDM approach (Cannon et al. 2015) implemented through python code. That bias correction was performed with respect to the reference period of 1981 to 2010, using 30/10 window definition.",
    "data_licence": 'CC BY 4.0',
    "acknowledgement":"This work was conducted under the Towards Equitable and Sustainable Nature-based Solutions Project (TES NbS).The work was carried out with the aid of a grant from the International Development Research Centre, Ottawa, Canada and also supported with funding from the Government of Flanders.",
    "observed_data_project":"nClimGrid-Daily",
    "observed_data_url":"https://www.ncei.noaa.gov/products/land-based-station/nclimgrid-daily",
    "history":"{}: bc.py ".format(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S')),
}

outds.assign_attrs(attributes)

outds[var].attrs={"unit":varunit, "long_name":long_name}

print("saving")
outds.to_netcdf(outfile)
print("written {}".format(outfile))
