#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb  3 18:09:33 2020

@author: evignon
"""

###################################################################################
# Etienne Vignon. 2020/02/03
#     Modified. Frédéric Hourdin 2020/02/04
#
#  This scrip modifies the "limit.nc" file, used as boundary condition for 
#  the LMDZ model. Used in particular for teaching with light versions of
#  LMDZ.
#
#  The perturbation is controled by the setup_perturb function just bellow
#
###################################################################################
# IMPORT

import numpy as np
import sys,subprocess
import matplotlib.pyplot as plt
import os
from netCDF4 import Dataset
import argparse

if_basemap=True                 # if you want to make nice plots with basemap
parser=argparse.ArgumentParser()
parser.add_argument("--nobasemap",action="store_true",help="Suffixe pour choisir le fichier ")
parser.add_argument("--plot",help="Suffixe pour choisir le fichier ",metavar='limit.nc')
parser.add_argument("-v",help="name of variable to be modified or ploted  can be: 'SST','RUG','ALB', 'FSIC', 'FLIC'",metavar='SST')
args=parser.parse_args()



variable_name='SST'
if args.v :
    variable_name=str(args.v)

if args.plot :
     plot_only=True
     file_for_one_plot=str(args.plot)
else :
     plot_only=False

if args.nobasemap :
        if_basemap=False
else :
        from mpl_toolkits.basemap import Basemap
        m_=Basemap(resolution='c',projection='kav7', lat_0=0.0, lon_0=0.0)
        n_graticules=18
        parallels=np.linspace(-80.0, 90.0, n_graticules)
        meridians=np.linspace(0.,360.0, n_graticules)
        graticules_color='grey'
        coastline_color='black'


ltime=0 #time index at which the variable will be plotted
###################################################################################
def plot_one_map (longitudevec,latitudevec,var_,min_,max_,levs,cmap_,title_,position_,fig_) :
###################################################################################
        ax_=plt.subplot(position_)
        [longitudemat,latitudemat]=np.meshgrid(longitudevec,latitudevec)
        vari=np.transpose(var_)
        if if_basemap :
           m_.drawmeridians(meridians,linewidth=1,color=graticules_color,labels=[1,0,0,0])
           m_.drawparallels(parallels,linewidth=1,color=graticules_color,labels=[1,0,0,0])
           m_.drawcoastlines(linewidth=1., color=coastline_color)
           x,y=m_(longitudemat,latitudemat)
           CS1=m_.pcolor(x,y,vari,vmin=min_,vmax=max_,cmap=cmap_)
        else :
           x=longitudemat
           y=latitudemat
           CS1=ax_.pcolor(x,y,vari,vmin=min_,vmax=max_,cmap=cmap_)
           ax_.grid()

        cbar_ax = fig_.add_axes(levs)
        fig_.colorbar(CS1, cax=cbar_ax)
        ax_.set_title(title_)


###################################################################################
def setup_perturb():
###################################################################################
    # Choix du type de perturbation : boite rectangulaire ou gaussienne 
    # fonction_perturb=rect_perturb
    fonction_perturb=gauss_perturb
    # Longitudes et latitudes de conrol
    # Pour le rectangle : [min_lon, max_lon, min_lat, max_lat]
    # Pour la gaussienne : [lon_center,lon_width,lat_center,lat_width]
    lonlatbounds=[0,25,0,20]    # [min_lon, max_lon, min_lat, max_lat]
    # restriction to a sub-surface : can be 'land' or 'ocean' or 'lic' or 'sic' or 'all'
    terrain_type= 'ocean'
    threshold_frac=0.9  # if terrain_type not all, over which fraction do we change the value
                        #(for heteregoneneous meshes, between 0 and 1)
    plus=1              # adding the perturbation (plus=1) or replacing the full fill (plus=0)
    value_in=3.0                    # amplitude of the perturbation
    # perturbed_variable = plus * perturbed_variable +  value_in * F ( longitude, latitude )
    value_out="none"                # value out of the mask. if 'none', takes the default value
    figure_name='mymask.png'        # name of the figure
    return fonction_perturb,lonlatbounds,variable_name,terrain_type,threshold_frac,plus,value_in,value_out,if_basemap,figure_name

###################################################################################
def rect_perturb(lon,lat,masktype,threshold_frac, variable,value_in,lonlatbounds,plus):
###################################################################################
    lonmin=lonlatbounds[0]
    lonmax=lonlatbounds[1]
    latmin=lonlatbounds[2]
    latmax=lonlatbounds[3]

    variable[(lon >= lonmin) & (lon <= lonmax) \
          & (lat >= latmin)  & (lat<= latmax) \
          & (masktype >= threshold_frac)]= \
          plus*variable[(lon >= lonmin) & (lon <= lonmax) \
          & (lat >= latmin)  & (lat<= latmax) \
          & (masktype >= threshold_frac)]+value_in
    return variable

###################################################################################
def gauss_perturb(lon,lat,masktype,threshold_frac, variable,value_in,lonlatbounds,plus):
###################################################################################
    lon0=lonlatbounds[0]
    dlon=lonlatbounds[1]
    lat0=lonlatbounds[2]
    dlat=lonlatbounds[3]
    # Calcul de la distance en longitude en tenant compte de la periodicite.
    ddl=np.abs(lon-lon0)
    ddp=np.abs(lon-lon0+360.)
    ddm=np.abs(lon-lon0-360.)
    dd2=np.stack((ddl,ddp,ddm),axis=2)
    ddlon=np.min(dd2,2)/dlon
    ddlat=(lat-lat0)/dlat
    variable[masktype >= threshold_frac]= \
            plus*variable[masktype >= threshold_frac] \
            +value_in*np.exp(-ddlon[masktype >= threshold_frac]**2) \
                     *np.exp(-ddlat[masktype >= threshold_frac]**2)
    return variable

###################################################################################
def apply_the_perturb(function_type,plus,lonlatbounds,variable,terrain_type,threshold_frac,value_in,value_out):
###################################################################################
# Application d'un masque
###################################################################################
    
    limit_data=Dataset('limit.nc','r+')
    latitude=limit_data.variables["latitude"][:]
    longitude=limit_data.variables["longitude"][:]
    fter=limit_data.variables["FTER"][:,:]
    foce=limit_data.variables["FOCE"][:,:]
    fsic=limit_data.variables["FSIC"][:,:]
    flic=limit_data.variables["FLIC"][:,:]
    myvar=limit_data.variables[variable][:,:]
    
    lonmin=lonlatbounds[0]
    lonmax=lonlatbounds[1]
    latmin=lonlatbounds[2]
    latmax=lonlatbounds[3]
    
    lon=np.tile(longitude,(len(myvar[:,0]),1))
    lat=np.tile(latitude,(len(myvar[:,0]),1))

    if terrain_type== "land":
        masktype=np.array(fter)
    elif terrain_type=="ocean":
        masktype=np.array(foce)
    elif terrain_type== "sic":
        masktype=np.array(fsic)
    elif terrain_type=="lic":
        masktype=np.array(flic)
    else: 
        masktype=np.array(fter)*0.0+1.0
   
    
    if value_out != "none":
        myvar=myvar*0.0+value_out
        
#   myvar[(lon >= lonmin) & (lon <= lonmax) \
#         & (lat >= latmin)  & (lat<= latmax) \
#         & (masktype >= threshold_frac)]=value_in
    myvar=fonction_perturb(lon,lat,masktype,threshold_frac,myvar, value_in,lonlatbounds,plus)
          
          
   # save the new variable in the netcdf
 
    limit_data[variable][:,:]=myvar
    limit_data.close()

    
###################################################################################
def get_longitude_latitude (file_) :
###################################################################################
    limitnew_data=Dataset(file_)
    return limitnew_data.variables["longitude"][:],limitnew_data.variables["latitude"][:]

###################################################################################
def get_longitude_latitude_vector(file_) :
###################################################################################
    lon_,lat_=get_longitude_latitude(file_)
    return np.unique(lon_),np.unique(lat_)

###################################################################################
def get_var_and_varmat (file_,lonvec_,latvec_,lon_,lat_,var_):
###################################################################################
    limit_data=Dataset(file_)
    myvar_=np.squeeze(limit_data.variables[var_][ltime,:])
    indlat=-1
    resol=[len(lonvec_), len(latvec_)]
    myvarmat_=np.zeros(resol)*np.nan
    for lati in latvec_:
        indlat=indlat+1
        indlon=-1
        for longi in lonvec_:
            indlon=indlon+1
            prov=myvar_[(lat_==lati) & (lon_==longi)]
            if len(prov)==1:
                myvarmat_[indlon,indlat]=prov
    return myvar_,myvarmat_

    
###################################################################################
def initialize_graphics () :
###################################################################################
    SMALL_SIZE = 7
    MEDIUM_SIZE = 13
    BIGGER_SIZE = 14

    # Set caracter sizes
    # ------------------
    plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    return plt.figure(figsize=(12,20))
    

###################################################################################
def plot_one_variable (file_,var_) :
###################################################################################
    longitudevec,latitudevec=get_longitude_latitude_vector(file_)
    longitudenew,latitudenew=get_longitude_latitude(file_)
    myvarnew,myvarnewmat=get_var_and_varmat (file_,longitudevec,latitudevec,longitudenew,latitudenew,var_)
    fig=initialize_graphics()
    mini=np.min(np.min(myvarnew))
    maxi=np.max(np.max(myvarnew))
    plot_one_map (longitudevec,latitudevec,myvarnewmat,mini,maxi,[0.91, 0.41, 0.025, 0.45],"gnuplot2_r","New",311,fig)
    plt.show()

###################################################################################
def plot_the_perturb(variable,figure_name,if_basemap):
###################################################################################
    
    longitudevec,latitudevec=get_longitude_latitude_vector('limit.nc')
    longitudenew,latitudenew=get_longitude_latitude('limit.nc')
    myvarnew,myvarnewmat=get_var_and_varmat ('limit.nc',longitudevec,latitudevec,longitudenew,latitudenew,variable)
    myvarold,myvaroldmat=get_var_and_varmat ('limit_back.nc',longitudevec,latitudevec,longitudenew,latitudenew,variable)

    # Minimum and maximum values of field and differences
    # ---------------------------------------------------
    mini=np.min([np.min(myvarnew),np.min(myvarold)])
    maxi=np.max([np.max(myvarnew),np.max(myvarold)])
    vardiff=myvarnewmat-myvaroldmat
    maxdiff=np.nanmax(np.abs(vardiff))
    mind=-1.2*maxdiff
    maxd=1.2*maxdiff
        
    # Plot 3 maps : New, Old and New-Old 
    # ----------------------------------
    fig=initialize_graphics()
    plot_one_map (longitudevec,latitudevec,myvarnewmat,mini,maxi,[0.91, 0.41, 0.025, 0.45],"gnuplot2_r","New",311,fig)
    plot_one_map (longitudevec,latitudevec,myvaroldmat,mini,maxi,[0.91, 0.41, 0.025, 0.45],"gnuplot2_r","New",312,fig)
    plot_one_map (longitudevec,latitudevec,vardiff    ,mind,maxd,[0.91, 0.13, 0.025, 0.2] ,"bwr"       ,"New-Old",313,fig)
    fig.suptitle(variable, fontsize=18)
    close_graphics(fig)
    fig_.savefig(figure_name)

        
#==============================================================================
# main program shapemylimit
#
# program that changes the value of a variable in limit.nc
# over a masked area
#
# Required packages:
#
# numpy, sys, netcdf4, subprocess, os, matplotlib
#==============================================================================


if plot_only :
   plot_one_variable (file_for_one_plot,variable_name)
   quit()


fonction_perturb,lonlatbounds,variable_name,terrain_type,threshold_frac,plus,value_in,value_out,if_basemap,figure_name=setup_perturb()
    
#---------------------------- DO NOT TOUCH BELOW ------------------------------

# Check if variable name is correct:
list_var=['SST','RUG','ALB','FSIC','FLIC']
if variable_name not in list_var:
    sys.exit(variable_name  +' is not a valid variable name')
    
    
# Check if limit.nc exists:
if not os.path.exists('limit.nc'):
    sys.exit('I do not find a limit.nc file')

# Copy the limit.nc

if os.path.exists('limit_back.nc'):
    answer="none"
    print("A limit_back.nc file already exists. The next command will erase it. Do you really want to proceed?")
    while answer not in ("yes","no"):
        answer=input("Enter yes or no ")
        if answer=="yes":
            break
        elif answer=="no":
            sys.exit("save carefully your limit_back.nc")
        else:
            print("Please enter yes or no ")

cmd1='cp -f limit.nc limit_back.nc'
subprocess.call(cmd1, shell=True)
   
# Apply the mask

apply_the_perturb(fonction_perturb,plus,lonlatbounds,variable_name,terrain_type,threshold_frac,value_in, value_out)


# Make a plot

plot_the_perturb(variable_name,figure_name,if_basemap)

