
import os
import sys

import json
import argparse

import scipy
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GPT2TokenizerFast
from transformers import AutoTokenizer, CLIPTextModelWithProjection
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from diffusers import StableUnCLIPImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionXLImg2ImgPipeline
from diffusers import StableDiffusionXLAdapterPipeline
import torch
import matplotlib.gridspec as gridspec

from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset
from llava.model.fmri_encoder.vit3d import CLIPVision3dModelWithProjection
from llava.train import DataArguments

parser = argparse.ArgumentParser()

parser.add_argument(
    "--device",
    type=str,
    default="cuda:0",
    help="device"
)
parser.add_argument(
    "--seed",
    type=int,
    default=42,
    help="the seed (for reproducible sampling)",
)

parser.add_argument(
    "--dataset",
    type=str,
    default="nsd",
)

parser.add_argument(
    "--subject",
    type=str,
    default="subj01",
)

parser.add_argument(
    "--model",
    type=str,
    default="",
)

parser.add_argument(
    "--batch-size",
    type=int,
    default=10,
)

parser.add_argument(
    "--caption-type",
    type=str,
    default="none",
)

parser.add_argument(
    "--llm",
    type=str,
    default='',
)

parser.add_argument(
    "--caption-ids",
    type=int,
    default=0,
)

parser.add_argument(
    "--select-subject",
    type=str,
    default=None,
)

parser.add_argument(
    "--noise-levels",
    nargs='+',
    type=float,
    default=[0.99, 0.98, 0.97, 0.96, 0.95, 0.94, 0.93, 0.92, 0.91, 0, 0.15, 0.3, 0.45, 0.6, 0.75, 0.85, 0.9],
)

args = parser.parse_args()

device = args.device

if __name__ == '__main__':
    # /mnt/NSD_dataset/datasets
    sorted_ids = None
    subject_bias = {
        'subj01': 0,
        'subj02': 2770,
        'subj05': 5540,
        'subj07': 8310,
    }
    output_suffix = ""
    if args.caption_type == 'caption_gen':
        llm_name = args.llm.split('/')[-1].replace('.json', '')
        output_suffix = f"_{llm_name}"
        if args.caption_ids == -1:
            sorted_ids = json.load(open(f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subject}/llava_captions/{llm_name}/best_ids.json'))
    elif args.caption_type == 'caption':
        output_suffix = "_caption"

    if args.subject == 'all':
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/pretrain.json'
    else:
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/{args.subject}/pretrain.json'

    embeds_gen_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subject}/fmri2embeds/{args.model}'

    if 'gen' in args.caption_type:
        args.caption_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subject}/llava_captions/{args.llm}'
    else:
        args.caption_path = None

    dataset_val = fMRIViT3dDataset(
        data_path=data_path,
        is_train=False,
        return_fmris=False,
        return_embeds=True,
        return_embeds_gen=embeds_gen_dir,
        return_vae_embeds_gen=('vae' in args.model),
        return_images=False,
        return_captions=True,
        return_subject=True,
        select_subject=args.select_subject,
        return_captions_gen=args.caption_path,
    )

    ## Create the pipeline
    pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(device)

    if 'vae' not in args.model:
        args.noise_levels = [0]

    for noise_level in args.noise_levels:
        args.noise_level = noise_level
        args.output_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subject}/embeds2images/{args.model}{output_suffix}_{args.noise_level}_{args.caption_ids}'
        os.makedirs(args.output_dir, exist_ok=True)

        subject_counter = {}
        with torch.no_grad():
            for idx in tqdm(range(0, len(dataset_val), args.batch_size), desc=f"noise level {args.noise_level}"):

                # vision_embeds = []
                vision_embeds_gen = []
                vae_embeds_gen = []
                captions = []
                subjects = []
                for i in range(args.batch_size):
                    data = dataset_val[idx + i]
                    vision_embeds_gen.append(data['labels_gen'])
                    if "vae_labels_gen" in data:
                        vae_embeds_gen.append(data['vae_labels_gen'])
                    if "caption_gen" in args.caption_type:
                        if sorted_ids:
                            caption_ids = sorted_ids[subject_bias[args.select_subject] + idx + i][0]
                        else:
                            caption_ids = args.caption_ids

                        capt = data['captions_gen']["captions"][caption_ids].replace("<s>", "").replace("</s>", "").replace("\n", "").replace("\t", "").replace("\r", "").replace("  ", " ").strip()
                        capt = capt.replace("I see a serene scene of ", "")
                        capt = capt.replace("I see ", "")
                        captions.append(capt)

                        # print(caption_ids)
                        # print(capt)
                    elif "caption" in args.caption_type:
                        captions.append(data['captions'][-1])
                    else:
                        captions.append("")

                    print(captions)
                    subjects.append(data['subject'])
                    if idx + i + 1 >= len(dataset_val):
                        break

                # print(captions, len(captions))
                current_bs = len(vision_embeds_gen)
                vision_embeds_gen = torch.stack(vision_embeds_gen).to(device)

                if vae_embeds_gen:
                    vae_embeds_gen = torch.stack(vae_embeds_gen).to(device).half() * (1. - args.noise_level) + randn_tensor((1, 4, 96, 96), device=device).half() * args.noise_level
                else:
                    vae_embeds_gen = None

                # print(vae_embeds_gen.shape)

                captions = [f"{capt}" for capt in captions]
                images_gen = pipe(
                    prompt=captions,
                    image_embeds=vision_embeds_gen.half(),
                    # image=list(images[idx] for _ in range(current_bs)),
                    num_images_per_prompt=1,
                    guidance_scale=7.5,
                    latents=vae_embeds_gen,
                ).images

                for j in range((len(images_gen))):
                    output_dir = f'{args.output_dir}/{subjects[j]}'
                    if subjects[j] not in subject_counter:
                        subject_counter[subjects[j]] = 0
                        os.makedirs(f'{output_dir}', exist_ok=True)

                    image_gen = images_gen[j]
                    image_gen.save(f'{output_dir}/{subject_counter[subjects[j]]:06}.png')
                    subject_counter[subjects[j]] += 1
