"""
Use universal guidance, but with ensembling.
"""

import argparse
import os
import ctypes
from time import time, localtime, strftime

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from functools import partial
from torchvision.utils import make_grid, save_image

from guided_diffusion import dist_util, logger
from guided_diffusion.transfer_learning import create_pretrained_model
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
    pretrained_ImageNet_defaults,
)

def reward_setup_0(kwargs):
    return create_pretrained_model(**kwargs)

def reward_setup_1(kwargs):
    pass

def reward_setup_2(kwargs):
    pass

def reward_setup_3(kwargs):
    pass

def reward_setup_4(kwargs):
    pass

REWARD_SETUP_IDS = {
    0: id(reward_setup_0),
    1: id(reward_setup_1),
    2: id(reward_setup_2),
    3: id(reward_setup_3),
    4: id(reward_setup_4),
}


class OptimizerDetails:
    def __init__(self):
        self.num_recurrences = None
        self.operation_func = None
        self.optimizer = None # handle it on string level
        self.lr = None
        self.loss_func = None
        self.backward_steps = 0
        self.loss_cutoff = None
        self.lr_scheduler = None
        self.warm_start = None
        self.old_img = None
        self.fact = 0.5
        self.print = False
        self.print_every = None
        self.folder = None
        self.tv_loss = None
        self.use_forward = False
        self.forward_guidance_wt = 0
        self.other_guidance_func = None
        self.other_criterion = None
        self.original_guidance = False
        self.sampling_type = None
        self.loss_save = None


def main():
    args = create_argparser().parse_args()
    start = strftime("%m%d_%I:%M:%S", localtime(time()))

    dist_util.setup_dist()
    logger.configure(args.log_dir)
    
    time_tag = strftime("%m%d_%I:%M:%S", localtime(time()))
    logger.log(vars(args))

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading reward models...")

    # Load reward models
    # i-th entry of args.reward_types is an integer index,
    # indicating that you will use reward_setup_{} function to create an empty model,
    # with that `index` value inserted within {}.
    # Then the model weights (stored in the i-th entry of args.reward_paths) are copied into the empty model.
    reward_list = []
    for i, reward_path in enumerate(args.reward_paths):
        kwargs = args_to_dict(args, pretrained_ImageNet_defaults().keys())
        # kwargs["model_name"] = args.reward_names[i]
        # logger.log(kwargs)
        # reward_idx = args.reward_types[i]
        # reward_builder = ctypes.cast(REWARD_SETUP_IDS[reward_idx], ctypes.py_object).value
        # reward = reward_builder(kwargs)
        reward = create_pretrained_model(**kwargs)
        reward.load_state_dict(
            dist_util.load_state_dict(reward_path, map_location="cpu")
        )
        reward.to(dist_util.dev())
        reward.eval()
        reward_list.append(reward)


    log_sigmoid = th.nn.LogSigmoid()    
    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            log_probs = sum([log_sigmoid(reward(x_in, t)) for reward in reward_list]) / len(reward_list)
            return th.autograd.grad(log_probs.sum(), x_in)[0] * args.original_guidance_wt

    def model_fn(x, t, y=None, args=None, model=None):
        return model(x, t, y if args.class_cond else None)

    def operation_func(x, t=None):
        if t == None:
            return [reward(x) for reward in reward_list]
        return [reward(x, t) for reward in reward_list]
    
    def loss_func(reward_vals, *args):
        return sum([-log_sigmoid(rval) for rval in reward_vals]) / len(reward_list)

    ##### operation #####
    operation = OptimizerDetails()
    operation.num_recurrences = args.num_recurrences
    operation.operation_func = operation_func
    operation.other_guidance_func = None

    operation.optimizer = 'Adam'
    operation.lr = args.optim_lr 
    operation.loss_func = loss_func
    operation.other_criterion = None

    operation.backward_steps = args.backward_steps
    operation.loss_cutoff = args.optim_loss_cutoff # 0.00001
    operation.tv_loss = args.optim_tv_loss

    operation.use_forward = args.use_forward 
    operation.forward_guidance_wt = args.forward_guidance_wt

    operation.original_guidance = args.original_guidance
    operation.sampling_type = args.optim_sampling_type

    operation.warm_start = args.optim_warm_start #False
    operation.print = args.optim_print
    operation.print_every = 10
    operation.folder = logger.get_dir() # results_folder
    if args.optim_print:
        os.makedirs(f'{operation.folder}/samples', exist_ok=True)
    operation.Aug = args.optim_aug
    #####################

    logger.log("sampling... ")
    all_images = []
    # all_labels = []
    while len(all_images) * args.batch_size < args.num_samples:
        """
        See https://github.com/arpitbansal297/Universal-Guided-Diffusion/blob/b3af48f78d7bec105f3ea1579faf8602c520ed1e/Guided_Diffusion_Imagenet/Guided/helpers.py#L245
        """
        model_kwargs = {}
        if args.target_class is None:
            classes = th.randint(
                low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
            )
        else:
            classes = int(args.target_class) * th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype=th.int64)
        model_kwargs["y"] = classes

        sample_fn = diffusion.ddim_sample_loop_operation
        sample = sample_fn(
            partial(model_fn, model=model, args=args),
            (args.batch_size, args.image_channels, args.image_size, args.image_size), # self.shape,
            operated_image=None, 
            operation=operation,
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            cond_fn=cond_fn,
            # cond_fn=partial(cond_fn, classifier=classifier, args=args),
            device=dist_util.dev(),
            progress=args.progressive
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        # gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
        # dist.all_gather(gathered_labels, classes)
        # all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")

    hparam_tag = f"[bwd{args.backward_steps}_lr_{args.optim_lr}]"
    if args.use_forward:
        hparam_tag = hparam_tag + f"_[fwd_wt_{args.forward_guidance_wt}]"
    if args.original_guidance:
        hparam_tag = hparam_tag + f"_[org_wt_{args.original_guidance_wt}]"

    arr = np.concatenate(all_images, axis=0)
    arr = arr[: args.num_samples]
    # label_arr = np.concatenate(all_labels, axis=0)
    # label_arr = label_arr[: args.num_samples]
    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        out_path = os.path.join(logger.get_dir(), f"samples_{args.expr_name}_{hparam_tag}_{time_tag}_{shape_str}.npz")
        logger.log(f"saving to {out_path}")
        np.savez(out_path, arr)
        # np.savez(out_path, arr, label_arr)

    logger.log("sampling complete")

    sample_tensor = th.tensor(np.transpose(arr.astype(np.float32) / 255., (0, 3, 1, 2)))
    sample_save_dir = os.path.join(args.log_dir, f"sample_imgs_{args.expr_name}_{hparam_tag}_{time_tag}")
    if not os.path.isdir(sample_save_dir) and dist.get_rank() == 0:
        os.makedirs(sample_save_dir)

    dist.barrier()


    end = strftime("%m%d_%I:%M:%S", localtime(time()))
    logger.log(f"start : {start}, end : {end}")

    if dist.get_rank() == 0:
        for i in range(sample_tensor.size(0) // 100):
            img_grid = make_grid(sample_tensor[100*i: 100*(i + 1)], nrow=10)
            save_image(img_grid, os.path.join(sample_save_dir, f"batch_{i}.png"))


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        model_path="",
        image_channels=1,
        target_class=None,
        log_dir="",
        # operation
        num_recurrences=1,
        optim_lr=0.01,
        backward_steps=0, # multi-gpu sampling is supported only for 0 (otherwise, single gpu sampling is supported)
        optim_loss_cutoff=0.0, 
        optim_tv_loss=False, 
        use_forward=True,
        original_guidance=False,
        original_guidance_wt=0.0,
        forward_guidance_wt=1.0,
        optim_sampling_type='ddpm', 
        optim_warm_start=False,
        optim_print=False, # save samples
        progressive=False, # print tqdm
        optim_aug=None,
        expr_name="ensemble",
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(pretrained_ImageNet_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument("--reward_paths", type=str, nargs='+', 
                        help="list of path to each reward model"
    )
    parser.add_argument("--reward_types", type=int, nargs='+',
                        help="list of integers indicating the types of each reward model"
    )
    parser.add_argument("--reward_names", type=str, nargs='+',
                        help="list of reward model name"
    )
    return parser


if __name__ == "__main__":
    main()