import netCDF4 as nc 
import sys,os  

if len(sys.argv)<2 :
    print "Usage: ",sys.argv[0],"input.nc nbHaloPointsToRemove [output.nc]"
    exit(1)

# Input file name
fi = sys.argv[1]
nbHalo=int(sys.argv[2])
# Output file ? 
if len(sys.argv)==4 : 
    out=sys.argv[3]
else :
    out = "formatted_"+fi

###------------ This should change for other input data format -----------###
# Variables that need a new name
changeVar = {"ni":"W_E_direction","nj":"S_N_direction","level":"vertical_levels",
             "XHAT":"W_E_direction", "YHAT":"S_N_direction", "ZHAT":"vertical_levels"}
# Dimensions that need to be copied to output file
keepDim = ["ni","nj","level","time"]
# Variables that need to be copied to output file
keepVar = {"XHAT","YHAT","ZHAT","RVT","RCT","PABST","THT"}
###-----------------------------------------------------------------------###


# Open input file
f = nc.Dataset(fi,'r')                 
print "Creating output file:",out
g = nc.Dataset(out,'w')    # Open output file

# To copy the global attributes of the input file to the output file
print "Copying global attributes to output file"
for attname in f.ncattrs():
    setattr(g,attname,getattr(f,attname))

# To copy the dimensions of the input file to the output file
print "Copying dimensions to output file"
#
for dimname,dim in f.dimensions.iteritems():
  if dimname in keepDim:
    npts = len(dim)
    if dimname in changeVar:
      newdimname = changeVar[dimname]
    else : newdimname = dimname
    if "time" in newdimname : 
      npts=0 
    elif "lev" in newdimname:
      npts-=2        # remove HALO points
    else :
      npts-=2*nbHalo # remove HALO points
    g.createDimension(newdimname,npts)
    print newdimname, "of length", npts

# To copy the variables of the netCDF file
print "Copying fields to output file"
for varname,ncvar in f.variables.iteritems():
  if varname in keepVar: # extract only relevant fields
    # Define dims from the dims of this var in input file
    tmpdim = [d if ("_" not in d) or ("dir" in d) or (d=="vertical_levels") else d[:-2] for d in ncvar.dimensions]
    dims=[changeVar[d] if d in changeVar else d for d in tmpdim]
    
    # Do we need to change varname?
    if varname in changeVar:
      newvarname=changeVar[varname]
    else : newvarname = varname

    # Create variable with modified or same name,
    # with same dims as in input file
    var = g.createVariable(newvarname,ncvar.dtype,dims)
    if len(var.shape)==4:
        var[:,:,:,:] = ncvar[:,1:-1,nbHalo:-nbHalo,nbHalo:-nbHalo] # remove HALO points
    elif len(var.shape)==1 :  # 1D fields are necessarily coordinates
      if "direction" in newvarname :
        # because htcp assumes domain origin is 0
        if nbHalo==1 :
          var[:]= 0.5*(ncvar[nbHalo:-nbHalo] + ncvar[nbHalo+1:])-ncvar[nbHalo] 
        else : 
          var[:]= 0.5*(ncvar[nbHalo:-nbHalo] + ncvar[nbHalo+1:-nbHalo+1])-ncvar[nbHalo] 
      else :
        # because htcp assumes domain origin is 0
        var[:]= .5*(ncvar[1:-1]+ncvar[2:]) + ncvar[1]
         
    else : print varname, "is of shape", var.shape, "and thus will not be copied to output file"

f.close()
g.close()
