


import os
import sys
import copy
import json
from typing import Dict, List, Optional, Union, Any

import numpy as np

import torch
from einops import rearrange
from matplotlib import pyplot as plt
from torch.nn import functional as F
from PIL import Image

from torch.utils.data.dataset import Dataset
import transformers

from llava.train import rank0_print, DataArguments, preprocess_multimodal, preprocess, DataCollatorForSupervisedDataset


FMRI_SHAPE = (83, 104, 81)


def calculate_total_mean_variance_from_std(mean_vector, std_dev_vector):

    return mean_vector, std_dev_vector


class fMRIViT3dDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path: Union[str, List[str]],
        is_train: bool,
        # data_args: DataArguments,
        patch_size: int = 14,
        return_fmris: bool = True,
        return_embeds: bool = True,
        return_vae_embeds: bool = False,
        return_images: bool = False,
        return_captions: bool = False,
        return_captions_gen: str = None,
        return_embeds_gen: str = None,
        return_images_gen: str = None,
        return_vae_embeds_gen: bool = False,
        return_tokens: bool = False,
        return_net_embeds: List[str] = None,
        return_subject: bool = False,
        return_image_ids: bool = False,
        requires_padding: bool = True,
        requires_norm: bool = True,
        select_subject: Union[str, List[str]] = None,
        return_image_type: str = "pil",
        return_image_size: int = None,
        select_brain_region: Union[str, List[str]] = None,
        image_fns: List[Any] = None,
        clip_aug: int = 0,
    ):
        super(fMRIViT3dDataset, self).__init__()

        if isinstance(data_path, str):
            data_path = [data_path]

        self.means, self.stds = {}, {}
        self.fmri_shape = None
        self.training = is_train
        self.list_data_dict = []
        for path in data_path:
            dataset_name = path.split("/fmris")[0].split("/")[-1]
            data_dict = json.load(open(path, "r"))

            rank0_print("Formatting fMRI dataset to train ViT3D model")
            list_data_dict = data_dict['train'] if is_train else data_dict['val']
            self.list_data_dict.extend(list_data_dict)
            # self.data_args = data_args

            mean = torch.tensor(np.load(data_dict['mean']))
            std = torch.tensor(np.load(data_dict['std']))
            mean, std = calculate_total_mean_variance_from_std(mean, std)

            self.means[dataset_name] = mean
            self.stds[dataset_name] = std


        self.fmri_shape = FMRI_SHAPE

        self.padding = []
        if requires_padding:
            for dim in reversed(self.fmri_shape):
                pad_size = (patch_size - dim % patch_size) % patch_size
                self.padding.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])

        # self.fmri_shape = torch.Size([83, 104, 81])
        self.return_fmris = return_fmris
        self.return_embeds = return_embeds
        self.return_vaes_embeds = return_vae_embeds
        self.return_images = return_images
        self.return_embeds_gen = return_embeds_gen
        self.return_images_gen = return_images_gen
        self.return_vae_embeds_gen = return_vae_embeds_gen
        self.return_net_embeds = return_net_embeds
        self.return_caps = return_captions
        self.return_subject = return_subject
        self.return_image_ids = return_image_ids
        self.return_image_type = return_image_type
        self.return_image_size = return_image_size
        self.select_brain_region = [select_brain_region] if isinstance(select_brain_region, str) else select_brain_region
        self.image_fns = image_fns
        self.clip_aug = clip_aug

        if return_captions_gen:
            self.caps_gen = json.load(open(return_captions_gen, "r"))
        else:
            self.caps_gen = None

        if return_tokens:
            token_path = '/'.join(data_path.split('/')[:-3]) + '/nsd_gpt2_tokens.json'
            self.tokens = json.load(open(token_path, "r"))
        else:
            self.tokens = None

        # print(self.tokens)
        self.requires_norm = requires_norm

        if select_subject is not None:
            if isinstance(select_subject, str):
                select_subject = [select_subject]
            results = []
            for _data in self.list_data_dict:
                # print(_data)
                if _data['subject'] in select_subject:
                    results.append(_data)
            # print(len(results))
            self.list_data_dict = results  # [:100]

        self.subj_mask = None
        if self.select_brain_region:
            subj_paths = data_dict['atlas']
            # print(subj_paths)

            subj_mask = {}
            whole_mask = torch.zeros((1, 1, *self.fmri_shape), dtype=torch.uint8)
            for subj in subj_paths:
                atlas_json = json.load(open(subj_paths[subj].replace("atlas", "atlas_general"), "r"))

                region_name2ids = atlas_json[1]
                atlas = rearrange(torch.tensor(atlas_json[0]), 'z y x -> x y z')
                mask = torch.zeros_like(atlas, dtype=torch.uint8)

                for region in self.select_brain_region:
                    mask[atlas == region_name2ids[region]] = 1
                subj_mask[subj] = mask
                whole_mask |= F.interpolate(
                    mask.unsqueeze(0).unsqueeze(0).float(),
                    size=self.fmri_shape,
                    mode='trilinear',
                    align_corners=False
                ).byte()

            token_mask = F.conv3d(
                whole_mask.float(),
                torch.ones(1, 1, patch_size, patch_size, patch_size).float(),
                stride=(patch_size, patch_size, patch_size),
            ).byte().flatten(2).squeeze()

            self.token_ids = torch.where(token_mask != 0)[0].tolist()
            self.subj_mask = subj_mask
        else:
            self.token_ids = None

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:

        sources = self.list_data_dict[i]
        fmri_file = sources['fmri']
        image_file = sources['image']
        img_idx = image_file.split("_")[-1].split(".")[0]
        dataset_name = sources["image"].split("/images/")[0].split("/")[-1]
        subj = sources["subject"]
        fmri_ids = sources["ids"]

        results = {}

        if self.return_fmris:
            if isinstance(fmri_file, list):
                fmri = [torch.tensor(np.load(f)).float().unsqueeze(0) for f in fmri_file]
                if not self.training:
                    fmri = torch.cat(fmri, dim=0).mean(dim=0, keepdim=True)
                else:
                    # randomly select two fmris, mix them up
                    if len(fmri) > 1:
                        lam = np.random.beta(0.5, 0.5)
                        ids = torch.randperm(len(fmri))[:2]
                        fmri = lam * fmri[ids[0]] + (1 - lam) * fmri[ids[1]]
                    else:
                        fmri = fmri[0]
            else:
                fmri = torch.tensor(np.load(fmri_file)).float().unsqueeze(0)

            if 'bold' in dataset_name:
                fmri = fmri.transpose(1, 3)

            if self.subj_mask:
                mask = self.subj_mask[subj]
                fmri = fmri * mask

            if fmri.shape != self.fmri_shape:
                fmri = F.interpolate(fmri.unsqueeze(0), size=self.fmri_shape, mode='trilinear', align_corners=False)
                fmri = fmri.squeeze(0)
            if self.requires_norm:
                fmri = (fmri - self.means[dataset_name]) / (self.stds[dataset_name] + 1e-5)
            # else:
            #     fmri = fmri / 1e5
            if self.padding is not None:
                fmri = F.pad(fmri, pad=self.padding, mode='constant', value=0.)
            results["fmri"] = fmri

        if self.return_embeds:
            vision_embeds_file = sources['vision_embeds']
            if self.clip_aug != 0:
                aug = np.random.randint(0, self.clip_aug + 1)
                if aug != 0:
                    vision_embeds_file = vision_embeds_file.replace('/vision_embeds/', f'/vision_embeds_aug/').replace('.npy', f'_{aug:03}.npy')

            vision_embeds = np.load(vision_embeds_file)
            vision_embeds = torch.tensor(vision_embeds)
            results["labels"] = vision_embeds

        if self.return_images:
            image = Image.open(image_file).convert("RGB")  # np.array()

            if self.image_fns:
                for ids, fn in enumerate(self.image_fns):
                    img = fn(image)
                    results[f"images_{ids:02}"] = img

            if self.return_image_size:
                image = image.resize((self.return_image_size, self.return_image_size))
            if self.return_image_type == "np":
                image = np.array(image)
            results["images"] = image

        if self.return_embeds_gen:
            vision_embeds_gen_file = os.path.join(self.return_embeds_gen, f"{subj}/{fmri_ids:06}.npy")
            vision_embeds_gen = np.load(vision_embeds_gen_file)
            vision_embeds_gen = torch.tensor(vision_embeds_gen)
            results["labels_gen"] = vision_embeds_gen

        if self.return_vae_embeds_gen:
            vision_embeds_gen_file = os.path.join(self.return_embeds_gen, f"{subj}/{fmri_ids:06}.npy")
            vae_embeds_gen_file = vision_embeds_gen_file.replace("fmri2embeds", "fmri2vae").replace('.npy', '.pt')
            vae_embeds_gen = torch.load(vae_embeds_gen_file)
            results["vae_labels_gen"] = vae_embeds_gen

        if self.return_images_gen:
            image_gen_file = os.path.join(self.return_images_gen, f"{subj}/{fmri_ids:06}.png")
            image_gen = Image.open(image_gen_file).convert("RGB")  # np.array()

            if self.image_fns:
                for ids, fn in enumerate(self.image_fns):
                    img = fn(image_gen)
                    results[f"images_gen_{ids:02}"] = img

            if self.return_image_size:
                image_gen = image_gen.resize((self.return_image_size, self.return_image_size))
            if self.return_image_type == "np":
                image_gen = np.array(image_gen)
            # plt.imshow(image_gen)
            # plt.show()
            results["images_gen"] = image_gen

        if self.return_vaes_embeds:
            vae_embeds_file = sources['vision_embeds'].replace("vision", "vae").replace('.npy', '.pt')
            vae_embeds = torch.load(vae_embeds_file)
            # vae_embeds = torch.tensor(vae_embeds)
            results["vae_labels"] = vae_embeds

        if self.return_net_embeds:
            for net_key in self.return_net_embeds:
                net_embeds_file = os.path.join(net_key, f"{fmri_ids:06}.npy")
                net_embeds = np.load(net_embeds_file)
                net_embeds = torch.tensor(net_embeds)
                key = "_".join(net_key.split("/")[-3:])
                results[key] = net_embeds

        if self.tokens:
            tokens = self.tokens[int(img_idx)]["tokens"]
            rand = np.random.randint(0, len(tokens))
            # print(tokens)
            # print(type(tokens))
            results["tokens"] = tokens[rand]

        if self.return_caps:
            results["captions"] = sources.get("caption", None)

        if self.caps_gen:
            results["captions_gen"] = self.caps_gen[i]

        if self.return_subject:
            results["subject"] = subj

        if self.return_image_ids:
            results["image_ids"] = int(img_idx)

        # print(results.keys())
        return results


if __name__ == '__main__':
    dataset = fMRIViT3dDataset(
        data_path='/mnt/NSD_dataset/datasets/nsd/fmris/pretrain_new.json',
        is_train=False,
        select_brain_region=['nsdgeneral'],
        select_subject='subj07',
    )

    print(len(dataset))
    print(dataset.token_ids)
    print(len(dataset.token_ids))

    print(dataset.subj_mask["subj01"].shape)
    print(dataset.subj_mask["subj02"].shape)

    for data in dataset:
        print(data["fmri"].shape)
