
import torch
import torch.nn as nn
import math
from torchdiffeq import odeint, odeint_adjoint
from .model import Scale
import numpy as np
from diffusion_utils import *
from einops import rearrange, repeat

MAX_NUM_STEPS = 1000  # Maximum number of steps for ODE solver

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, MLP=True, frequency_embedding_size=256, mlp_only_linear=False):
        super().__init__()
        self.mlp = nn.Identity()
        if MLP:
            if mlp_only_linear:
                self.mlp = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
            else:
                self.mlp = nn.Sequential(
                    nn.Linear(frequency_embedding_size, hidden_size, bias=True),
                    nn.SiLU(),
                    nn.Linear(hidden_size, hidden_size, bias=True),
                )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0,
                                                 end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class tMLPBlock(nn.Module):
    def __init__(self, t_dim, hidden_dim, norm):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim + t_dim, hidden_dim)
        self.norm1 = norm(hidden_dim)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.norm1(out)
        return out

class tLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim + 1, out_dim)
    
    def forward(self, t, x):
        x = torch.cat([x, t], dim=1)
        out = self.fc1(x)
        return out
    
class MLPBlock(nn.Module):
    def __init__(self, in_dim, out_dim, add_norm):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.norm1 = add_norm(out_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.norm1(out)
        out = self.relu(out)
        return out

class tSequential(nn.Sequential):
    def forward(self, t, x):
        for module in self:
            if isinstance(module, tLinear):
                x = module(t, x)
            else:
                x = module(x)
        return x


class ODEFunc(nn.Module):
    """MLP modeling the derivative of ODE system.

    Parameters
    ----------
    device : torch.device

    data_dim : int
        Dimension of data.

    hidden_dim : int
        Dimension of hidden layers.

    augment_dim: int
        Dimension of augmentation. If 0 does not augment ODE, otherwise augments
        it with augment_dim dimensions.

    time_dependent : bool
        If True adds time as input, making ODE time dependent.

    non_linearity : string
        One of 'relu' and 'softplus'
    """
    def __init__(self, device, data_dim, hidden_dim, augment_dim=0,
                 time_dependent=False, non_linearity='relu', time_modulation='none',
                 add_norm=None, h_add_blocks=0,):
        super(ODEFunc, self).__init__()
        self.device = device
        self.augment_dim = augment_dim
        self.data_dim = data_dim
        self.input_dim = data_dim + augment_dim
        t_dim = 1
        self.hidden_dim = hidden_dim
        self.nfe = 0  # Number of function evaluations
        self.time_modulation = time_modulation
        add_norm = lambda dim: nn.Identity()
        if self.time_modulation == 'fourier':
            assert 0
        elif self.time_modulation == 'adaln':
            assert 0

        self.time_dependent = time_dependent
        if time_dependent:
            self.fc1 = nn.Linear(self.input_dim + t_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim + t_dim, hidden_dim)
            self.fc3 = nn.Linear(hidden_dim + t_dim, self.input_dim)
        else:
            self.fc1 = nn.Linear(self.input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, hidden_dim)
            self.fc3 = nn.Linear(hidden_dim, self.input_dim)
        self.norm1 = add_norm(hidden_dim)
        self.norm2 = add_norm(hidden_dim)

        if non_linearity == 'relu':
            self.non_linearity = nn.ReLU(inplace=False)
        elif non_linearity == 'softplus':
            self.non_linearity = nn.Softplus()
        elif non_linearity in ['silu', 'swish']:
            self.non_linearity = nn.SiLU()
        
        self.h_add_blocks = h_add_blocks
        if h_add_blocks > 0:
            self.tmlp_blocks = nn.ModuleList([tMLPBlock(t_dim, hidden_dim, add_norm) for _ in range(h_add_blocks)])

    def forward(self, t, x):
        """
        Parameters
        ----------
        t : torch.Tensor
            Current time. Shape (1,).

        x : torch.Tensor
            Shape (batch_size, input_dim)
        """
        # Forward pass of model corresponds to one function evaluation, so
        # increment counter
        self.nfe += 1

        if self.time_dependent:
            t_vec = torch.ones(x.shape[0], 1).to(self.device) * t
            out = self.fc1(torch.cat([t_vec, x], 1))
            out = self.norm1(out)
            out = self.non_linearity(out)
            out = self.fc2(torch.cat([t_vec, out], 1))
            out = self.norm2(out)
            out = self.non_linearity(out)

            if self.h_add_blocks > 0:
                for i in range(self.h_add_blocks):
                    out = self.tmlp_blocks[i](torch.cat([t_vec, out], 1))
                    out = self.non_linearity(out)
            
            out = self.fc3(torch.cat([t_vec, out], 1))
        else:
            out = self.fc1(x)
            out = self.norm1(out)
            out = self.non_linearity(out)
            out = self.fc2(out)
            out = self.norm2(out)
            out = self.non_linearity(out)
            out = self.fc3(out)

        if self.time_modulation == 'adaln':
            t_emb = self.t_emb(t.view(-1))
            shift, scale, gate = self.adaln(t_emb).chunk(3, dim=-1)
            out = out + gate * modulate(self.norm(out), shift, scale)
        return out


class ODEBlock(nn.Module):
    """Solves ODE defined by odefunc.

    Parameters
    ----------
    device : torch.device

    odefunc : ODEFunc instance or anode.conv_models.ConvODEFunc instance
        Function defining dynamics of system.

    is_conv : bool
        If True, treats odefunc as a convolutional model.

    tol : float
        Error tolerance.

    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, odefunc, is_conv=False, tol=1e-3, adjoint=False, t_final=1.0):
        super(ODEBlock, self).__init__()
        self.adjoint = adjoint
        self.device = device
        self.is_conv = is_conv
        self.odefunc = odefunc
        self.tol = tol
        self.t_final = t_final

    def forward(self, x, eval_times=None, method='dopri5'):
        """Solves ODE starting from x.

        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)

        eval_times : None or torch.Tensor
            If None, returns solution of ODE at final time t=1. If torch.Tensor
            then returns full ODE trajectory evaluated at points in eval_times.
        """
        # Forward pass corresponds to solving ODE, so reset number of function
        # evaluations counter
        self.odefunc.nfe = 0

        if eval_times is None:
            integration_time = torch.tensor([0, self.t_final]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        x_aug = x
        
        options = None if method == 'euler' else {'max_num_steps': MAX_NUM_STEPS}
        if self.adjoint:
            out = odeint_adjoint(self.odefunc, x_aug, integration_time,
                                 rtol=self.tol, atol=self.tol, method=method,
                                 options=options)
        else:
            out = odeint(self.odefunc, x_aug, integration_time,
                         rtol=self.tol, atol=self.tol, method=method,
                         options=options)

        if eval_times is None:
            return out[1]  # Return only final time
        else:
            return out

    def trajectory(self, x, timesteps, method='dopri5'):
        """Returns ODE trajectory.

        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)

        timesteps : int
            Number of timesteps in trajectory.
        """
        if isinstance(timesteps, int):
            integration_time = torch.linspace(0., self.t_final, timesteps)
        elif isinstance(timesteps, torch.Tensor):
            integration_time = timesteps # be careful with this
        else:
            raise ValueError('timesteps should be int or torch.Tensor')
        return self.forward(x, eval_times=integration_time, method=method)

class PaddingLayer(nn.Module):
    def __init__(self, input_dim, output_dim, mode=0):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.mode = mode

    def forward(self, x):
        B = x.shape[0]
        x = x.view(B, -1)
        if self.mode == 0:
            return torch.cat([x, torch.zeros(x.shape[0], self.output_dim - self.input_dim).to(x.device)], dim=1)
        elif self.mode == 1:
            return torch.cat([torch.zeros(x.shape[0], self.output_dim - self.input_dim).to(x.device), x], dim=1)
        else:
            raise ValueError(f'Invalid mode {self.mode}')

class SlicingLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, x):
        return x[:, -self.output_dim:]

class ODENet(nn.Module):
    """An ODEBlock followed by a Linear layer.

    Parameters
    ----------
    device : torch.device

    data_dim : int
        Dimension of data.

    hidden_dim : int
        Dimension of hidden layers.

    output_dim : int
        Dimension of output after hidden layer. Should be 1 for regression or
        num_classes for classification.

    augment_dim: int
        Dimension of augmentation. If 0 does not augment ODE, otherwise augments
        it with augment_dim dimensions.

    time_dependent : bool
        If True adds time as input, making ODE time dependent.

    non_linearity : string
        One of 'relu' and 'softplus'

    tol : float
        Error tolerance.

    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, data_dim, hidden_dim, latent_dim=None, output_dim=1,
                 augment_dim=0, time_dependent=False, non_linearity='relu',
                 tol=1e-3, adjoint=False, in_proj=False, out_proj=False, label_proj=False,
                 proj_norm='none', in_proj_scale=None, label_proj_scale=None, t_final=1.,
                 time_modulation='none', f_add_blocks=0, h_add_blocks=0, g_add_blocks=0):
        super().__init__()
        self.device = device
        data_dim = data_dim + output_dim
        self.data_dim = data_dim
        self.hidden_dim = hidden_dim
        self.augment_dim = augment_dim
        self.output_dim = output_dim
        self.time_dependent = time_dependent
        self.tol = tol

        # diffusion params
        self.betas = make_beta_schedule('linear', 1000, 1e-4, 0.02).to(device)
        self.betas_sqrt = torch.sqrt(self.betas)
        self.alphas = 1 - self.betas
        self.one_minus_betas_sqrt = torch.sqrt(self.alphas)
        alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod)
        self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1, device=device), alphas_cumprod[:-1]], dim=0
        )
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.posterior_mean_coeff_1 = (
                self.betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_mean_coeff_2 = (
                torch.sqrt(self.alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        )
        posterior_variance = (
                self.betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_variance = posterior_variance
        self.logvar = self.betas.log()

        self.tau = None  # precision fo test NLL computation

        self.T = 1000

        if in_proj == 'padding':
            assert label_proj == 'padding', 'padding only for both in_proj and label_proj'
            assert out_proj == 'padding', 'padding only for both in_proj and out_proj'
            data_dim = data_dim + output_dim

        if latent_dim is None or augment_dim > 0:
            latent_dim = data_dim
        self.latent_dim = latent_dim

        if proj_norm == 'bn':
            def add_norm(dim): return nn.BatchNorm1d(dim, affine=True)
        elif proj_norm == 'ln':
            def add_norm(dim): return nn.LayerNorm(dim, elementwise_affine=True)
        else:
            def add_norm(dim): return nn.Identity()
            assert proj_norm == 'none'

        self.odefunc = ODEFunc(device, self.latent_dim, hidden_dim, augment_dim,
                                time_dependent, non_linearity, time_modulation=time_modulation, add_norm=lambda dim: nn.Identity(),
                                h_add_blocks=h_add_blocks)

        self.adjoint = adjoint
        
        if non_linearity == 'relu':
            act_fn = nn.ReLU
        elif non_linearity == 'softplus':
            act_fn = nn.Softplus
        elif non_linearity in ['silu', 'swish']:
            act_fn = nn.SiLU

        if isinstance(in_proj, nn.Module):
            self.in_projection = in_proj
        elif self.augment_dim > 0:
            latent_dim = self.odefunc.input_dim
            self.in_projection = PaddingLayer(self.data_dim, latent_dim, mode=0)
        elif in_proj == 'identity' or in_proj is False:
            self.in_projection = nn.Flatten()
        elif in_proj == 'linear' or in_proj is True:
            in_proj_layers = [nn.Flatten(), tLinear(data_dim, self.latent_dim)]
            self.in_projection = tSequential(
                *in_proj_layers
            )
        elif in_proj == 'mlp':
            latent_dim = self.odefunc.input_dim
            in_projection = [
                nn.Flatten(),
                tLinear(data_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                tLinear(latent_dim, latent_dim),
            ]
            for _ in range(f_add_blocks):
                in_projection.extend([
                    add_norm(latent_dim),
                    nn.ReLU(),
                    tLinear(latent_dim, latent_dim),
                ])
            self.in_projection = tSequential(
                *in_projection
            )
        elif in_proj == 'mlp2':
            latent_dim = self.odefunc.input_dim
            self.in_projection = tSequential(
                nn.Flatten(),
                tLinear(data_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                tLinear(latent_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                tLinear(latent_dim, latent_dim),
            )
        elif in_proj == 'padding':
            in_dim = data_dim - output_dim
            out_dim = data_dim
            self.in_projection = PaddingLayer(in_dim, out_dim, mode=0)
        elif in_proj in ['conv1x1', 'conv3x3']:
            assert data_dim == 28*28, data_dim # only for MNIST
            ksize = 1 if in_proj == 'conv1x1' else 3
            self.in_projection = nn.Sequential(
                nn.Conv2d(1, 1, ksize, padding=ksize//2),
                nn.Flatten(),
            )
        else:
            raise ValueError(f'Invalid in_proj {type(in_proj)} {in_proj}')
        
        if isinstance(out_proj, nn.Module):
            self.out_projection = out_proj
        elif out_proj == 'linear' or out_proj is True:
            self.out_projection = nn.Linear(self.odefunc.input_dim,
                                          self.output_dim,
                                          )
        elif out_proj == 'identity' or out_proj is False:
            self.out_projection = nn.Identity()
        elif out_proj == 'mlp':
            latent_dim = self.odefunc.input_dim
            self.out_projection = nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, self.output_dim)
            )
        elif out_proj == 'mlp2':
            latent_dim = self.odefunc.input_dim
            self.out_projection = nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, latent_dim),
                add_norm(latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, self.output_dim)
            )
        elif out_proj == 'padding':
            in_dim = data_dim
            out_dim = output_dim
            self.out_projection = SlicingLayer(in_dim, out_dim)
        else:
            raise ValueError(f'Invalid out_proj {type(out_proj)} {out_proj}')
        
        if g_add_blocks > 0:
            out_proj = [MLPBlock(latent_dim, latent_dim, add_norm) for _ in range(g_add_blocks)]
            out_proj += [nn.Linear(latent_dim, self.output_dim)]
            self.out_projection = nn.Sequential(*out_proj)

        if isinstance(label_proj, nn.Module):
            self.label_projection = label_proj
        elif label_proj == 'linear' or label_proj is True:
            latent_dim = self.odefunc.input_dim
            self.label_projection = nn.Linear(self.output_dim, latent_dim)
        elif label_proj == 'mlp':
            latent_dim = self.odefunc.input_dim
            self.label_projection = nn.Sequential(
                nn.Linear(self.output_dim, self.latent_dim),
                add_norm(self.latent_dim),
                nn.ReLU(),
                nn.Linear(self.latent_dim, self.latent_dim),
            )
        elif label_proj == 'padding':
            in_dim = output_dim
            out_dim = data_dim
            self.label_projection = PaddingLayer(in_dim, out_dim, mode=1)
        else:
            self.label_projection = nn.Identity()        
    
    def forward(self, x, yt, y_0_hat, t, return_features=False):
        # repeat yt
        if not hasattr(self, 'latent_shape'):
            self.latent_shape = yt.shape
        B = x.shape[0]
        x = x.view(B, -1)
        t = t.float() / self.T
        t = repeat(t, 'b -> b c', c=1)
        # concat x, yt, t
        x = torch.cat([x, yt], dim=1)
        x = self.in_projection(t, x)
        features = self.odefunc(t, x)
        pred = self.out_projection(features)
        if return_features:
            return features, pred
        return pred

    def get_traj(self, x, timesteps=100+1, method='ddim', eta=0, noise='rand'):
        '''
        timestep: int
            note: should do +1 to timesteps since it is both start & end inclusive.
        '''
        # TODO: DDIM or DDPM
        B = x.shape[0]
        shape = (B, *self.latent_shape[1:])
        y = torch.zeros(shape, device=x.device)
        if noise == 'rand':
            noise = torch.randn_like(y)
        elif noise == 'zeros':
            noise = torch.zeros_like(y)
        traj = self.ddim_sample(x, y, n_steps=timesteps, eta=eta, noise=noise)
        return traj, traj[-1]
    
    def q_sample(self, y, t, noise=None):
        return q_sample(y, torch.zeros_like(y), self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, t, noise)
    
    def ddpm_sample(self, x, y1, n_steps=1000, noise=None):
        y_0_hat = torch.zeros_like(y1)
        y_T_mean = torch.zeros_like(y1)
        return p_sample_loop(self, x, y_0_hat, y_T_mean, n_steps, self.alphas, self.one_minus_alphas_bar_sqrt)
    
    def ddim_sample(self, x, y1, n_steps=1000, noise=None, eta=0):
        y_0_hat = torch.zeros_like(y1)
        y_T_mean = torch.zeros_like(y1)
        return ddim_p_sample_loop(self, x, y_0_hat, y_T_mean, n_steps, self.alphas, self.one_minus_alphas_bar_sqrt, noise, eta=eta)
    

class PreservingLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias=True)

    def forward(self, x):
        # Compute the transformation
        z = self.fc(x)
        
        # Compute the Jacobian determinant
        jacobian = self.fc.weight
        det_jacobian = torch.det(jacobian.T @ jacobian)
        
        # Normalize the output
        z_normalized = z / torch.sqrt(det_jacobian).unsqueeze(-1)
        
        return z_normalized
