import argparse
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List

import torch

import transformers
from torch.nn import MSELoss
from transformers import CLIPVisionConfig, Trainer, ViTConfig
from transformers import VisionEncoderDecoderConfig

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
    DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token
from llava.train import *
from llava.datasets.fmri_datasets import make_supervised_fmri_data_module

from PIL import Image

from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset
from llava.model.fmri_encoder.vit3d import CLIPVision3dModelWithProjection
from llava.model.fmri_encoder.cvt3d import CLIPCVT3dModelWithProjection
from llava.model.fmri_encoder.vit3d_decoder import ViT3dDecoderModel, ViT3dWithProjectionModel

parser = argparse.ArgumentParser()

parser.add_argument(
    "--num-layers",
    type=int,
    default=24,
)

parser.add_argument(
    "--from-pretrained",
    type=str,
    default="",
)

parser.add_argument(
    "--patch-size",
    type=int,
    default=14,
)

parser.add_argument(
    "--batch-size",
    type=int,
    default=32,
)

parser.add_argument(
    "--hidden-size",
    type=int,
    default=768,
)

parser.add_argument(
    "--num-heads",
    type=int,
    default=12,
)

parser.add_argument(
    "--model",
    type=str,
    default="vit3d_new",
)

parser.add_argument(
    "--subject",
    type=str,
    default="subj01",
)

parser.add_argument(
    "--suffix",
    type=str,
    default="",
)

parser.add_argument(
    "--dataset",
    nargs='+',
    type=str,
    default=["nsd"],
)


parser.add_argument(
    "--exclude-subj",
    type=str,
    default=""
)

parser.add_argument(
    "--vae",
    action="store_true",
)

parser.add_argument(
    "--select-region",
    nargs='+',
    type=str,
    default=None,
)

parser.add_argument(
    "--select-subj",
    type=str,
    default=None,
)

parser.add_argument(
    "--flatten",
    action="store_true",
)

parser.add_argument(
    "--data-path",
    type=str,
    default=""
)

parser.add_argument(
    '--clip-aug',
    type=int,
    default=0
)

args = parser.parse_args()


def calc_image_size(patch_size):
    base_size = (83, 104, 81)
    res = []
    for dim in base_size:
        pad_size = (patch_size - dim % patch_size) % patch_size
        res.append(dim + pad_size)
    return tuple(res)


def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['fmri'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch]),
        'vae_labels': torch.stack([x['vae_labels'] if 'vae_labels' in x else x['labels'] for x in batch]),
    }


class MyTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        clip_loss = outputs.loss
        vae_loss = outputs.vae_loss
        loss = clip_loss + vae_loss / 81  #

        if isinstance(vae_loss, float):
            vae_loss = torch.tensor(vae_loss)

        if self.state.global_step % self.state.logging_steps == 0 and model.training:
            self.log({
                "loss": float(loss.item()),
                "clip_loss": float(clip_loss.item()),
                "vae_loss": float(vae_loss.item()),
            })

        if not hasattr(self, "last_train_state"):  # first iteration
            self.last_train_state = model.training
            self.eval_loss = []
            self.eval_clip_loss = []
            self.eval_vae_loss = []
            self.best_eval_loss = 1e6

        if self.last_train_state is False and model.training:  # first iteration after eval
            current_eval_loss = sum(self.eval_loss) / len(self.eval_loss)
            self.log({
                "eval_loss": current_eval_loss,
                "eval_clip_loss": sum(self.eval_clip_loss) / len(self.eval_clip_loss),
                "eval_vae_loss": sum(self.eval_vae_loss) / len(self.eval_vae_loss),
            })

            if current_eval_loss < self.best_eval_loss:
                self.best_eval_loss = current_eval_loss

            self.eval_loss = []
            self.eval_clip_loss = []
            self.eval_vae_loss = []

        if not model.training:
            self.eval_loss.append(float(loss.item()))
            self.eval_clip_loss.append(float(clip_loss.item()))
            self.eval_vae_loss.append(float(vae_loss.item()))

        self.last_train_state = model.training

        return (loss, outputs) if return_outputs else loss


def train():

    data_path = []
    for dataset in args.dataset:
        if 'all' not in args.subject:
            path = f'/mnt/NSD_dataset/datasets/{dataset}/fmris/{args.subject}/pretrain{args.data_path}.json'
        else:
            path = f'/mnt/NSD_dataset/datasets/{dataset}/fmris/pretrain{args.data_path}.json'
        data_path.append(path)

    dataset_name = '_'.join(args.dataset)
    region_name = '_'.join(args.select_region) if args.select_region is not None else "all"

    train_subjects = None
    val_subjects = None

    assert not(args.exclude_subj != "" and args.select_subj is not None), \
        "Exclude subject and select subject cannot be set at the same time"

    if args.exclude_subj != "":
        train_subjects = [f"subj0{i}" for i in [1, 2, 5, 7] if f"subj0{i}" != args.exclude_subj]
        val_subjects = [args.exclude_subj]
    elif args.select_subj is not None:
        train_subjects = [args.select_subj]
        val_subjects = [args.select_subj]

    if args.model == "vit3d":
        model_cls = CLIPVision3dModelWithProjection
        model_str = "vit3d"
    elif "cvt3d" in args.model:
        model_cls = CLIPCVT3dModelWithProjection
        model_str = "cvt3d"
    elif "vitdecoder" in args.model:
        model_cls = ViT3dDecoderModel
        model_str = "vitdecoder"
    elif "new" in args.model:
        model_cls = ViT3dWithProjectionModel
        model_str = "vit3d_new"
    else:
        raise ValueError(f"Unknown model {args.model}")

    if args.data_path:
        model_str = f'{model_str}{args.data_path}'

    is_vae = "vae" if args.vae else "clip"
    training_args = TrainingArguments(
        output_dir=f"./checkpoint/{model_str}-"
                   f"{dataset_name}-{args.subject}-"
                   f"{is_vae}-"
                   f"{args.patch_size}-{args.num_layers}-"
                   f"{args.hidden_size}-{region_name}-{args.exclude_subj}-{args.suffix}",
        run_name=f"./checkpoint/{model_str}-"
                 f"{dataset_name}-{args.subject}-"
                 f"{is_vae}-"
                 f"{args.patch_size}-{args.num_layers}-"
                 f"{args.hidden_size}-{region_name}-{args.exclude_subj}-{args.suffix}",
        per_device_train_batch_size=args.batch_size,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        eval_steps=100,
        save_steps=100,
        num_train_epochs=300,
        logging_steps=5,
        learning_rate=5e-4,
        save_total_limit=2,
        load_best_model_at_end=True,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to=["wandb"],
        dataloader_num_workers=8,
        lr_scheduler_type="cosine",
        warmup_ratio=0.08,
    )

    dataset_train = fMRIViT3dDataset(
        data_path=data_path,
        is_train=True,
        patch_size=args.patch_size,
        select_subject=train_subjects,
        return_vae_embeds=args.vae,
        select_brain_region=args.select_region,
        clip_aug=args.clip_aug,
    )

    dataset_val = fMRIViT3dDataset(
        data_path=data_path,
        is_train=False,
        patch_size=args.patch_size,
        select_subject=val_subjects,
        return_vae_embeds=args.vae,
        select_brain_region=args.select_region,
    )

    model_args = ViTConfig(
        image_size=calc_image_size(args.patch_size),
        num_hidden_layers=args.num_layers,
        num_channels=1,
        patch_size=(args.patch_size, args.patch_size, args.patch_size),
        hidden_size=args.hidden_size,
        num_attention_heads=args.num_heads,
        projection_dim=1024,
        with_vae=args.vae,
        flatten_patch=args.flatten,
        token_ids=dataset_train.token_ids,
    )
    model = model_cls(
        config=model_args,
    )

    if training_args.bf16:
        model.to(torch.bfloat16)
    if training_args.fp16:
        model.to(torch.float16)

    trainer = MyTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset_train,
        eval_dataset=dataset_val,
        data_collator=collate_fn,
    )

    trainer.train()


if __name__ == '__main__':
    train()
