import os
from omegaconf import OmegaConf
import argparse

import torch
import torchvision

from pipeline.model_scope_vlcm_pipeline import ModelScopeVideoLCMPipeline
from utils.common_utils import set_torch_2_attn
from utils.utils import instantiate_from_config
from scripts.evaluation.funcs import load_model_checkpoint, batch_ddim_sampling
from scheduler.vlcm_scheduler import VLCMScheduler
from pipeline.vlcm_pipeline import VideoLatentConsistencyModelPipeline

from transformers import CLIPTokenizer, CLIPTextModel
from model_scope.utils.lora_handler import LoraHandler
from model_scope.models.unet_3d_condition import UNet3DConditionModel
from diffusers.models import AutoencoderKL


def save_videos(batch_tensors, save_dir, prompt, index, fps=16):
    assert len(index) == len(batch_tensors)
    for idx, vid_tensor in zip(index, batch_tensors):
        video = vid_tensor.detach().cpu()
        video = torch.clamp(video.float(), -1.0, 1.0)
        video = video.permute(1, 0, 2, 3)  # t,c,h,w
        video = (video + 1.0) / 2.0
        video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)

        savepath = os.path.join(save_dir, f"{prompt}-{idx}.mp4")
        torchvision.io.write_video(
            savepath, video, fps=fps, video_codec="h264", options={"crf": "10"}
        )


@torch.no_grad()
def main(args):
    save_dir = os.path.join(
        args.save_root,
        args.run_name,
        f"ckpt_{args.num_iters}",
        f"{args.num_infer_steps}_steps",
    )
    os.makedirs(save_dir, exist_ok=True)

    if args.seed:
        torch.manual_seed(args.seed)

    # read prompt list
    with open(f"./prompts/all_dimension.txt", "r") as f:
        prompt_list = f.readlines()
    prompt_list = [prompt.strip() for prompt in prompt_list]

    if "model_scope" in args.run_name:
        pretrained_model_path = "ali-vilab/text-to-video-ms-1.7b"
        tokenizer = CLIPTokenizer.from_pretrained(
            pretrained_model_path, subfolder="tokenizer"
        )
        text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_path, subfolder="text_encoder"
        )
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
        teacher_unet = UNet3DConditionModel.from_pretrained(
            pretrained_model_path, subfolder="unet"
        )

        time_cond_proj_dim = 256
        unet = UNet3DConditionModel.from_config(
            teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim
        )
        # load teacher_unet weights into unet
        unet.load_state_dict(teacher_unet.state_dict(), strict=False)
        del teacher_unet
        set_torch_2_attn(unet)
        use_unet_lora = True
        lora_manager = LoraHandler(
            version="cloneofsimo",
            use_unet_lora=use_unet_lora,
            save_for_webui=True,
        )
        if os.path.exists(f"output/{args.run_name}/checkpoint-{args.num_iters}/unet_lora.pt"):
            lora_manager.add_lora_to_model(
                use_unet_lora,
                unet,
                lora_manager.unet_replace_modules,
                lora_path=f"output/{args.run_name}/checkpoint-{args.num_iters}/unet_lora.pt",
                dropout=0.1,
                r=32,
            )
        else:
            lora_manager.add_lora_to_model(
                use_unet_lora,
                unet,
                lora_manager.unet_replace_modules,
                dropout=0.1,
                r=32,
            )
            state_dict = torch.load(
                f"output/{args.run_name}/checkpoint-{args.num_iters}/unet.pt"
            )
            unet.load_state_dict(state_dict)

        lora_manager.deactivate_lora_train([unet], True)
        noise_scheduler = VLCMScheduler(
            linear_start=0.00085,
            linear_end=0.012,
        )
        pipeline = ModelScopeVideoLCMPipeline(
            unet=unet,
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            scheduler=noise_scheduler,
        )
        pipeline.to("cuda")
    else:
        config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
        model_config = config.pop("model", OmegaConf.create())
        pretrained_t2v = instantiate_from_config(model_config)
        pretrained_t2v = load_model_checkpoint(
            pretrained_t2v,
            "PATH_TO_THE_VIDEOCRAFTER2_CHECKPOINT",
        )

        unet_config = model_config["params"]["unet_config"]
        unet_config["params"]["time_cond_proj_dim"] = 256
        unet = instantiate_from_config(unet_config)

        if "lora" in args.run_name:
            unet.load_state_dict(
                pretrained_t2v.model.diffusion_model.state_dict(), strict=False
            )

            use_unet_lora = True
            lora_manager = LoraHandler(
                version="cloneofsimo",
                use_unet_lora=use_unet_lora,
                save_for_webui=True,
                unet_replace_modules=["UNetModel"],
            )
            lora_manager.add_lora_to_model(
                use_unet_lora,
                unet,
                lora_manager.unet_replace_modules,
                lora_path=f"output/{args.run_name}/checkpoint-{args.num_iters}/unet_lora.pt",
                dropout=0.1,
                r=64,
            )
        else:
            unet_state_dict = torch.load(
                f"output/{args.run_name}/checkpoint-{args.num_iters}/unet.pt"
            )
            unet.load_state_dict(unet_state_dict)

        unet.eval()
        pretrained_t2v.model.diffusion_model = unet
        scheduler = VLCMScheduler(
            linear_start=model_config["params"]["linear_start"],
            linear_end=model_config["params"]["linear_end"],
        )
        pipeline = VideoLatentConsistencyModelPipeline(
            pretrained_t2v, scheduler, model_config
        )

    pipeline.to("cuda")
    for prompt in prompt_list:
        if "model_scope" in args.run_name:
            videos = pipeline(
                prompt=prompt,
                frames=16,
                num_inference_steps=args.num_infer_steps,
                num_videos_per_prompt=5,
                generator=torch.Generator(device="cuda").manual_seed(args.seed),
            )
            save_videos(videos, save_dir, prompt, [0, 1, 2, 3, 4], fps=8)
        else:
            videos = pipeline(
                prompt=prompt,
                frames=16,
                num_inference_steps=args.num_infer_steps,
                num_videos_per_prompt=3,
                generator=torch.Generator(device="cuda").manual_seed(args.seed),
            )
            save_videos(videos, save_dir, prompt, [0, 1, 2], fps=16)

            videos = pipeline(
                prompt=prompt,
                frames=16,
                num_inference_steps=args.num_infer_steps,
                num_videos_per_prompt=2,
                generator=torch.Generator(device="cuda").manual_seed(args.seed + 1),
            )
            save_videos(videos, save_dir, prompt, [3, 4], fps=16)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default="sampled_vbench_videos")
    parser.add_argument(
        "--run_name",
        type=str,
        default="rg_vlcm_lora_vc2_vi_clip2_5_hpsv2_2_consec_frame_motion",
    )
    parser.add_argument("--num_iters", type=int, default=10000)
    parser.add_argument("--num_infer_steps", type=int, default=4)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    main(args)
