import numpy as np
import pickle
from pathlib import Path

import sys
sys.path.append('../../experiment_utils/')
sys.path.append('../experiment_utils/')
sys.path.append('experiment_utils/')
from utils import get_spatial_binned_data
from stdata.grids import create_spatial_grid, create_geopandas_spatial_grid

import stdata
import geopandas as gpd

from stdata.utils import save_to_pickle
import matplotlib.pyplot as plt
from stgp.data import SpatioTemporalData

import pandas as pd
import numpy as np
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm

import numpy as np
from sklearn.model_selection import train_test_split
from stdata.utils import datetime_to_epoch
from stdata.model_selection import normalise


NUM_FOLDS = 5
SUBSAMPLE_SIZE = 2000

def get_data_file_names(fold):
    return {
        'raw': f'raw_data_{fold}.pickle', 
        'train': f'train_data_{fold}.pickle',
        'test': f'test_data_{fold}.pickle', 
    }

def make_spatial_grid(data_df, n_x=None, n_y=None, x_col=None, y_col=None, padding_min=0, padding_max = 0):
    min_x = np.min(data_df[x_col])-padding_min
    max_x = np.max(data_df[x_col])+padding_max
    min_y = np.min(data_df[y_col])-padding_min
    max_y = np.max(data_df[y_col])+padding_max
    
    size_x = (max_x-min_x)/n_x
    size_y = (max_y-min_y)/n_y
    
    grid_gdf = create_geopandas_spatial_grid(
        min_x,
        max_x,
        min_y,
        max_y,
        size_x, 
        size_y
    )

    #grid created on x-y format
    N_x = len(np.arange(min_x, max_x-size_x, size_x))
    N_y = len(np.arange(min_y, max_y-size_y, size_y))

    x_locs = np.reshape(np.tile(np.arange(N_x)[:, None].T, (N_y, 1)).T, [-1, 1])
    y_locs = np.tile(np.arange(N_y), [N_x])[:, None]
    
    grid_gdf['x_loc'] = x_locs
    grid_gdf['y_loc'] = y_locs
    
    grid_gdf['grid_id'] = grid_gdf.index
    
    return grid_gdf

def select_region(data_df):
    return data_df[
        (data_df['lon'] >= -86.8) &
        (data_df['lon'] <= -85.5) &
        (data_df['lat'] >= 26.3) &
        (data_df['lat'] <= 27.5) 
    ]

def discretize_space(_df, grid_gdf):
    return get_spatial_binned_data(
        _df, 
        grid_gdf=grid_gdf,
        x_col='lon', 
        y_col='lat', 
        return_grid_details=False
    )

def get_binned_data(hourly_date, grid_gdf):

    hourly_date = hourly_date.copy()

    disc_hourly_df = hourly_date.groupby(['date', 'hour']).apply(lambda _df: discretize_space(_df, grid_gdf))
    breakpoint()
    disc_hourly_df= disc_hourly_df.drop(columns='hour').reset_index()
    nan_idx = ~(pd.isna(disc_hourly_df['u']) & pd.isna(disc_hourly_df['v']))

    disc_hourly_df['datetime'] = pd.to_datetime(disc_hourly_df.apply(
        lambda df: f"{df['date']}  {df['hour']}:00:00",
        axis=1
    ))

    disc_hourly_df['epoch'] = datetime_to_epoch(disc_hourly_df['datetime'])

    # drop nan data
    disc_hourly_df = disc_hourly_df[nan_idx]

    # add lat/lon data
    disc_hourly_df['lon'] = disc_hourly_df.geometry.centroid.x
    disc_hourly_df['lat'] = disc_hourly_df.geometry.centroid.y

    return disc_hourly_df

def setup(datasets_root, experiment_root):
    results_path = experiment_root / 'results'
    data_path = experiment_root / 'data'

    # Ensure results and data file exists
    results_path.mkdir(exist_ok=True)
    data_path.mkdir(exist_ok=True)

    # Load data
    raw_data_df = pd.read_csv(datasets_root / 'laser.csv')

    # replace id with an int
    raw_data_df['id'] = raw_data_df.groupby('id').ngroup()

    # book keeping
    raw_data_df['datetime'] = pd.to_datetime(raw_data_df['datetime'])
    raw_data_df['gid'] = raw_data_df.index


    data_df = raw_data_df.copy()
    #data_df = select_region(raw_data_df)
    data_df['date'] = pd.to_datetime(data_df['date'] )

    # round datetime to nearest 15 mins
    # data has been smoothed from every 5 mins to every 15 but there are some small artificats in the datatime
    data_df['datetime'] = data_df['datetime'].dt.round('15min')  

    data_df['hour'] = data_df['datetime'].dt.hour

    # only select numeric columns
    data_df = data_df[
        ['id', 'datetime', 'date', 'hour', 'lat', 'lon', 'position_error', 'u', 'v', 'velocity_error', 'gid']
    ]

    data_df = data_df[(data_df['date']>='2016-01-22') & (data_df['date']<='2016-01-28')]

    if True:
        hourly_date = data_df.copy()
    else:
        hourly_date = data_df.groupby(['id', 'date', 'hour', 'lat', 'lon']).mean().reset_index()

        hourly_date['datetime'] = pd.to_datetime(hourly_date.apply(
            lambda df: f"{df['date']}  {df['hour']}:00:00",
            axis=1
        ))


    breakpoint()


    grid_gdf = make_spatial_grid(hourly_date, 50, 50, 'lon', 'lat', padding_max=0.01)

    hourly_date['epoch'] = datetime_to_epoch(hourly_date['datetime'])

    #TODO: we are not guarenting that the test indexes are disjoint
    for fold in range(NUM_FOLDS):
        data_fold_df = hourly_date.copy()

        # train test 
        train_df, test_df = train_test_split(data_fold_df, random_state=fold, shuffle=True, test_size=0.1)

        #train_binned_df = get_binned_data(train_df, grid_gdf)
        #test_binned_df = get_binned_data(test_df, grid_gdf)

        train_df['epoch_norm'] = normalise(train_df['epoch'], train_df['epoch'])
        test_df['epoch_norm'] = normalise(test_df['epoch'], train_df['epoch'])

        data_fold_df['epoch_norm'] = normalise(data_fold_df['epoch'], train_df['epoch'])

        #train_binned_df['epoch_norm'] = normalise(train_binned_df['epoch'], train_df['epoch'])
        #test_binned_df['epoch_norm'] = normalise(test_binned_df['epoch'], train_df['epoch'])


        X_cols = ['epoch_norm', 'lon', 'lat']
        Y_cols = ['u', 'v']

        # raw data
        X_all = np.array(data_fold_df[X_cols])
        Y_all = np.array(data_fold_df[Y_cols])

        # ===== TRAINING DATA =====
        # binned data
        #X_train_binned = np.array(train_binned_df[X_cols])
        #Y_train_binned = np.array(train_binned_df[Y_cols])

        # raw data
        X_train = np.array(train_df[X_cols])
        Y_train = np.array(train_df[Y_cols])

        sub_train_df = train_df.sample(SUBSAMPLE_SIZE, random_state=0)
        X_train_sub = np.array(sub_train_df[X_cols])
        Y_train_sub = np.array(sub_train_df[Y_cols])


        # ===== TEST DATA =====
        # test of the raw data
        X_test = np.array(test_df[X_cols])
        Y_test = np.array(test_df[Y_cols])

        # Set up vis data
        X_grid = create_spatial_grid(
            np.min(train_df['lon']), 
            np.max(train_df['lon']), 
            np.min(train_df['lat']), 
            np.max(train_df['lat']), 
            50, 
            50
        )

        all_epochs = data_fold_df['epoch'].unique()        

        # use floor so we include first and last epoch
        vis_plots_epochs_index = list(
            np.arange(
                0, 
                len(all_epochs), 
                int(np.floor(len(all_epochs)/5))
            )
        )
        vis_plots_epochs = all_epochs[vis_plots_epochs_index]
        vis_plots_epochs_norm = normalise(vis_plots_epochs, train_df['epoch'])


        X_vis_arr = []
        Y_vis_arr = []

        for vis_epoch in vis_plots_epochs_norm:
            X_vis = np.hstack([
                np.ones([X_grid.shape[0], 1])*vis_epoch,
                X_grid
            ])
            Y_vis = None

            X_vis_arr.append(X_vis)
            Y_vis_arr.append(Y_vis)

        i = fold
        fnames = get_data_file_names(i)

        print(f'======== {i} ======')
        print(f'X_train: {X_train.shape}')
        print(f'Y_train: {Y_train.shape}')
        #print(f'X_train_binned: {X_train_binned.shape}')
        #print(f'Y_train_binned: {Y_train_binned.shape}')
        print(f'X_train_sub: {X_train_sub.shape}')
        print(f'Y_train_sub: {Y_train_sub.shape}')
        print(f'X_test: {X_test.shape}')
        print(f'Y_test: {Y_test.shape}')
        print(f'X_all: {X_all.shape}')
        print(f'Y_all: {Y_all.shape}')
        for X_vis in X_vis_arr:
            print(f'X_vis: {X_vis.shape}')
        print(f'==================')

        train_data = {
            'train': {
                'X': X_train,
                'Y': Y_train
            },
            #'train_binned': {
            #    'X': X_train_binned,
            #    'Y': Y_train_binned
            #},
            'train_sub': {
                'X': X_train_sub,
                'Y': Y_train_sub
            }
        }

        test_data = {
            'test': {
                'X': X_test,
                'Y': Y_test
            }, 

            'all': {
                'X': X_all,
                'Y': Y_all
            }
        }

        for i in range(len(vis_plots_epochs)):
            test_data[f'vis_{i}'] = {
                'X': X_vis_arr[i],
                'Y': Y_vis_arr[i]
            }

        # Save unnormalised X for plotting
        raw_data = { 
            'df': data_fold_df,
            'vis_plots_epochs': vis_plots_epochs
        }

        print(f"saving to {data_path / fnames['train']}")
        print(f"saving to {data_path / fnames['test']}")
        print(f"saving to {data_path / fnames['raw']}")
        save_to_pickle(train_data, data_path / fnames['train'])
        save_to_pickle(test_data, data_path / fnames['test'])
        save_to_pickle(raw_data, data_path / fnames['raw'])

if __name__ == '__main__':
    experiment_root = Path('.')
    data_root = Path('../../data/laser/data')

    setup(data_root, experiment_root)
