import os 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import geopandas as gpd
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap, BoundaryNorm
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
import time
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import rioxarray as rxr
import xarray as xr
import sys


#predicted with slurm, and using year as an arugment to run on all nodes/gpus available
year = sys.argv[1]


start_time = time.time()

# Function to normalize data
scaler = MinMaxScaler()

# Load min-max normalization values
min_max = pd.read_csv("min_max_overlapping.csv").reset_index(drop=True)
min_max = min_max[['6', '7', '8']]


# Function to normalize images using min-max scaler
def normalize_image(img):
    img_shape = img.shape
    img = img.reshape(-1, img.shape[2])
    img = pd.DataFrame(img)
    img.columns = min_max.columns
    img = pd.concat([min_max, img]).reset_index(drop=True)
    img = pd.DataFrame(scaler.fit_transform(img))
    img = img.iloc[2:]
    img = img.values.reshape(img_shape)
    img[np.isnan(img)] = -1
    img = np.round(img, 3)
    return img

#Load the TensorFlow model used for final predictions
model = tf.keras.models.load_model(
    "combined_sliding_final_update_nodeep_2023.tf", 
    custom_objects={'precision': sm.metrics.Precision(threshold=0.5), 
                    'recall': sm.metrics.Recall(threshold=0.5),
                    'f1-score': sm.metrics.FScore(threshold=0.5),
                    'iou_score': sm.metrics.IOUScore(threshold=0.5)}
        )

#path to files split into 128x128 chunks used for predicting on (after download_all_landsat script
in_path = f'local_predict_all/{year}'
out_path = f'prism_preds/{year}'
os.makedirs(out_path, exist_ok = True)


#list all files
all_tifs = [f for f in os.listdir(in_path) if f.endswith('.tif')]

#loop through the files and predict the model, and save as tif for each tile. 
for f in all_tifs:

    output_file = os.path.join(out_path, f)

    # Skip if output file already exists
    if os.path.exists(output_file):
        print(f"Skipping {f}, output file already exists.")
        continue

    # Load and process the file
    img_rio = rxr.open_rasterio(os.path.join(in_path, f))
    img = img_rio.to_numpy()
    img = np.moveaxis(img, 0, 2)  # Rearrange the dimensions (bands last)
    
    #no data, -999 is cloud, -998 is crop and -997 is water
    img = img.astype(float)
    img = np.round(img, 3)
    img[img == 0] = -999
    img[img == -998] = -999
    img[img == -997] = -999
    img[np.isnan(img)] = -999
    img[img == -999] = np.nan
    
    # Normalize image
    img = normalize_image(img)
    
    #overlap for smoothing predictions
    overlap = 64
    
    # 🔹 Create Overlapping Tiles
    def create_overlapping_tiles(image, tile_size=(128, 128), overlap=overlap):
        height, width, channels = image.shape
        stride_x = tile_size[0] - overlap
        stride_y = tile_size[1] - overlap
        
        tiles = []
        positions = []
        
        for x in range(0, height - tile_size[0] + 1, stride_x):
            for y in range(0, width - tile_size[1] + 1, stride_y):
                tile = image[x:x + tile_size[0], y:y + tile_size[1], :]
                tiles.append(tile)
                positions.append((x, y))
        
        return np.array(tiles), positions  # Ensure tiles are a NumPy array
    
    #Create tiles & positions
    tiles, positions = create_overlapping_tiles(img, tile_size=(128, 128), overlap=overlap)
    
    # Batch Prediction on Single GPU
    BATCH_SIZE = int(224 / 2)
    num_batches = len(tiles) // BATCH_SIZE + int(len(tiles) % BATCH_SIZE > 0)
    
    print(f"Processing {len(tiles)} tiles in {num_batches} batches...")
    
    #predict on batches
    predictions = []
    for i in range(num_batches):
        start = i * BATCH_SIZE
        end = min((i + 1) * BATCH_SIZE, len(tiles))
        batch_tiles = tiles[start:end]
        
        #Predict on Batch
        batch_pred = model.predict(batch_tiles, verbose=0)
        predictions.append(batch_pred)
    
    #Stack Predictions into Final Array
    final_tiles = np.vstack(predictions)
    
    # Reconstruct the Full Predicted Image
    def reconstruct_image(tiles, positions, image_shape, tile_size=(128, 128), overlap=overlap):
        reconstructed = np.zeros(image_shape[:2], dtype=np.float32)
        count = np.zeros(image_shape[:2], dtype=np.float32)
        
        stride_x = tile_size[0] - overlap
        stride_y = tile_size[1] - overlap
        
        for (tile, (x, y)) in zip(tiles, positions):
            reconstructed[x:x + tile_size[0], y:y + tile_size[1]] += tile.squeeze()
            count[x:x + tile_size[0], y:y + tile_size[1]] += 1
        
        reconstructed /= np.maximum(count, 1)  # Normalize overlapping regions 
        return reconstructed
    
    # Reconstruct the final predicted image
    final_prediction = reconstruct_image(final_tiles, positions, img.shape, tile_size=(128, 128), overlap=overlap)
    
    #convert to xarray to save tif file
    # Convert final_prediction to an xarray DataArray with georeferencing
    xr_pred = xr.DataArray(
        np.expand_dims(final_prediction, axis=0),  # Add band dimension
        dims=("band", "y", "x"),
        coords={"band": [1]},  # Assign a band index
    )
    # Preserve original CRS and transform
    xr_pred = xr_pred.rio.write_crs(img_rio.rio.crs)
    xr_pred = xr_pred.rio.write_transform(img_rio.rio.transform())
    
    # Save the raster
    xr_pred.rio.to_raster(os.path.join(out_path, f), driver="GTiff", compress="LZW")

    print(f"Processed file saved at: {os.path.join(out_path, f)}")


#End timing
end_time = time.time()
elapsed_time = (end_time - start_time) / 3600

print(f"Prediction, processing, and reconstruction complete in {elapsed_time:.2f} hours")