import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.basemap import Basemap
import netCDF4
import numpy as np
import pandas as pd
import math as ma
import calendar as cd
from scipy.spatial.distance import pdist, squareform


def nan_helper(y):
    return np.isnan(y), lambda z: z.nonzero()[0]


def find_index_nearest(array,value):
    idx = (np.abs(array-value)).argmin()
    return idx

def months_to_seasons(array):
    winter=[12,1,2]
    spring=[3,4,5]
    summer=[6,7,8]
    autumn=[9,10,11]
    seasons=[spring,summer,autumn,winter]
    idxoff=[1,2,len()]
    return 

def histo(map_obs,map_mod,mask_05,nb_zones,zones,ind_zones,bins,noun,yvec):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' Distribution Comparison of Soil Moisture [m3/m3]'+yvec, fontsize=18)
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        print(zones[zo])
        aux_obs=map_obs*reg_mask_05
        aux_obs[aux_obs==0]=np.nan
        aux_mod=map_mod+aux_obs-aux_obs
        n_obs, bins_obs, patches_obs = ax.hist(aux_obs.flatten(), bins, facecolor='Maroon', alpha=0.5, label='ESA-CCI')
        n_mod, bins_mod, patches_mod = ax.hist(aux_mod.flatten(), bins, facecolor='DarkGreen', alpha=0.5, label='ORCHIDEE')
        ax.set_title(zones[zo],fontsize=20)
        #ax.ylabel('Number of data',fontsize=20)
        #plt.setp(ax.get_xticklabels(), fontsize=18)
        #plt.setp(ax.get_yticklabels(), fontsize=18)
        if (i==1):
            ax.legend(loc='upper right',prop={'size':18})
    plt.subplots_adjust(left=0.15)
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/hist_regions_'+noun+'_'+yvec+'_summer.png')
    plt.close()

def seasonal_cycle(map_obs,map_mod,areacell,mask_05,nb_zones,zones,ind_zones,noun,yvec,ntime):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' seasonal_cycle '+yvec, fontsize=18)
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        print(zones[zo])
        reg_mask_05=np.float_(reg_mask_05)
        reg_mask_05[reg_mask_05==0]=np.nan
        areacellreg=areacell*reg_mask_05
        aux_obs=map_obs*reg_mask_05
        aux_mod=map_mod+aux_obs-aux_obs
        obs_season=np.zeros((12,aux_obs.shape[1],aux_mod.shape[2]))
        mod_season=np.zeros((12,aux_obs.shape[1],aux_mod.shape[2]))
        for t in range(12):
            obs_season[t,:,:]=np.nanmean(aux_obs[range(t,ntime,12),:,:],axis=0)
            mod_season[t,:,:]=np.nanmean(aux_mod[range(t,ntime,12),:,:],axis=0)
        #obs=np.nanmean(obs_season,axis=(1,2))
        #mod=np.nanmean(mod_season,axis=(1,2))
        obstmp=areacellreg*obs_season
        modtmp=areacellreg*mod_season
        obs=np.nansum(obstmp,axis=(1,2))/np.nansum(areacellreg)
        mod=np.nansum(modtmp,axis=(1,2))/np.nansum(areacellreg)
        x=np.arange(12)
        mons=cd.month_abbr[1:13]
        ax.plot(x,obs,color="black",label='ESA-CCI')
        ax.plot(x,mod,color="dodgerblue", label='ORCHIDEE')
        ax.set_xticks(x)
        ax.set_xticklabels(mons)
        ax.set_title(zones[zo],fontsize=20)
        if (i==1):
            ax.legend(loc='upper right',prop={'size':18})
    plt.subplots_adjust(left=0.15)
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/seasonal_cycle_'+noun+'_'+yvec+'.png')
    plt.close()

def timeserie(map_obs,map_mod,mask_05,nb_zones,zones,ind_zones,noun,yvec,dates):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' Timeserie '+yvec, fontsize=18)
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        print(zones[zo])
        aux_obs=map_obs*reg_mask_05
        aux_obs[aux_obs==0]=np.nan
        aux_mod=map_mod+aux_obs-aux_obs
        obs=np.nanmean(aux_obs,axis=(1,2))
        mod=np.nanmean(aux_mod,axis=(1,2))
        ax.plot(dates,obs,color="black",label='ESA-CCI')
        ax.plot(dates,mod,color="dodgerblue", label='ORCHIDEE')
        ax.set_title(zones[zo],fontsize=20)
        if (i==1):
            ax.legend(loc='upper right',prop={'size':18})
    plt.subplots_adjust(left=0.15)
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/timeserie_'+noun+'_'+yvec+'_test.png')
    plt.close()

def spacor(map_obs,map_mod,mask_05,nmonths,nb_zones,zones,yvec,noun,ind_zones,dates):
    for zo in range(nb_zones):
        nc = netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        obs_mod=np.zeros(nmonths)
        aux_obs=map_obs*reg_mask_05
        aux_obs[aux_obs==0]=np.nan
        aux_mod=map_mod+aux_obs-aux_obs
        for i in range(nmonths):
            obs_mod[i]=cor_serie(aux_obs[i,:,:].flatten(),aux_mod[i,:,:].flatten())
        obs_mod = pd.Series(obs_mod)
        obs_mod = pd.rolling_mean(obs_mod, 30, min_periods=1)
        fig, ax = plt.subplots()
        ax.plot(dates, obs_mod, color = 'gold', label="monthly, ave=%.2f" % np.nanmean(np.array(obs_mod)))
        ax.set_ylabel(noun+' Correlation coefficient')
        ax.set_ylim([0.,1.])
        plt.legend(loc='lower right')
        plt.xlabel('Time')
        plt.xticks(rotation=70)
        plt.title(noun+' '+zones[zo]+' '+yvec)
        plt.tight_layout()
        plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/spacor_'+noun+'_'+zones[zo]+'_'+yvec+'_test.png')
        print(zones[zo])
        plt.close()


def bias_map(data_mod,data_obs,lat,lon,yvec,noun):
    diff=mbe_alldata(data_obs,data_mod)
    diff=np.ma.masked_invalid(diff)
    fig=plt.figure()
    m = Basemap(projection='mill', llcrnrlat=-90, urcrnrlat=90, llcrnrlon=-180,urcrnrlon=180, resolution='c')
    m.drawcoastlines(linewidth=0.5)
    cmap = cm.get_cmap('RdBu')
    cmesh = m.pcolormesh(lon, lat, diff, shading="flat", cmap=cmap, latlon=True, vmin=-0.1, vmax=0.1)
    cbar = m.colorbar(cmesh, pad = 0.08)
    cbar.set_label('Bias Error (m3/m3)')
    plt.title(noun+', Mean = %.3f' % np.nanmean(diff))
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/bias_'+noun+'_'+yvec+'_test.png')
    plt.close()


def cor_map(sm_mod,sm_obs,lat,lon,yvec,noun):
    diff=cor_alldata(sm_obs,sm_mod)
    diff=np.ma.masked_invalid(diff)
    fig=plt.figure()
    m = Basemap(projection='mill', llcrnrlat=-90, urcrnrlat=90, llcrnrlon=-180,urcrnrlon=180, resolution='c')
    m.drawcoastlines(linewidth=0.5)
    cmap = cm.get_cmap('jet')
    cmesh = m.pcolormesh(lon, lat, diff, shading="flat", cmap=cmap, latlon=True, vmin=-1., vmax=1.)
    cbar = m.colorbar(cmesh, pad = 0.08)
    cbar.set_label('Correlation Coefficient')
    plt.title(noun+', Mean = %.3f' % np.nanmean(diff))
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/cor_'+noun+'_'+yvec+'_test.png')
    plt.close()


def rmsd_map(data_mod,data_obs,lat,lon,yvec,noun):
    diff=rmsd_alldata(data_obs,data_mod)
    diff=np.ma.masked_invalid(diff)
    fig=plt.figure()
    m = Basemap(projection='mill', llcrnrlat=-90, urcrnrlat=90, llcrnrlon=-180,urcrnrlon=180, resolution='c')
    m.drawcoastlines(linewidth=0.5)
    cmap = plt.cm.get_cmap('YlOrRd')
    cmesh = m.pcolormesh(lon, lat, diff, shading="flat", cmap=cmap, latlon=True, vmin=0, vmax=1)
    cbar = m.colorbar(cmesh, pad = 0.08)
    cbar.set_label('normalised RMSD')
    plt.title(noun+', Mean = %.3f' % np.nanmean(diff))
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/rmsd_'+noun+'_'+yvec+'_test.png')
    plt.close()


def mbe_alldata(sm_obs,sm_mod):
    return np.nanmean(sm_mod-sm_obs,axis=0)


def rmsd_alldata(sm_obs,sm_mod):
    aux_obs=(sm_obs-np.nanmean(sm_obs,axis=0))*np.nanstd(sm_mod,axis=0)/np.nanstd(sm_obs,axis=0)+np.nanmean(sm_mod,axis=0)
    obs_range=np.nanpercentile(sm_obs,90,axis=0)-np.nanpercentile(sm_obs,10,axis=0)
    return np.sqrt(np.nanmean((sm_mod-aux_obs)**2,axis=0))/obs_range


def lagk_autocor(data, k):
    N=len(data)
    return cov_serie(data[0:N-k],data[k:N])/np.nanvar(data)


def cov_alldata(sm_obs,sm_mod):
    mod=np.zeros(sm_mod.shape)
    obs=np.zeros(sm_obs.shape)
    mean_obs=np.nanmean(sm_obs,axis=0)
    mean_mod=np.nanmean(sm_mod,axis=0)
    time=sm_obs.shape[0]
    for t in range(time):
        mod[t,:,:]=sm_mod[t,:,:]-mean_mod
        obs[t,:,:]=sm_obs[t,:,:]-mean_obs
    return np.nanmean(mod*obs,axis=0)


def cor_alldata(sm_obs,sm_mod):
    return cov_alldata(sm_obs,sm_mod)/(np.nanstd(sm_mod,axis=0)*np.nanstd(sm_obs,axis=0))


def cov_serie(sm_obs,sm_mod):
    return np.nanmean((sm_mod-np.nanmean(sm_mod))*(sm_obs-np.nanmean(sm_obs)))


def cor_serie(sm_obs,sm_mod):
    return cov_serie(sm_obs,sm_mod)/(np.nanstd(sm_mod)*np.nanstd(sm_obs))


def day_to_month(cum_array,day):
    i=0
    while day>cum_array[i]:
        i+=1
    return i 


def autocor(map_obs,map_mod,mask_05,nb_zones,zones,noun,ind_zones,cut,nmonths):
    kvec=range(cut)
    for i in range(nb_zones):
        print zones[i]
        autocor_obs=np.zeros(cut)
        autocor_mod=np.zeros(cut)
        nc = netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[i],:,:]
        #print 'ici0'
        nc.close()
        reg_mask_05=np.float_(reg_mask_05)
        reg_mask_05[reg_mask_05==0]=np.nan
        obs=map_obs*reg_mask_05
        #print 'ici1'
        obs=np.nanmean(np.nanmean(obs,axis=-1),axis=-1)
        #print 'ici2'
        mod=map_mod*reg_mask_05
        mod=np.nanmean(np.nanmean(mod,axis=-1),axis=-1)
        b_obs=False
        b_mod=False
        #print 'ici1'
        for k in kvec:
            autocor_obs[k]=lagk_autocor(np.array(obs),k)
            autocor_mod[k]=lagk_autocor(np.array(mod),k)
        autocor_obs[autocor_obs==0]=np.nan
        autocor_obs=pd.Series(autocor_obs).interpolate(method='pchip')
        autocor_mod[autocor_mod==0]=np.nan
        autocor_mod=pd.Series(autocor_mod).interpolate(method='pchip')
        for k in kvec:
            if b_obs==False and autocor_obs[k]<0.37:
                b_obs=True
                l_obs=k
            if b_mod==False and autocor_mod[k]<0.37:
                b_mod=True
                l_mod=k
        #print 'ici'
        f, ax = plt.subplots(figsize=(30,20))
        plt.plot(kvec, autocor_obs, color = 'dodgerblue',  linewidth = 3., label='MJung (lag_time=%.0f)' % l_obs)
        plt.plot(kvec, autocor_mod, color = 'limegreen',  linewidth = 3., label='v5.67PDay01(lag_time=%.0f)' % l_mod)
        ax.axhline(y=0.37,c="black",linewidth=1., linestyle='--')
        ax.axhline(y=0.,c="black",linewidth=1.)
        plt.ylim([-0.2,1.2])
        plt.xlim([0.,cut])
        plt.xlabel('Time Lag [d]',fontsize=25)
        plt.ylabel('Autocorrelation coefficient',fontsize=25)
        plt.setp(ax, yticks=[0.,0.5,1.], xticks=np.linspace(0,cut,4))
        plt.setp(ax.get_xticklabels(), fontsize=20)
        plt.setp(ax.get_yticklabels(), fontsize=20)
        plt.legend(loc='upper right', ncol=2, fontsize=25)
        plt.title(zones[i]+" ",fontsize=30)
        plt.savefig('/data/sli/PC/autocorr/output/autocor_'+noun+'_'+zones[i]+'.png')
        plt.close()


def autocor_multi(map_obs,map_mod1,map_mod2,map_mod3,map_mod4,map_mod5,map_mod6,map_mod7,map_mod8,map_mod9,map_mod10, mask_05,nb_zones,zones,noun,ind_zones,cut,nmonths):
    kvec=range(cut)
    default = 0.5 #0.37
    for i in range(nb_zones):
        print zones[i]
        autocor_obs=np.zeros(cut)
        autocor_mod1=np.zeros(cut)
        autocor_mod2=np.zeros(cut)
        autocor_mod3=np.zeros(cut)
        autocor_mod4=np.zeros(cut)
        autocor_mod5=np.zeros(cut)
        autocor_mod6=np.zeros(cut)
        autocor_mod7=np.zeros(cut)
        autocor_mod8=np.zeros(cut)
        autocor_mod9=np.zeros(cut)
        autocor_mod10=np.zeros(cut)
        autocor_mod11=np.zeros(cut)
        nc = netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[i],:,:]
        #print 'ici0'
        nc.close()
        reg_mask_05=np.float_(reg_mask_05)
        reg_mask_05[reg_mask_05==0]=np.nan
        obs=map_obs*reg_mask_05
        #print 'ici1'
        obs=np.nanmean(np.nanmean(obs,axis=-1),axis=-1)
        #print 'ici2'
        mod1=map_mod1*reg_mask_05
        mod2=map_mod2*reg_mask_05
        mod3=map_mod3*reg_mask_05
        mod4=map_mod4*reg_mask_05
        mod5=map_mod5*reg_mask_05
        mod6=map_mod6*reg_mask_05
        mod7=map_mod7*reg_mask_05
        mod8=map_mod8*reg_mask_05
        mod9=map_mod9*reg_mask_05
        mod10=map_mod10*reg_mask_05

        mod1=np.nanmean(np.nanmean(mod1,axis=-1),axis=-1)
        mod2=np.nanmean(np.nanmean(mod2,axis=-1),axis=-1)
        mod3=np.nanmean(np.nanmean(mod3,axis=-1),axis=-1)
        mod4=np.nanmean(np.nanmean(mod4,axis=-1),axis=-1)
        mod5=np.nanmean(np.nanmean(mod5,axis=-1),axis=-1)
        mod6=np.nanmean(np.nanmean(mod6,axis=-1),axis=-1)
        mod7=np.nanmean(np.nanmean(mod7,axis=-1),axis=-1)
        mod8=np.nanmean(np.nanmean(mod8,axis=-1),axis=-1)
        mod9=np.nanmean(np.nanmean(mod9,axis=-1),axis=-1)
        mod10=np.nanmean(np.nanmean(mod10,axis=-1),axis=-1)

        b_obs=False
        b_mod1=False
        b_mod2=False
        b_mod3=False
        b_mod4=False
        b_mod5=False
        b_mod6=False
        b_mod7=False
        b_mod8=False
        b_mod9=False
        b_mod10=False

        #print 'ici1'
        for k in kvec:
            autocor_obs[k]=lagk_autocor(np.array(obs),k)
            autocor_mod1[k]=lagk_autocor(np.array(mod1),k)
            autocor_mod2[k]=lagk_autocor(np.array(mod2),k)
            autocor_mod3[k]=lagk_autocor(np.array(mod3),k)
            autocor_mod4[k]=lagk_autocor(np.array(mod4),k)
            autocor_mod5[k]=lagk_autocor(np.array(mod5),k)
            autocor_mod6[k]=lagk_autocor(np.array(mod6),k)
            autocor_mod7[k]=lagk_autocor(np.array(mod7),k)
            autocor_mod8[k]=lagk_autocor(np.array(mod8),k)
            autocor_mod9[k]=lagk_autocor(np.array(mod9),k)
            autocor_mod10[k]=lagk_autocor(np.array(mod10),k)
        autocor_obs[autocor_obs==0]=np.nan
        autocor_obs=pd.Series(autocor_obs).interpolate(method='pchip')
        autocor_mod1[autocor_mod1==0]=np.nan
        autocor_mod2[autocor_mod2==0]=np.nan
        autocor_mod3[autocor_mod3==0]=np.nan
        autocor_mod4[autocor_mod4==0]=np.nan
        autocor_mod5[autocor_mod5==0]=np.nan
        autocor_mod6[autocor_mod6==0]=np.nan
        autocor_mod7[autocor_mod7==0]=np.nan
        autocor_mod8[autocor_mod8==0]=np.nan
        autocor_mod9[autocor_mod9==0]=np.nan
        autocor_mod10[autocor_mod10==0]=np.nan
        autocor_mod1=pd.Series(autocor_mod1).interpolate(method='pchip')
        autocor_mod2=pd.Series(autocor_mod2).interpolate(method='pchip')
        autocor_mod3=pd.Series(autocor_mod3).interpolate(method='pchip')
        autocor_mod4=pd.Series(autocor_mod4).interpolate(method='pchip')
        autocor_mod5=pd.Series(autocor_mod5).interpolate(method='pchip')
        autocor_mod6=pd.Series(autocor_mod6).interpolate(method='pchip')
        autocor_mod7=pd.Series(autocor_mod7).interpolate(method='pchip')
        autocor_mod8=pd.Series(autocor_mod8).interpolate(method='pchip')
        autocor_mod9=pd.Series(autocor_mod9).interpolate(method='pchip')
        autocor_mod10=pd.Series(autocor_mod10).interpolate(method='pchip')

        for k in kvec:
            if b_obs==False and autocor_obs[k]<default:
                b_obs=True
                l_obs=k
            if b_mod1==False and autocor_mod1[k]<default:
                b_mod1=True
                l_mod1=k
            if b_mod2==False and autocor_mod2[k]<default:
                b_mod2=True
                l_mod2=k
            if b_mod3==False and autocor_mod3[k]<default:
                b_mod3=True
                l_mod3=k
            if b_mod4==False and autocor_mod4[k]<default:
                b_mod4=True
                l_mod4=k
            if b_mod5==False and autocor_mod5[k]<default:
                b_mod5=True
                l_mod5=k
            if b_mod6==False and autocor_mod6[k]<default:
                b_mod6=True
                l_mod6=k
            if b_mod7==False and autocor_mod7[k]<default:
                b_mod7=True
                l_mod7=k
            if b_mod8==False and autocor_mod8[k]<default:
                b_mod8=True
                l_mod8=k
            if b_mod9==False and autocor_mod9[k]<default:
                b_mod9=True
                l_mod9=k
            if b_mod10==False and autocor_mod10[k]<default:
                b_mod10=True
                l_mod10=k
        #print 'ici'
        f, ax = plt.subplots(figsize=(30,20))
        plt.plot(kvec, autocor_obs, color = 'black',  linewidth = 3., label='MJung (lag_time=%.0f)' % l_obs)
        plt.plot(kvec, autocor_mod1, color = 'limegreen',  linewidth = 3., label='IPSL-CM5A-MR' )
        plt.plot(kvec, autocor_mod2, color = 'peru',  linewidth = 3., label='CM605-LR-pdCtrl-01' )
        plt.plot(kvec, autocor_mod3, color = 'green',  linewidth = 3., label='CM605.calv-LR-pdCtrl-02')
        plt.plot(kvec, autocor_mod4, color = 'red',  linewidth = 3., label='CM605.dt20-LR-pdCtrl-02' )
        plt.plot(kvec, autocor_mod5, color = 'yellow',  linewidth = 3., label='CM605.GUST-LR-pdCtrl-01')
        plt.plot(kvec, autocor_mod6, color = 'magenta',  linewidth = 3., label='CM605.NOSU-LR-pdCtrl-03')
        plt.plot(kvec, autocor_mod7, color = 'blue',  linewidth = 3., label='CM605.THC1-LR-pdCtrl-01' )
        plt.plot(kvec, autocor_mod8, color = 'brown',  linewidth = 3., label='CM605.Z0-LR-pdCtrl-01' )
        plt.plot(kvec, autocor_mod9, color = 'olive',  linewidth = 3., label='NPv5.70PDctrl01' )
        plt.plot(kvec, autocor_mod10, color = 'cyan',  linewidth = 3., label='5.71vd2qsp375' )

        #ax.axhline(y=0.37,c="black",linewidth=1., linestyle='--')
        ax.axhline(y=0.,c="black",linewidth=1.)
        plt.ylim([-0.6,1.2])
        plt.xlim([0.,cut])
        plt.xlabel('Time Lag [d]',fontsize=25)
        plt.ylabel('Autocorrelation coefficient',fontsize=25)
        plt.setp(ax, yticks=[-0.5,-0.25,0,0.25,0.5,0.75,1.], xticks=np.linspace(0,cut,4))
        plt.setp(ax.get_xticklabels(), fontsize=20)
        plt.setp(ax.get_yticklabels(), fontsize=20)
        plt.legend(loc='upper right', ncol=2, fontsize=25)
        plt.title(zones[i]+" ",fontsize=30)
        plt.savefig('/data/sli/PC/autocorr/output/autocor_'+noun+'_'+zones[i]+'.png')
        plt.close()

def semivario(map_obs_m,map_mod_m,mask_05,nb_zones,zones,noun,ind_zones,days,nlag,lmax):
    for zo in range(nb_zones):
        print zones[zo]
        nc=netCDF4.Dataset(mask_05)
        rmask05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        lons_05 = nc.variables["lon"][:]
        lats_05 = nc.variables["lat"][:]
        nc.close()
        inds05=np.where(rmask05==1)
        lat1_05=np.min(inds05[0])
        lon1_05=np.min(inds05[1])
        lat2_05=np.max(inds05[0])
        lon2_05=np.max(inds05[1])
        rmask05=rmask05[lat1_05:lat2_05+1,lon1_05:lon2_05+1]
        lats_05=lats_05[lat1_05:lat2_05+1]
        lons_05=lons_05[lon1_05:lon2_05+1]
        lat_05=len(lats_05)
        lon_05=len(lons_05)
        plt.figure()
        #print map_obs_m
        gpp_obs_m = np.nanmean(map_obs_m[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        #print gpp_obs_m
        gpp_obs_m= gpp_obs_m*rmask05
        gpp_mod_m = np.nanmean(map_mod_m[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m= gpp_mod_m*rmask05
        #print gpp_obs_m
        data_obs=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        l=0
        for i in range(lat_05):
            for j in range(lon_05):
                #print i
                #print j
                #print (gpp_obs_m[i,j])
                if ~np.isnan(gpp_obs_m[i,j]):
                    data_obs[l,0]=lats_05[i]
                    data_obs[l,1]=lons_05[j]
                    data_obs[l,2]=gpp_obs_m[i,j]
                    data_mod[l,0]=lats_05[i]
                    data_mod[l,1]=lons_05[j]
                    data_mod[l,2]=gpp_mod_m[i,j]
                    l+=1
        D_obs_m,G_obs_m=semivariogram(data_obs, nlag, lmax)
        D_mod_m,G_mod_m=semivariogram(data_mod, nlag, lmax)
        G_obs_m = pd.rolling_mean(G_obs_m, 5, min_periods=1)
        G_mod_m = pd.rolling_mean(G_mod_m, 5, min_periods=1)
        D_obs_m=np.insert(D_obs_m,0,0)
        G_obs_m=np.insert(G_obs_m,0,0) 
        D_mod_m=np.insert(D_mod_m,0,0)
        G_mod_m=np.insert(G_mod_m,0,0)
        plt.plot(D_obs_m,G_obs_m,color='cyan',label='MJUNG')
        plt.plot(D_mod_m,G_mod_m,color='limegreen',label='v5.67PDay01')
        #remove duplicates
        handles, labels = plt.gca().get_legend_handles_labels()
        newLabels, newHandles = [], []
        for handle, label in zip(handles, labels):
            if label not in newLabels:
                newLabels.append(label)
                newHandles.append(handle)
        plt.legend(newHandles,newLabels,ncol=3,loc='upper center',fontsize=10)
        plt.xlabel('Distance [km]')
        plt.ylabel('Semi-Variance')
        plt.title(zones[zo]+" ")
        plt.savefig('/data/sli/PC/autocorr/output/semivario_'+noun+'_'+zones[zo]+'.png')
        plt.close()


def semivario_multi(map_obs_m,map_mod_m1,map_mod_m2,map_mod_m3,map_mod_m4,map_mod_m5,map_mod_m6,map_mod_m7,map_mod_m8,map_mod_m9,map_mod_m10,mask_05,nb_zones,zones,noun,ind_zones,days,nlag,lmax):
    for zo in range(nb_zones):
        print zones[zo]
        nc=netCDF4.Dataset(mask_05)
        rmask05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        lons_05 = nc.variables["lon"][:]
        lats_05 = nc.variables["lat"][:]
        nc.close()
        inds05=np.where(rmask05==1)
        lat1_05=np.min(inds05[0])
        lon1_05=np.min(inds05[1])
        lat2_05=np.max(inds05[0])
        lon2_05=np.max(inds05[1])
        rmask05=rmask05[lat1_05:lat2_05+1,lon1_05:lon2_05+1]
        lats_05=lats_05[lat1_05:lat2_05+1]
        lons_05=lons_05[lon1_05:lon2_05+1]
        lat_05=len(lats_05)
        lon_05=len(lons_05)
        plt.figure()
        gpp_obs_m = np.nanmean(map_obs_m[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_obs_m= gpp_obs_m*rmask05
        gpp_mod_m1 = np.nanmean(map_mod_m1[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m1= gpp_mod_m1*rmask05
        gpp_mod_m2 = np.nanmean(map_mod_m2[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m2= gpp_mod_m2*rmask05
        gpp_mod_m3 = np.nanmean(map_mod_m3[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m3= gpp_mod_m3*rmask05
        gpp_mod_m4 = np.nanmean(map_mod_m4[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m4= gpp_mod_m4*rmask05
        gpp_mod_m5 = np.nanmean(map_mod_m5[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m5= gpp_mod_m5*rmask05
        gpp_mod_m6 = np.nanmean(map_mod_m6[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m6= gpp_mod_m6*rmask05
        gpp_mod_m7 = np.nanmean(map_mod_m7[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m7= gpp_mod_m7*rmask05
        gpp_mod_m8 = np.nanmean(map_mod_m8[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m8= gpp_mod_m8*rmask05
        gpp_mod_m9 = np.nanmean(map_mod_m9[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m9= gpp_mod_m9*rmask05
        gpp_mod_m10 = np.nanmean(map_mod_m10[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        gpp_mod_m10= gpp_mod_m10*rmask05
        #gpp_mod_m11 = np.nanmean(map_mod_m11[:,lat1_05:lat2_05+1,lon1_05:lon2_05+1],axis=0)
        #gpp_mod_m11= gpp_mod_m11*rmask05

        data_obs=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod1=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod2=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod3=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod4=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod5=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod6=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod7=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod8=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod9=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        data_mod10=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))
        #data_mod11=np.zeros((np.count_nonzero(~np.isnan(gpp_obs_m[:,:])),3))

        l=0
        for i in range(lat_05):
            for j in range(lon_05):
                if ~np.isnan(gpp_obs_m[i,j]):
                    data_obs[l,0]=lats_05[i]
                    data_obs[l,1]=lons_05[j]
                    data_obs[l,2]=gpp_obs_m[i,j]
                    data_mod1[l,0]=lats_05[i]
                    data_mod1[l,1]=lons_05[j]
                    data_mod1[l,2]=gpp_mod_m1[i,j]
                    data_mod2[l,0]=lats_05[i]
                    data_mod2[l,1]=lons_05[j]
                    data_mod2[l,2]=gpp_mod_m2[i,j]
                    data_mod3[l,0]=lats_05[i]
                    data_mod3[l,1]=lons_05[j]
                    data_mod3[l,2]=gpp_mod_m3[i,j]
                    data_mod4[l,0]=lats_05[i]
                    data_mod4[l,1]=lons_05[j]
                    data_mod4[l,2]=gpp_mod_m4[i,j]
                    data_mod5[l,0]=lats_05[i]
                    data_mod5[l,1]=lons_05[j]
                    data_mod5[l,2]=gpp_mod_m5[i,j]
                    data_mod6[l,0]=lats_05[i]
                    data_mod6[l,1]=lons_05[j]
                    data_mod6[l,2]=gpp_mod_m6[i,j]
                    data_mod7[l,0]=lats_05[i]
                    data_mod7[l,1]=lons_05[j]
                    data_mod7[l,2]=gpp_mod_m7[i,j]
                    data_mod8[l,0]=lats_05[i]
                    data_mod8[l,1]=lons_05[j]
                    data_mod8[l,2]=gpp_mod_m8[i,j]
                    data_mod9[l,0]=lats_05[i]
                    data_mod9[l,1]=lons_05[j]
                    data_mod9[l,2]=gpp_mod_m9[i,j]
                    data_mod10[l,0]=lats_05[i]
                    data_mod10[l,1]=lons_05[j]
                    data_mod10[l,2]=gpp_mod_m10[i,j]
                    #data_mod11[l,0]=lats_05[i]
                    #data_mod11[l,1]=lons_05[j]
                    #data_mod11[l,2]=gpp_mod_m11[i,j]
                    l+=1
        D_obs_m,G_obs_m=semivariogram(data_obs, nlag, lmax)
        #D_mod_m,G_mod_m=semivariogram(data_mod, nlag, lmax)
        D_mod_m1,G_mod_m1=semivariogram(data_mod1, nlag, lmax)
        D_mod_m2,G_mod_m2=semivariogram(data_mod2, nlag, lmax)
        D_mod_m3,G_mod_m3=semivariogram(data_mod3, nlag, lmax)
        D_mod_m4,G_mod_m4=semivariogram(data_mod4, nlag, lmax)
        D_mod_m5,G_mod_m5=semivariogram(data_mod5, nlag, lmax)
        D_mod_m6,G_mod_m6=semivariogram(data_mod6, nlag, lmax)
        D_mod_m7,G_mod_m7=semivariogram(data_mod7, nlag, lmax)
        D_mod_m8,G_mod_m8=semivariogram(data_mod8, nlag, lmax)
        D_mod_m9,G_mod_m9=semivariogram(data_mod9, nlag, lmax)
        D_mod_m10,G_mod_m10=semivariogram(data_mod10, nlag, lmax)
        #D_mod_m11,G_mod_m11=semivariogram(data_mod11, nlag, lmax)

        G_obs_m = pd.rolling_mean(G_obs_m, 5, min_periods=1)
        G_mod_m1 = pd.rolling_mean(G_mod_m1, 5, min_periods=1)
        G_mod_m2 = pd.rolling_mean(G_mod_m2, 5, min_periods=1)
        G_mod_m3 = pd.rolling_mean(G_mod_m3, 5, min_periods=1)
        G_mod_m4 = pd.rolling_mean(G_mod_m4, 5, min_periods=1)
        G_mod_m5 = pd.rolling_mean(G_mod_m5, 5, min_periods=1)
        G_mod_m6 = pd.rolling_mean(G_mod_m6, 5, min_periods=1)
        G_mod_m7 = pd.rolling_mean(G_mod_m7, 5, min_periods=1)
        G_mod_m8 = pd.rolling_mean(G_mod_m8, 5, min_periods=1)
        G_mod_m9 = pd.rolling_mean(G_mod_m9, 5, min_periods=1)
        G_mod_m10 = pd.rolling_mean(G_mod_m10, 5, min_periods=1)
        #G_mod_m11 = pd.rolling_mean(G_mod_m11, 5, min_periods=1)
        D_obs_m=np.insert(D_obs_m,0,0)
        G_obs_m=np.insert(G_obs_m,0,0)
        D_mod_m1=np.insert(D_mod_m1,0,0)
        G_mod_m1=np.insert(G_mod_m1,0,0)
        D_mod_m2=np.insert(D_mod_m2,0,0)
        G_mod_m2=np.insert(G_mod_m2,0,0)
        D_mod_m3=np.insert(D_mod_m3,0,0)
        G_mod_m3=np.insert(G_mod_m3,0,0)
        D_mod_m4=np.insert(D_mod_m4,0,0)
        G_mod_m4=np.insert(G_mod_m4,0,0)
        D_mod_m5=np.insert(D_mod_m5,0,0)
        G_mod_m5=np.insert(G_mod_m5,0,0)
        D_mod_m6=np.insert(D_mod_m6,0,0)
        G_mod_m6=np.insert(G_mod_m6,0,0)
        D_mod_m7=np.insert(D_mod_m7,0,0)
        G_mod_m7=np.insert(G_mod_m7,0,0)
        D_mod_m8=np.insert(D_mod_m8,0,0)
        G_mod_m8=np.insert(G_mod_m8,0,0)
        D_mod_m9=np.insert(D_mod_m9,0,0)
        G_mod_m9=np.insert(G_mod_m9,0,0)
        D_mod_m10=np.insert(D_mod_m10,0,0)
        G_mod_m10=np.insert(G_mod_m10,0,0)
        #D_mod_m11=np.insert(D_mod_m11,0,0)
        #G_mod_m11=np.insert(G_mod_m11,0,0)


        plt.plot(D_obs_m,G_obs_m,color='black',label='MJUNG')
        plt.plot(D_mod_m1,G_mod_m1,color='limegreen',label='IPSL-CM5A-MR')
        plt.plot(D_mod_m2,G_mod_m2,color='peru',label='CM605-LR-pdCtrl-01')
        plt.plot(D_mod_m3,G_mod_m3,color='green',label='CM605.calv-LR-pdCtrl-02')
        plt.plot(D_mod_m4,G_mod_m4,color='red',label='CM605.dt20-LR-pdCtrl-02')
        plt.plot(D_mod_m5,G_mod_m5,color='yellow',label='CM605.GUST-LR-pdCtrl-01')
        plt.plot(D_mod_m6,G_mod_m6,color='magenta',label='CM605.NOSU-LR-pdCtrl-03')
        plt.plot(D_mod_m7,G_mod_m7,color='blue',label='CM605.THC1-LR-pdCtrl-01')
        plt.plot(D_mod_m8,G_mod_m8,color='brown',label='CM605.Z0-LR-pdCtrl-01')
        plt.plot(D_mod_m9,G_mod_m9,color='olive',label='NPv5.70PDctrl01')
        plt.plot(D_mod_m10,G_mod_m10,color='cyan',label='5.71vd2qsp375')
        #plt.plot(D_mod_m11,G_mod_m11,color='black',label='v5.67PDay01')
        #remove duplicates
#        handles, labels = plt.gca().get_legend_handles_labels()
#        newLabels, newHandles = [], [] #, [], [], [], [], [], [], [], [], [], []
#        for handle, label in zip(handles, labels):
#            if label not in newLabels:
#                newLabels.append(label)
#                newHandles.append(handle)

        plt.legend(['OBS','IPSL-CM5A','CM605-LR','CM605.calv','CM605.dt20','CM605.GUST','CM605.NOSU','CM605.THC1','CM605.Z0','NPv5.70PDctrl01','v5.71vd2qsp375'], ncol=3,loc='upper left',fontsize=10)
#mod_file1=indir+'IPSL-CM5A-MR_19860101_20051231_1M_nbp_time_TS.nc'
#mod_file2=indir+'CM605-LR-pdCtrl-01_20600101_20791231_1M_nbp_time_TS.nc'
#mod_file3=indir+'CM605.calv-LR-pdCtrl-02_20900101_21091231_1M_nbp_time_TS.nc'
#mod_file4=indir+'CM605.dt20-LR-pdCtrl-02_20500101_20691231_1M_nbp_time_TS.nc'
#mod_file5=indir+'CM605.GUST-LR-pdCtrl-01_20800101_20991231_1M_nbp_time_TS.nc'
#mod_file6=indir+'CM605.NOSU-LR-pdCtrl-03_19800101_19991231_1M_nbp_time_TS.nc'
#mod_file7=indir+'CM605.THC1-LR-pdCtrl-01_21000101_21191231_1M_nbp_time_TS.nc'
#mod_file8=indir+'CM605.Z0-LR-pdCtrl-01_20800101_20991231_1M_nbp_time_TS.nc'
#mod_file9=indir+'NPv5.70PDctrl01_19530101_19721231_1M_nbp_time_TS.nc'
#mod_file10=indir+'v5.71vd2qsp375_21020101_21211231_1M_nbp_time_TS.nc'
#mod_file11=indir+'FG1low.Choi.CN3238.RUN.NewX_19940101_20131231_1M_nbp_time_TS.nc'


#        plt.legend(newHandles,newLabels,ncol=3,loc='upper center',fontsize=10)
        plt.xlabel('Distance [km]')
        plt.ylabel('Semi-Variance')
        plt.title(zones[zo]+" ")
        plt.savefig('/data/sli/PC/autocorr/output/semivario_'+noun+'_'+zones[zo]+'.png')
        plt.close()


def find_index_nearest(array,value):
    """
    Find index nearest value in numpy array
    """
    idx = (np.abs(array-value)).argmin()
    return idx
def grid_distances(latlon):
    """
    Compute distances between point
    """
    #Rayon earth
    RTerre = 6375.
    #Problem size
    N = latlon.shape[0]
    #COS ET SIN use radians
    lats = latlon[:,0]*np.pi/180.
    lons = latlon[:,1]*np.pi/180.
    #Initialization
    dist = np.zeros((N,N))
    #Compute
    for i in range(N-1):
        for j in range(i+1,N):
            dist[j,i] = abs((1-ma.sin(lats[i])*ma.sin(lats[j])-ma.cos(lats[i])*ma.cos(lats[j])*ma.cos(lons[i]-lons[j]))/2.)
            dist[j,i] = RTerre*2.*ma.asin(dist[j,i]**(0.5))
    return dist
def semivariogram(SM, n_lag, lmax):
    """
    Compute the semivariogram
    """
    # make the meshgrid
    z = SM[:,2] # sm
    Z1,Z2 = np.meshgrid(z,z)
    # distance matrix
    # sym matrix
    # diag null
    D = grid_distances(SM[:,:2])
    # not normalised variogram
    # sym matrix
    # diag null
    G = 0.5*(Z1 - Z2)**2
    # data for computation
    indx = range(len(z)) # [ 0, 1, 2, ...  nb data]
    C,R = np.meshgrid(indx,indx) 
    # scattered values
    G = G[R>C] # triang inf of G, size=(1,size triang inf)
    D = D[R>C] # triang inf of D, size=(1,size triang inf)
    G = G[D<=lmax]
    D = D[D<=lmax]
    # group the variogram
    # the group are formed based on the equal number of bin
    total_n = len(D) #nb of data
    group_n = int(total_n/n_lag) # number of data per group
    sor_i = np.argsort(D) # indices to sort array increasing order
    # compute the mean for each group
    # initialisation
    DE = np.zeros(n_lag)
    GE = np.zeros(n_lag)
    i = 0
    while i<n_lag :
        if i<n_lag-1:
            DE[i] = D[sor_i[group_n*i:group_n*(i+1)]].mean()
            GE[i] = G[sor_i[group_n*i:group_n*(i+1)]].mean()
        else:
            DE[i] = D[sor_i[group_n*i:]].mean()
            GE[i] = G[sor_i[group_n*i:]].mean()
        i += 1
    # return
    return DE,GE
def cvmodel(model_type,model_par,lags):
    '''
    fits a model to the semivariogram data and returns a covariance method named covfct
    '''
    n = model_par['nugget']
    r = model_par['range']
    s = model_par['sill']
    l = lags
    if model_type == 'spherical':
        G = n + (s-n)*(1.5*l/r - 0.5*(l/r)**3.0)*(l<=r) + (s-n)*(l>r)
    elif model_type == 'linear':
        G = n + (s-n)*l/r
    elif model_type == 'exponential':
        G = n + (s-n)*(1 - np.exp(-3*l/r))
    else:
        raise ValueError('model_type should be spherical or linear or exponential')
    return G


# ScatterPlots
def scatterplot_spatial(obs,mod,yvec):
    fig=plt.figure()
    plt.plot(obs.flatten(), mod.flatten(), "o", color='dodgerblue', label='CL3', alpha=0.5)
    plt.legend()
    plt.xlabel('ESA CCI [m3/m3]')
    plt.ylabel('ORCHIDEE [m3/m3]')
    plt.savefig('/Users/cmagand/pro/CMIP6/GRAPHS/temporal/scatter_'+yvec+'.png')
    plt.close()
