import pathlib
import warnings

from nfmc_jax.flows.debug import Tag
import json
from typing import Optional, Dict

import numpy as np
from natsort import natsort

from flask import Flask, render_template, request, redirect, url_for
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--path', action='store', dest='path')
parser.add_argument('--max-epochs', action='store', dest='max_epochs', type=int)
parser.add_argument('--max-samples', action='store', dest='max_samples', type=int)
# Experiments in the directory subtree, rooted in the path argument, are parents of "raw_flow_data" subdirectories.
args = parser.parse_args()
root = pathlib.Path(args.path)


class PathsDataset:
    def __init__(self, path: pathlib.Path, max_samples: int = 100, max_epochs: int = None):
        self.path = path
        self.raw_data_directory = self.path / 'raw_flow_data'
        self.paths_directory = self.raw_data_directory / 'paths'
        self.train_paths_directory = self.paths_directory / Tag.train
        self.validation_paths_directory = self.paths_directory / Tag.validation
        self.generative_paths_directory = self.paths_directory / Tag.generative
        self.max_samples = max_samples  # Only consider the first max_samples. This is to avoid large html file size.
        self.max_epochs = max_epochs

        self.numpy_data: Dict[str, Optional[np.ndarray]] = dict(
            train=np.empty(shape=()),
            train_minima=None,
            train_maxima=None,
            validation=None,
            validation_minima=None,
            validation_maxima=None,
            generative=None,
            generative_minima=None,
            generative_maxima=None
        )
        self.info = dict(
            nSteps=0,
            nDimensions=0,
            nLayers=0,
            nTrainingSamples=0,
            nValidationSamples=0,
            nGenerativeSamples=0,
        )

    @staticmethod
    def compute_limits(data: np.ndarray):
        minima = np.min(np.min(np.min(data, axis=0), axis=1), axis=1)
        maxima = np.max(np.max(np.max(data, axis=0), axis=1), axis=1)
        return minima, maxima

    def print_data_stats(self):
        print('=' * 100)
        print('Experiment info')
        print(f'- Path: {self.path.absolute()}')
        print(f'- Number of steps: {self.info["nSteps"]}')
        print(f'- Number of dimensions: {self.info["nDimensions"]}')
        print(f'- Flow layers: {self.info["nLayers"]}')
        print('-' * 100)
        print(f'- Training data size: {self.numpy_data["train"].nbytes / 10 ** 6} MB')
        print(f'- Training samples: {self.info["nTrainingSamples"]}')
        print(f'- Training data shape: {self.numpy_data["train"].shape}')
        print(f'- Training data minima: {self.numpy_data["train_minima"]}')
        print(f'- Training data maxima: {self.numpy_data["train_maxima"]}')
        print('-' * 100)
        if self.numpy_data["validation"] is not None:
            print(f'- Validation data size: {self.numpy_data["validation"].nbytes / 10 ** 6} MB')
            print(f'- Validation samples: {self.info["nValidationSamples"]}')
            print(f'- Validation data shape: {self.numpy_data["validation"].shape}')
            print(f'- Validation data minima: {self.numpy_data["validation_minima"]}')
            print(f'- Validation data maxima: {self.numpy_data["validation_maxima"]}')
        else:
            print(f'- No validation data')
        print('-' * 100)
        if self.numpy_data["generative"] is not None:
            print(f'- Generative data size: {self.numpy_data["generative"].nbytes / 10 ** 6} MB')
            print(f'- Generative samples: {self.info["nGenerativeSamples"]}')
            print(f'- Generative data shape: {self.numpy_data["generative"].shape}')
            print(f'- Generative data minima: {self.numpy_data["generative_minima"]}')
            print(f'- Generative data maxima: {self.numpy_data["generative_maxima"]}')
        else:
            print('- No generative data')
        print('=' * 100)

    def load_training_data(self):
        numpy_files_train = natsort.natsorted(list(self.train_paths_directory.glob('*.npy')))
        if len(numpy_files_train) == 0:
            warnings.warn(f"No .npy files found in {self.train_paths_directory.absolute()}")
            return

        data = np.stack([np.load(str(f)) for f in numpy_files_train])
        data = np.transpose(data, (0, 3, 2, 1))  # (steps, dimensions, samples, layers)
        data = data[:self.max_epochs, :, :self.max_samples, :]
        minima, maxima = self.compute_limits(data)

        self.info['nSteps'] = data.shape[0]
        self.info['nDimensions'] = data.shape[1]
        self.info['nTrainingSamples'] = data.shape[2]
        self.info['nLayers'] = data.shape[3] - 1

        self.numpy_data["train"] = data
        self.numpy_data["train_minima"] = minima
        self.numpy_data["train_maxima"] = maxima

    def load_validation_data(self):
        numpy_files = natsort.natsorted(list(self.validation_paths_directory.glob('*.npy')))
        if len(numpy_files) == 0:
            warnings.warn(f"No .npy files found in {self.validation_paths_directory.absolute()}")
            return

        data = np.stack([np.load(str(f)) for f in numpy_files])
        data = np.transpose(data, (0, 3, 2, 1))  # (steps, dimensions, samples, layers)
        data = data[:self.max_epochs, :, :self.max_samples, :]
        minima, maxima = self.compute_limits(data)

        self.info['nSteps'] = data.shape[0]
        self.info['nDimensions'] = data.shape[1]
        self.info['nValidationSamples'] = data.shape[2]
        self.info['nLayers'] = data.shape[3] - 1

        self.numpy_data["validation"] = data
        self.numpy_data["validation_minima"] = minima
        self.numpy_data["validation_maxima"] = maxima

    def load_generative_data(self):
        numpy_files = natsort.natsorted(list(self.generative_paths_directory.glob('*.npy')))
        if len(numpy_files) == 0:
            warnings.warn(f"No .npy files found in {self.generative_paths_directory.absolute()}")
            return

        data = np.stack([np.load(str(f)) for f in numpy_files])
        data = np.transpose(data, (0, 3, 2, 1))  # (steps, dimensions, samples, layers)
        data = data[:self.max_epochs, :, :self.max_samples, :]
        data = np.flip(data, axis=3)  # Reverse the layer direction
        minima, maxima = self.compute_limits(data)

        self.info['nSteps'] = data.shape[0]
        self.info['nDimensions'] = data.shape[1]
        self.info['nGenerativeSamples'] = data.shape[2]
        self.info['nLayers'] = data.shape[3] - 1

        self.numpy_data["generative"] = data
        self.numpy_data["generative_minima"] = minima
        self.numpy_data["generative_maxima"] = maxima

    def load(self):
        self.load_training_data()
        self.load_validation_data()
        self.load_generative_data()

        self.print_data_stats()

    @property
    def data(self):
        return {k: (v.tolist() if v is not None else []) for k, v in self.numpy_data.items()}


class DatasetManager:
    def __init__(self, path: pathlib.Path, **kwargs):
        self.path = path
        self.experiment_directories = natsort.natsorted([p.parent for p in self.path.rglob('raw_flow_data')])
        if len(self.experiment_directories) == 0:
            raise ValueError(f"No experiment directories found in {str(self.path)}")
        print(f'Total experiments: {len(self.experiment_directories)}')

        self.dataset_kwargs = kwargs

        self.active_index = 0
        self.dataset: PathsDataset = PathsDataset(self.experiment_directories[self.active_index], **self.dataset_kwargs)
        self.dataset.load()

    def set_active_dataset(self, dataset_index):
        if dataset_index == self.active_index:
            return

        self.active_index = dataset_index
        self.dataset = PathsDataset(self.experiment_directories[self.active_index], **self.dataset_kwargs)
        self.dataset.load()

    def reload_active_dataset(self):
        self.dataset.load()

    @property
    def relative_paths(self):
        return [str(p.relative_to(self.path)) for p in self.experiment_directories]

    @property
    def meta_info(self):
        return {
            "experiment_id": self.active_index,
            "default_experiment_id_flask": self.active_index + 1
        }


app = Flask(__name__)
dataset_manager = DatasetManager(root, max_epochs=args.max_epochs, max_samples=args.max_samples)


@app.route('/reload', methods=['POST'])
def reload():
    global dataset_manager

    print('Reloading experiment list')
    dataset_manager = DatasetManager(root)

    return redirect(url_for('home'))


@app.route('/', methods=['GET', 'POST'])
def home():
    global dataset_manager

    # +1, -1 because flask indices starts at 1
    new_experiment_id = int(request.form.get('experimentRadio', dataset_manager.active_index + 1)) - 1
    dataset_manager.set_active_dataset(new_experiment_id)

    return render_template(
        'flow_dashboard_home.html',
        data=dataset_manager.dataset.data,
        experiment_info=dataset_manager.dataset.info,
        experiment_paths=dataset_manager.relative_paths,
        meta_info=dataset_manager.meta_info
    )


if __name__ == '__main__':
    app.run()
