import os
from typing import Dict, List, Tuple

import numpy as np
import torch
import xarray as xr
from scipy.io import netcdf
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset


############################## real world data ##############################
class SSTDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        mode="train",
        start_idx=0,
        chunk_size: int = 52 * 4,
    ) -> None:
        super().__init__()
        from scipy.io import netcdf

        sst_netcdf = netcdf.NetCDFFile(os.path.join(data_path, "sst.wkmean.1990-present.nc"), "r")
        keys = list(sst_netcdf.variables.keys())
        self.data = {k: np.asarray(sst_netcdf.variables[k][:].byteswap().newbyteorder()) for k in keys}

        m, std = self.data["sst"].mean(), self.data["sst"].std()
        self.data["sst"] = (self.data["sst"] - m) / std  # [ts=1727, lat=180, lon=360]

        # time steps
        self.time_steps = self.data["sst"].shape[0]  # = 1727 #, too large
        self.chunk_size = chunk_size or self.time_steps
        self.lat_dim = self.data["lat"].shape[0]
        self.lon_dim = self.data["lon"].shape[0]

        # only used in testing / forecasting
        self.time_indices = np.arange(self.time_steps - 2*self.chunk_size, self.time_steps-self.chunk_size)

        # normalise time
        self.data["time"] = self.data["time"] - self.data["time"][0]

        self.data["sst"] = self.data["sst"].reshape(self.time_steps, -1).T

        self.mode = mode

        # if mode in ["train"]:
        #     self.data["sst"] = self.data["sst"][: int(0.9 * self.data["sst"].shape[0])]
        # elif mode in ["test"]:
        #     self.data["sst"] = self.data["sst"][int(0.9 * self.data["sst"].shape[0]) :]

        self.data["sst"] = self.data["sst"][..., None]
        # for validation: to reproduce everything use the whole dataset

    def __len__(self):
        return self.data["sst"].shape[0]  # number of locations

    def __getitem__(self, index) -> Tuple:
        if self.mode != "test":
            if self.time_steps - 2*self.chunk_size > 0:
                time_index = np.random.randint(self.time_steps - 2*self.chunk_size)
            else:
                time_index = 0

            return {
                "index": index,
                "time_index": time_index,
                "states": self.data["sst"][index, time_index : time_index + self.chunk_size],
            }  # index of the location (i,j) and (i',j'
        else:
            return {
                "index": index,
                "states": self.data["sst"][index, self.time_indices],
            }  # index of the location (i,j) and (i',j'

    def collate_fn(self, batch: List[Dict]):
        indices = []
        states = []
        aug_indices = []
        aug_states = []
        time_indices = []
        for b in batch:
            indices += [b["index"]]
            states += [b["states"]]
            time_indices += [b["time_index"]]
            lat, lon = np.unravel_index(b["index"], (self.lat_dim, self.lon_dim))
            aug_lon = np.random.randint(max(0, lon - 5), min(self.lon_dim, lon + 5))  # index for sampled longitude
            aug_index = np.ravel_multi_index((lat, aug_lon), (self.lat_dim, self.lon_dim))
            aug_indices += [aug_index]
            aug_states += [self.data["sst"][aug_index, b["time_index"] : b["time_index"] + self.chunk_size]]

        batch_dict = {
            "index": torch.stack([torch.tensor(indices), torch.tensor(aug_indices)], dim=0),
            "states": torch.stack([torch.from_numpy(np.stack(states)), torch.from_numpy(np.stack(aug_states))], dim=0),
        }
        return batch_dict

    def predict_collate_fn(self, batch):
        indices = []
        states = []
        time_indices_w_forecast = np.arange(self.time_indices[0], self.time_indices[0] + 2 * self.chunk_size)
        for b in batch:
            indices += [b["index"]]
            states += [self.data["sst"][b["index"], time_indices_w_forecast]]

        batch_dict = {"index": torch.tensor(indices), "states": torch.from_numpy(np.stack(states))}
        return batch_dict


class ERA5Dataset(Dataset):
    def __init__(
        self,
        data_path: str,
        mode: str = "train",
        variable_name: str = "sst",
        chunk_size: int = None,
    ) -> None:
        super().__init__()
        assert variable_name in ["wind", "sst"], "valid variable names are: 'wind' or 'sst'"
        dataset = netcdf.NetCDFFile(data_path, "r")
        keys = ["longitude", "latitude", "time"]
        # keys += ['sst'] if variable_name == "sst" else keys += ['u10', 'v10']
        self.data = {k: np.asarray(dataset.variables[k][:].byteswap().newbyteorder()) for k in keys}

        if variable_name == "sst":
            states = np.asarray(dataset.variables["sst"][:].byteswap().newbyteorder(), dtype=float)[
                :, 0, ..., None
            ]  # [time, lat, lon, 1]
        else:
            u10 = np.asarray(dataset.variables["u10"][:].byteswap().newbyteorder(), dtype=float)[:, 0]
            v10 = np.asarray(dataset.variables["v10"][:].byteswap().newbyteorder(), dtype=float)[:, 0]
            states = np.stack([u10, v10], -1)  # [time, lat, lon, 2]

        self.data["lat"] = self.data.pop("latitude")
        self.data["lon"] = self.data.pop("longitude")

        # time steps
        self.time_steps = self.data["time"].shape[0]  # 412
        self.chunk_size = chunk_size or self.time_steps
        self.lat_dim = self.data["lat"].shape[0]
        self.lon_dim = self.data["lon"].shape[0]

        # normalise data
        self.data["time"] = self.data["time"] - self.data["time"][0]
        self.data["states"] = (
            StandardScaler()
            .fit_transform(states.reshape(-1, states.shape[-1]))
            .reshape(self.time_steps, -1, states.shape[-1])
        )
        self.data["states"] = self.data["states"].transpose(1, 0, 2)  # [lat*lon, ts, state_dim]

        if mode in ["train"]:
            self.data["states"] = self.data["states"][: int(0.9 * self.data["states"].shape[0])]
        elif mode in ["test"]:
            self.data["states"] = self.data["states"][int(0.9 * self.data["states"].shape[0]) :]

    def __len__(self):
        return self.data["states"].shape[0]

    def __getitem__(self, index) -> Dict:
        if self.chunk_size < self.time_steps:
            time_index = np.random.randint(self.time_steps - self.chunk_size)
        else:
            time_index = 0
        return {
            "index": index,
            "time_index": time_index,
            "states": self.data["states"][index, time_index : time_index + self.chunk_size],
        }  # index of the location (i,j) and (i',j'

    def collate_fn(self, batch):
        time_indices = []
        indices = []
        states = []
        aug_indices = []
        aug_states = []
        for b in batch:
            time_indices += [b["time_index"]]
            indices += [b["index"]]
            states += [b["states"]]
            aug_index = np.random.randint(len(self))
            aug_indices += [aug_index]
            aug_states += [self.data["states"][aug_index, b["time_index"] : b["time_index"] + self.chunk_size]]

        batch_dict = {
            "index": torch.stack([torch.tensor(indices), torch.tensor(aug_indices)], dim=0),
            "states": torch.stack([torch.from_numpy(np.stack(states)), torch.from_numpy(np.stack(aug_states))], dim=0),
        }
        return batch_dict


GS_PATH = (
    "gs://weatherbench2/datasets/era5_biweekly/1959-2023_01_10-1h-240x121_equiangular_with_poles_conservative.zarr/"
)


class WeatherBench2Dataset(Dataset):
    def __init__(
        self,
        data_path: str = GS_PATH,
        mode: str = "train",
        variable_name: str = "sea_surface_temperature",
    ) -> None:
        # for now we only have sst data
        # for wind data we have to stack u and v dimensions
        self.ds = xr.open_zarr(data_path)[variable_name]  # [ts, lon, lat]
        self.lat_dim = self.ds.latitude.shape[0]
        self.lon_dim = self.ds.longitude.shape[0]

        if mode != "test":
            self.offset = 0
            self.num_samples = int(0.9 * self.lat_dim * self.lon_dim)
        else:
            self.offset = int(0.9 * self.lat_dim * self.lon_dim)
            self.num_samples = int(0.1 * self.lat_dim * self.lon_dim)

        self.time_interval = 30  # consider monthly data
        self.time_steps = self.ds.time.shape[0] // self.time_interval

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        lon, lat = np.unravel_index(index + self.offset, (self.lon_dim, self.lat_dim))
        return {"index": index, "states": self.ds[:: self.time_interval, lon, lat].values}

    def collate_fn(self, batch):
        indices = []
        states = []
        aug_indices = []
        aug_states = []
        for b in batch:
            indices += [b["index"]]
            states += [b["states"]]
            aug_index = np.random.randint(len(self))

            dict_item = self.__getitem__(aug_index)
            aug_indices += [dict_item["index"]]
            aug_states += [dict_item["states"]]

        batch_dict = {
            "index": torch.stack([torch.tensor(indices), torch.tensor(aug_indices)], dim=0),
            "states": torch.stack([torch.from_numpy(np.stack(states)), torch.from_numpy(np.stack(aug_states))], dim=0),
        }
        return batch_dict
