mport numpy as np
import matplotlib.pyplot as plt
import netCDF4 as nc

def load_netcdf_data(filename):
    """
    Charge le fichier NetCDF et affiche les informations
    """
    dataset = nc.Dataset(filename, 'r')
    
    print("Dimensions:", dataset.dimensions.keys())
    print("Variables:", dataset.variables.keys())
    
    # Affichage des valeurs Ls
    if 'Ls' in dataset.variables:
        ls_values = dataset.variables['Ls'][:]
        print(f"Valeurs Ls: min={ls_values.min():.2f},
max={ls_values.max():.2f}")
        print(f"Premières valeurs Ls: {ls_values[:10]}")
    
    return dataset

def find_closest_ls_indices(ls_array, target_values):
    """
    Trouve les indices des valeurs Ls les plus proches des valeurs
cibles
    """
    closest_indices = {}
    actual_values = {}
    
    for target in target_values:
        idx = np.argmin(np.abs(ls_array - target))
        closest_indices[target] = idx
        actual_values[target] = ls_array[idx]
        print(f"Ls cible: {target}°, Ls trouvé: {ls_array[idx]:.2f}°,
indice: {idx}")
    
    return closest_indices, actual_values

def calculate_daily_average(dataset, start_idx):
    """
    Calcule la moyenne journalière (4 outputs par jour) à partir de
start_idx
    """
    end_idx = min(start_idx + 4, dataset.variables['Ls'].shape[0])
    
    # Extraction des données pour la journée
    # mtot: (Ls, latitude, longitude)
    mtot_day = dataset.variables['mtot'][start_idx:end_idx, :, :]
    
    # u et v: (Ls, alt, latitude, longitude) - on prend le premier
    # niveau d'altitude
    u_day = dataset.variables['u'][start_idx:end_idx, 0, :, :]
    v_day = dataset.variables['v'][start_idx:end_idx, 0, :, :]
    
    # Moyenne journalière
    mtot_avg = np.mean(mtot_day, axis=0)
    u_avg = np.mean(u_day, axis=0)
    v_avg = np.mean(v_day, axis=0)
    
    return mtot_avg, u_avg, v_avg

def create_mars_climate_map(longitude, latitude, mtot, u, v, target_ls,
actual_ls, output_dir="./"):
    """
    Crée une carte climatique simple pour Mars
    """
    fig, ax = plt.subplots(figsize=(15, 10))  # Figure plus grande
    
    # Création des grilles de coordonnées
    LON, LAT = np.meshgrid(longitude, latitude)
    
    # Conversion mtot en microns précipitables
    mtot_microns = mtot * 1000
    
    # Carte de couleur pour mtot
    levels = np.linspace(np.nanmin(mtot_microns),
np.nanmax(mtot_microns), 20)
    contour = ax.contourf(LON, LAT, mtot_microns, levels=levels,
cmap='viridis')
    
    # Barre de couleur pour mtot
    cbar = plt.colorbar(contour, ax=ax, shrink=0.8)
    cbar.set_label('Humidité (μm précipitables)', fontsize=12)
    
    # Masquer les pôles (±90°) car ils sont biaisés
    mask_poles = (np.abs(LAT) < 89.5)
    
    # Calcul de l'intensité du vent
    wind_speed = np.sqrt(u**2 + v**2)
    max_speed = np.nanmax(wind_speed[mask_poles])
    
    # Masquer les données aux pôles
    u_masked = np.where(mask_poles, u, np.nan)
    v_masked = np.where(mask_poles, v, np.nan)
    wind_speed_masked = np.where(mask_poles, wind_speed, np.nan)
    
    # Flèches de vent - taille proportionnelle à l'intensité
    # Scale ajusté pour des flèches bien visibles
    scale_value = max_speed * 10  # Ajustez ce facteur pour la taille
des flèches
    
    quiver = ax.quiver(LON, LAT, u_masked, v_masked,
                      scale=scale_value, scale_units='width',
                      color='white', alpha=0.8, width=0.003)
    
    # Légendes pour les flèches - plusieurs tailles de référence
    reference_speeds = []
    if max_speed > 20:
        reference_speeds = [5, 10, 20]
    elif max_speed > 10:
        reference_speeds = [2, 5, 10]
    else:
        reference_speeds = [1, 2, 5]
    
    # Ajout des légendes de taille - positionnées dans la marge
    for i, ref_speed in enumerate(reference_speeds):
        if ref_speed <= max_speed:
            ax.quiverkey(quiver, 0.02, 0.98 - i*0.04, ref_speed, 
                        f'{ref_speed} m/s', labelpos='E', 
                        coordinates='axes', color='black', 
                        fontproperties={'size': 11, 'weight': 'bold'},
                        labelsep=0.1)
    
    # Configuration des axes
    ax.set_xlabel('Longitude (°)', fontsize=12)
    ax.set_ylabel('Latitude (°)', fontsize=12)
    ax.set_title(f'Mars - Climat (Ls = {actual_ls:.2f}°)', fontsize=14,
pad=15)
    
    # Grille
    ax.grid(True, alpha=0.3)
    
    # Limites des axes - éviter les pôles
    ax.set_xlim(longitude.min(), longitude.max())
    ax.set_ylim(max(latitude.min(), -89), min(latitude.max(), 89))
    
    # Sauvegarde haute qualité
    filename = f"{output_dir}mars_climate_Ls_{target_ls:03d}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.close()
    
    print(f"Carte sauvegardée: {filename}")
    print(f"  - Humidité: {np.nanmin(mtot_microns):.1f} -
{np.nanmax(mtot_microns):.1f} μm")
    print(f"  - Vent max: {max_speed:.1f} m/s (hors pôles)")
    print(f"  - Données masquées aux pôles (|lat| > 89.5°)")
    
    return filename

def main(filename):
    # Configuration
    target_ls_values = [0, 30, 60, 90, 120, 150, 180, 210, 240, 270,
300, 330]
    output_directory = "./"
    
    print(f"Utilisation du fichier: {filename}")
    
    try:
        # Chargement des données
        print("=== Chargement du fichier NetCDF ===")
        dataset = load_netcdf_data(filename)
        
        # Extraction des coordonnées - noms corrects
        longitude = dataset.variables['longitude'][:]
        latitude = dataset.variables['latitude'][:]
        ls_array = dataset.variables['Ls'][:]
        
        print(f"\nGrille: {len(latitude)} lat × {len(longitude)} lon")
        print(f"Latitude: {latitude.min():.1f}° à
{latitude.max():.1f}°")
        print(f"Longitude: {longitude.min():.1f}° à
{longitude.max():.1f}°")
        
        # Vérification des dimensions
        print(f"\nDimensions des variables:")
        print(f"  mtot: {dataset.variables['mtot'].shape}")
        print(f"  u: {dataset.variables['u'].shape}")
        print(f"  v: {dataset.variables['v'].shape}")
        
        # Trouve les indices Ls les plus proches
        print(f"\n=== Recherche des valeurs Ls ===")
        ls_indices, actual_ls_values = find_closest_ls_indices(ls_array,
target_ls_values)
        
        # Création des cartes
        print(f"\n=== Création des cartes ===")
        saved_files = []
        
        for target_ls in target_ls_values:
            if target_ls in ls_indices:
                print(f"\nTraitement Ls = {target_ls}°...")
                
                # Calcul de la moyenne journalière
                start_idx = ls_indices[target_ls]
                mtot_avg, u_avg, v_avg =
calculate_daily_average(dataset, start_idx)
                
                # Création de la carte
                filename_out = create_mars_climate_map(
                    longitude, latitude, mtot_avg, u_avg, v_avg,
                    target_ls, actual_ls_values[target_ls],
                    output_directory
                )
                saved_files.append(filename_out)
        
        print(f"\n=== Résumé ===")
        print(f"Cartes créées: {len(saved_files)}")
        for f in saved_files:
            print(f"  - {f}")
        
        # Fermeture du fichier
        dataset.close()
        
    except FileNotFoundError:
        print(f"Erreur: Fichier '{filename}' introuvable")
    except KeyError as e:
        print(f"Erreur: Variable manquante: {e}")
        print("Variables disponibles:", list(dataset.variables.keys())
if 'dataset' in locals() else "N/A")
    except Exception as e:
        print(f"Erreur: {e}")

if __name__ == "__main__":
    main("diagfi.nc")  # Modifiez le nom du fichier ici
