import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from .helpers import (
    cosine_beta_schedule,
    extract,
    apply_conditioning,
)
from .temporal import TemporalUnet


@torch.no_grad()
def default_sample_fn(model, x, cond, t, td=None):
    model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, td=td)
    model_std = torch.exp(0.5 * model_log_variance)

    # no noise when t == 0
    noise = torch.randn_like(x)
    noise[t == 0] = 0

    return model_mean + model_std * noise

@torch.no_grad()
def low_temperature_sample_fn(model, x, cond, t, td=None):
    model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, td=td)
    model_std = torch.exp(0.5 * model_log_variance)

    noise = 0.5 * torch.randn_like(x)
    noise[t == 0] = 0

    return model_mean + model_std * noise

def make_timesteps(batch_size, i, device):
    t = torch.full((batch_size,), i, device=device, dtype=torch.long)
    return t


class GaussianDiffusion(nn.Module):
    def __init__(self, horizon, transition_dim, n_timesteps=1000,
        loss_type='l2', predict_epsilon=True, normalize_denoised=False,
        td_condition=False, condition_guidance_w=1.2, condition_dropout=0.25
    ):
        super().__init__()
        self.horizon = horizon
        self.transition_dim = transition_dim
        self.model = TemporalUnet(
            horizon=horizon,
            transition_dim=transition_dim,
            td_condition=td_condition,
            condition_dropout=condition_dropout
        )
        self.normalize_denoised = normalize_denoised
        if loss_type == 'l2':
            self.loss_fn = lambda x, y: F.mse_loss(x, y)

        self.td_condition = td_condition
        self.condition_guidance_w = condition_guidance_w
        self.condition_dropout = condition_dropout

        betas = cosine_beta_schedule(n_timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)
        self.predict_epsilon = predict_epsilon

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
            torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

    #------------------------------------------ sampling ------------------------------------------#

    def predict_start_from_noise(self, x_t, t, noise):
        '''
            if self.predict_epsilon, model output is (scaled) noise;
            otherwise, model predicts x0 directly
        '''
        if self.predict_epsilon:
            return (
                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            return noise

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, cond, t, td=None):
        if self.td_condition:
            epsilon_cond = self.model(x, cond, t, td, use_dropout=False)
            epsilon_uncond = self.model(x, cond, t, td, force_dropout=True)
            epsilon = epsilon_uncond + self.condition_guidance_w * (epsilon_cond - epsilon_uncond)
        else:
            epsilon = self.model(x, cond, t)
        x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)
        x_recon = apply_conditioning(x_recon, cond)

        if self.normalize_denoised:
            x_recon_1, x_recon_2 = torch.split(x_recon, self.transition_dim // 2, dim=-1)
            x_recon_1 = F.normalize(x_recon_1, p=2.0, dim=-1)
            x_recon_2 = F.normalize(x_recon_2, p=2.0, dim=-1)
            x_recon = torch.cat([x_recon_1, x_recon_2], dim=-1)
        else:
            assert RuntimeError()

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
                x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample_loop(self, shape, cond, td=None, sample_fn=low_temperature_sample_fn, return_process=False, **sample_kwargs):
        device = self.betas.device

        batch_size = shape[0]
        x = 0.5 * torch.randn(shape, device=device)
        x = apply_conditioning(x, cond)
        if return_process:
            process = [x]
        for i in reversed(range(0, self.n_timesteps)):
            t = make_timesteps(batch_size, i, device)
            x = sample_fn(self, x, cond, t, td, **sample_kwargs)
            x = apply_conditioning(x, cond)
            if return_process: process.append(x)
        if return_process:
            return x, process
        else:
            return x

    @torch.no_grad()
    def conditional_sample(self, cond, td=None, horizon=None, **sample_kwargs):
        '''
            conditions : [ (time, state), ... ]
        '''
        device = self.betas.device
        batch_size = len(cond[0])
        horizon = horizon or self.horizon
        shape = (batch_size, horizon, self.transition_dim)

        return self.p_sample_loop(shape, cond, td, **sample_kwargs)

    #------------------------------------------ training ------------------------------------------#

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sample = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        return sample

    def p_losses(self, x_start, cond, t, td):
        noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        x_noisy = apply_conditioning(x_noisy, cond)

        x_recon = self.model(x_noisy, cond, t, td)

        assert noise.shape == x_recon.shape

        if self.predict_epsilon:
            for k in cond.keys():
                x_recon[:, k, :] = 0.
                noise[:, k, :] = 0.
            loss = self.loss_fn(x_recon, noise)
        else:
            if self.normalize_denoised:
                x_recon_1, x_recon_2 = torch.split(x_recon, self.transition_dim // 2, dim=-1)
                x_recon_1 = F.normalize(x_recon_1, p=2.0, dim=-1)
                x_recon_2 = F.normalize(x_recon_2, p=2.0, dim=-1)
                x_recon = torch.cat([x_recon_1, x_recon_2], dim=-1)
            x_recon = apply_conditioning(x_recon, cond)
            loss = self.loss_fn(x_recon, x_start)

        return loss

    def loss(self, x, cond, td=None):
        batch_size = len(x)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
        return self.p_losses(x, cond, t, td)

    def forward(self, cond, *args, **kwargs):
        return self.conditional_sample(cond, *args, **kwargs)
