import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
import os
from copy import deepcopy
from MyDiffusers import *
from flowsrepo import *
import time
from PIL import Image
import numpy as np
import cv2
from diffusers import StableDiffusionPipeline
import argparse

def get_attention_maps_median(attention_maps, threshold_bin=3, num_of_maps_to_avg = 10 , maps_to_sum = [4,5], expand=True, expand_size=2):

    all_attn_maps = attention_maps.attention_history
    all_attn_maps = torch.cat(all_attn_maps)

    # mean_attn_maps = all_attn_maps[-num_of_maps_to_avg:].mean(dim=(0,5)).detach().cpu()[1]
    mean_attn_maps = all_attn_maps[-num_of_maps_to_avg:].mean(dim=(0,5)).detach().cpu()[1,:]

    batch_size = all_attn_maps.shape[2]

    attn_maps = []
    bin_attn_maps = []

    for b in range(batch_size): 
        attn_map0 = 0*mean_attn_maps[b,:,:,0] 
    
        for m in maps_to_sum:
            attn_map0 += mean_attn_maps[b,:,:,m]

        binary_attn_map0 = torch.where(attn_map0 < threshold_bin*np.median(attn_map0), 0 , 1 )
        attn_maps.append(attn_map0)
        if expand:
            binary_attn_map0 +=  torch.roll(binary_attn_map0,expand_size,dims=0) + torch.roll(binary_attn_map0,-expand_size,dims=0) + torch.roll(binary_attn_map0,expand_size,dims=1) + torch.roll(binary_attn_map0,-expand_size,dims=1)
        bin_attn_maps.append(binary_attn_map0.clamp(0,1))
    

    return all_attn_maps, attn_maps, bin_attn_maps

def pos_manage_latent(
    back_latent,
    bee_latent,
    binary_map,
    pos_x=[20],
    pos_y=[20],
    lim_pos_x=64,
    lim_pos_y=64,
    num_subjects=1,
    scale=False,
    scale_num=0.8,
):

    # resize bin map to 64x64
    bin_map = torch.nn.functional.interpolate(
        binary_map.unsqueeze(0).unsqueeze(0).to(torch.float32),
        size=[64, 64],
        mode="nearest",
    )

    nonzero_coords = torch.nonzero(bin_map[0, 0])
    # print(nonzero_coords.shape, bin_map.shape)
    max_lat = bee_latent.max()
    min_lat = bee_latent.min()
    # Calculate the center of mass
    center_of_mass = torch.mean(nonzero_coords.float(), dim=0)
    center_of_mass = torch.nan_to_num(center_of_mass)
    # print(center_of_mass)
    bee_latent = bee_latent * bin_map

    many_bee_latent = 0 * bee_latent.clone()

    many_bin_map = bin_map.clone()
    scale_many_bee_binmap = 0 * bin_map.clone()

    for s in range(num_subjects):

        # print(int(pos_x[s]) , int(center_of_mass[0]))

        if (
            pos_x[s] < lim_pos_x
            and pos_y[s] < lim_pos_y
            and pos_x[s] > 1
            and pos_y[s] > 1
        ):

            roll_x = int(pos_x[s]) - int(center_of_mass[0])
            roll_y = int(pos_y[s]) - int(center_of_mass[1])

            shifted_bee_latent = torch.roll(bee_latent, roll_x, dims=-2)
            shifted_bee_latent = torch.roll(shifted_bee_latent, roll_y, dims=-1)
            shifted_bee_latent[..., :2] = bee_latent[..., :2]

            shifted_bin_map = torch.roll(bin_map, roll_x, dims=-2)
            shifted_bin_map = torch.roll(shifted_bin_map, roll_y, dims=-1)
            # shifted_bin_map[:,:2] = bin_map[:,:2]

            if s == 0:
                # replace the previous position of the bee with noise of the shifted position
                latent_to_repl = back_latent * shifted_bin_map

                shifted_latent_to_repl = torch.roll(latent_to_repl, -roll_x, dims=-2)
                shifted_latent_to_repl = torch.roll(
                    shifted_latent_to_repl, -roll_y, dims=-1
                )

                many_bee_latent += shifted_latent_to_repl

            scale_many_bee_binmap += shifted_bin_map

            many_bee_latent += (
                shifted_bee_latent
            )
            many_bin_map += shifted_bin_map

    many_bin_map = many_bin_map.clamp(0, 1)
    not_bee_latent = (1 - many_bin_map) * back_latent

    if scale:
        scale_many_bee_latent = TF.affine(
            many_bee_latent, angle=0, translate=[0, 0], shear=0, scale=scale_num
        )
        scale_many_bee_binmap = TF.affine(
            scale_many_bee_binmap, angle=0, translate=[0, 0], shear=0, scale=scale_num
        )

        new_latent = not_bee_latent + scale_many_bee_latent
        new_latent = new_latent.clamp(min_lat, max_lat)

        return new_latent, scale_many_bee_binmap.clamp(0, 1), bin_map

    new_latent = not_bee_latent.clamp(min_lat, max_lat) + many_bee_latent.clamp(
        min_lat, max_lat
    )
    new_latent = new_latent.clamp(min_lat, max_lat)

    return new_latent, scale_many_bee_binmap.clamp(0, 1), bin_map

args = argparse.ArgumentParser()
args.add_argument("--tau", type=int, default=400)
args.add_argument("--num_inference_steps", type=int, default=200)
args.add_argument("--guidance_scale", type=float, default=7.5)
args.add_argument("--tag", type=str, default="default")
args.add_argument(
    "--example",
    type=str,
    required=True,
)
args.add_argument("--device", type=str, default="cuda:0")
args.add_argument("--crossframeattention_pattern", type=str, default="[[0,0],[1,1],[0,1]]")
args.add_argument("--invert", action=argparse.BooleanOptionalAction, default=True)
args.add_argument("--spatialeta", action=argparse.BooleanOptionalAction, default=True)
args.add_argument("--interpolationmode", type=str, choices=["bilinear","nearest","bicubic"], default="nearest")
args = args.parse_args()

##### Parameters
torch.manual_seed(11)
device = args.device
torch.set_default_device(device)
if args.example == "satellite":
    image_warper = SatelliteFlow(N=64)
elif args.example == "dragons":
    image_warper = BattleOfDragonsFlow(N=64)
elif args.example == "earth":
    image_warper = SphereFlow(N=64, radius=0.83)
elif args.example == "meltingman":
    image_warper = MeltingManFlow(N=64)
elif args.example == "birds":
    image_warper = Birds(N=64)
elif args.example == "glass":
    image_warper = GlassFlow(N=64)
elif args.example == "meltingman_fluid":
    image_warper = MeltingManFluidFlow(N=64)
else:
    raise NotImplementedError
prompt, negative_prompt = image_warper.get_default_prompt()

# prompt = ""
# clean |-------------------|> noise
# 0     |----tau------------|> 1000
# 0     |---*-----*------*--|> * num inference steps
tau = args.tau
num_inference_steps = args.num_inference_steps
guidance_scale = args.guidance_scale
folder_path = f"output/video/{args.tag}/{time.time()}_{args.example}"
os.makedirs(folder_path, exist_ok=True)
SDM = StableDiffusionManager(device, tau)


def save_output(output, name):
    for i, img in enumerate(output["images"]):
        img.save(f"{folder_path}/{name}_{i}.png")


single_attention_processor = get_attention_processor(
    video_length=1, crossframe_attn="disabled", should_record_history=False
)

cross_attention_processor = get_attention_processor_from_pattern(args.crossframeattention_pattern)

SDM.pipeline.unet.set_attn_processor(single_attention_processor)
if args.example == 'birds':
    seed = 801
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.set_default_device('cpu')
    z = torch.randn(1, 4, 64, 64, device='cpu').to(SDM.device)
    z[:, :, 40:60, :20] = z[:, :, :20, :20]
    attn = MyCrossFrameAttnProcessor2_0(video_length=1)
    attn.attention_history = []
    attn.filter_latent_dimension = 32
    attn.should_record_history = True
    SDM.pipeline.unet.set_attn_processor(attn)

    _, mon = SDM.partial_generation_remaining(
        z = z,
        prompt=prompt,
        guidance_scale=7.5,
        eta=0,
        negative_prompt = [negative_prompt],
        num_inference_steps=num_inference_steps,
    )
    all_att_maps, att_maps_beftau, binary_maps_beftau = get_attention_maps_median(
        attn,
        threshold_bin=7.0,
        num_of_maps_to_avg=10,
        maps_to_sum=[6],
        expand=False,
        expand_size=1,
    )
    sel_bin_map = binary_maps_beftau[-1].to(device)
else:
    base_img = image_warper.get_default_image().resize((512,512))
    z0 = SDM.image_to_latent(base_img)
    _, mon = SDM.partial_inversion(
        z=z0,
        prompt=prompt,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=0.0,
    )
    sel_bin_map = None

z_tau_orig = mon["latents"][-1].to(device)
z_tau = z_tau_orig.clone()

# Then we enter the loop
framesteps = image_warper.get_default_framesteps()
for f, (framestep) in enumerate(framesteps):
    if args.example == 'birds':
        agentsnum = len(image_warper.allag_pos_t[f])
        pos_xt = [np.nonzero(image_warper.allag_pos_t[f][p])[0] for p in range(agentsnum)]
        pos_yt = [np.nonzero(image_warper.allag_pos_t[f][p])[1] for p in range(agentsnum)]
        sel_bin_map = sel_bin_map.to(device)
        warped_latent, agents_binmap, starting_binmap = pos_manage_latent(
            back_latent=z_tau_orig,
            bee_latent=z_tau_orig,
            binary_map=sel_bin_map.clone(),
            pos_x=pos_xt,
            pos_y=pos_yt,
            lim_pos_x=64,
            lim_pos_y=64,
            num_subjects=agentsnum,
            scale=False,
            scale_num=0.7,
        )
        spatial_eta = (agents_binmap + starting_binmap).clamp(0, 1).to(device)
        spatial_eta[..., -10:, :] = 1
    else:
        # Warp the latent according to the attention map
        warped_latent = image_warper.warp_latent_and_correct(
            t=framestep,
            original_frame=z_tau_orig,
            alphabar_tau=SDM.alphabar[tau].item(),
            previous_frame=z_tau,
            mode=args.interpolationmode,
        )
        spatial_eta = image_warper.get_spatial_eta(t=framestep)
    if type(spatial_eta) != float:
        assert type(spatial_eta) == torch.Tensor
        assert spatial_eta.shape == (1,1,64,64)
        spatial_eta = spatial_eta.repeat(cross_attention_processor.video_length,1,1,1)
        spatial_eta[:-1] = 0.0
        print(spatial_eta.abs().sum())


    if cross_attention_processor.video_length == 2:
        z_pair = torch.cat([z_tau_orig, warped_latent], dim=0)
    else:
        z_pair = torch.cat([z_tau_orig, z_tau, warped_latent], dim=0)

    # Generate with cross attention
    SDM.pipeline.unet.set_attn_processor(cross_attention_processor)
    frame, mon = SDM.partial_generation(
        z=z_pair,
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        eta=spatial_eta if args.spatialeta else 0.0,
        guidance_scale=guidance_scale,
        negative_prompt=[prompt] * (len(z_pair)-1) + [negative_prompt],
    )

    save_output(frame, f"frame_{f:03}")

    # Update the z0 for the next iteration
    if not args.invert:
        print('Warning: not inverting')
        z_tau = warped_latent
    else:
        SDM.pipeline.unet.set_attn_processor(single_attention_processor)
        _, mon = SDM.partial_inversion(
            z=mon["latents"][-1][[-1]].detach().to(device),
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=0.0,
        )
        z_tau = mon["latents"][-1].detach().to(device)