import numpy as np
import torch
from tqdm import trange
from torch import nn
import copy
import time

def to_device(x, device='cuda'):
    if torch.is_tensor(x):
        return x.to(device)
    elif type(x) is dict:
        return {k: to_device(v, device) for k, v in x.items()}
    else:
        print(f'Unrecognized type in `to_device`: {type(x)}')

def batch_to_device(batch, device='cuda:0'):
    vals = [to_device(getattr(batch, field), device) for field in batch._fields]
    return type(batch)(*vals)

@torch.jit.script
def compute_kernel(x, y):
    x_size = x.shape[0]#32
    y_size = y.shape[0]#32
    dim = x.shape[1]#16

    tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)#(32,32,16)
    tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)#(32,32,16)

    return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0) # (32,32)

@torch.jit.script
def compute_mmd(x, y):#(32,16)(32,16)
    x_kernel = compute_kernel(x, x)#(32,32)
    y_kernel = compute_kernel(y, y)#(32,32)
    xy_kernel = compute_kernel(x, y)#(32,32)
    return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)

class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class PrefDiffuserTrainer():
    def __init__(
        self,
        en_model, # encoder
        de_model, # diffusion model
        optimizer,
        batch_size,
        get_batch,
        device,
    ):
        super().__init__()
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.get_batch = get_batch
        self.diagnostics = dict()
        self.en_model = en_model
        self.de_model = de_model

        self.device = device
        self.ema = EMA(0.995)
        self.ema_model = copy.deepcopy(self.de_model)
        self.reset_parameters()
        self.step = 0
        self.count = 0


    def reset_parameters(self):
        self.ema_model.load_state_dict(self.de_model.state_dict())

    def step_ema(self):
        if self.step < 2000:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.de_model)

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        diffusion_losses, inv_losses, info_losses, simi_losses = [], [], [], []
        logs = dict()
        train_start = time.time()

        self.de_model.train()
        for i in trange(num_steps, desc='train_step'):
            diffusion_loss, inv_loss, info_loss, simi_loss = self.train_step()
            diffusion_losses.append(diffusion_loss)
            inv_losses.append(inv_loss)
            info_losses.append(info_loss)
            simi_losses.append(simi_loss)
            self.count += 1

        logs['training/time'] = time.time() - train_start
        logs['training/diffusion_loss_mean'] = np.mean(diffusion_losses)
        logs['training/diffusion_loss_std'] = np.std(diffusion_losses)
        logs['training/inv_loss_mean'] = np.mean(inv_losses)
        logs['training/inv_loss_std'] = np.std(inv_losses)
        logs['training/info_loss_mean'] = np.mean(info_losses)
        logs['training/info_loss_std'] = np.std(info_losses)
        logs['training/simi_loss_mean'] = np.mean(simi_losses)
        logs['training/simi_loss_std'] = np.std(simi_losses)
        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_step(self):
        states, actions, timesteps, mask = self.get_batch(self.batch_size) # s-(batch, 100, 17)
        conditions = states[:,0,:] # condition在当前状态下，用于做planning
        trajectories = torch.concat([actions, states], dim=-1)  # 将这段state和action合并
        if self.en_model.repre_type == 'vec':
            phis = self.en_model.forward(states, actions, timesteps, mask).detach() # (batch,16)
            diff_loss, inv_loss = self.de_model.loss(trajectories, conditions, phis) # compute loss
        elif self.en_model.repre_type == 'dist':
            phi_mean, phi_std = self.en_model.forward(states, actions, timesteps, mask)
            phi_dist = torch.distributions.MultivariateNormal(loc=phi_mean, 
                                                              covariance_matrix=torch.diag_embed(torch.exp(phi_std)))
            diff_loss, inv_loss, si_loss = self.de_model.loss(trajectories, conditions, phi_mean, phi_std) # compute loss

        # mutual information loss between x_0 and phi
        # if (self.en_model.repre_type == 'dist') and (self.count == 10):
        if self.en_model.repre_type == 'dist':
            generated_phi_mean, generated_phi_std, simi_loss = self.de_model.generate(conditions, phi_mean) # (batch, 16)
            # generated_phi_mean, simi_loss = self.de_model.generate(conditions, phi_mean) # (batch, 16)
            generated_phi_dist = torch.distributions.MultivariateNormal(loc=generated_phi_mean, 
                                                                covariance_matrix=torch.diag_embed(torch.exp(generated_phi_std)))
            info_loss = torch.distributions.kl_divergence(generated_phi_dist, phi_dist).mean()
            # # info_loss = compute_mmd(phi_mean, generated_phi_mean)
            diffusion_loss = diff_loss + inv_loss + 0.1 * info_loss# + 0.1 * simi_loss.abs()
            # diffusion_loss = diff_loss + inv_loss# + 0.1 * simi_loss
        else:
            diffusion_loss = diff_loss + inv_loss

        self.optimizer.zero_grad()
        diffusion_loss.backward()
        
        # diffusion_loss.backward(retain_graph=True)
        # # 计算总体梯度
        # total_gradient = 0.0
        # for param in self.de_model.parameters():
        #     if param.grad is not None:
        #         total_gradient += param.grad.abs().sum().item()
        # print("Total Gradient Magnitude:", total_gradient, diffusion_loss)
        # info_loss *= 0.1
        # info_loss.backward()
        # total_gradient = 0.0
        # for param in self.de_model.parameters():
        #     if param.grad is not None:
        #         total_gradient += param.grad.abs().sum().item()
        # print("Total Gradient Magnitude--------:", total_gradient, info_loss)
        
        self.optimizer.step()
        if self.step % 10 == 0:
            self.step_ema()
        self.step += 1

        # if (self.en_model.repre_type == 'dist') and (self.count == 10):
            # self.count = 0
        if self.en_model.repre_type == 'dist':
        #     # return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), 0.0, simi_loss.detach().cpu().item()*10
            return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), info_loss.detach().cpu().item(), simi_loss.detach().cpu().item()
        else:
            return diff_loss.detach().cpu().item(), inv_loss.detach().cpu().item(), 0.0, 0.0