import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint
from models.model import Flatten, AppendRepeat
from ANODE.anode.conv_models import ConvODEFunc, MNISTConvODEFunc

MAX_NUM_STEPS = 1000  # Maximum number of steps for ODE solver


class ChannelAugment(nn.Module):
    """
    Augment the input by padding zero to the channel dimension.
    """
    def __init__(self, augment_dim):
        super(ChannelAugment, self).__init__()
        self.augment_dim = augment_dim

    def forward(self, x):
        """
        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, channels, height, width)
        """
        batch_size, channels, height, width = x.shape
        aug = torch.zeros(batch_size, self.augment_dim, height, width).to(x.device)
        return torch.cat([x, aug], 1)


class ConvODENet(nn.Module):
    """Creates an ODEBlock with a convolutional ODEFunc followed by a Linear
    layer.

    Parameters
    ----------
    device : torch.device

    img_size : tuple of ints
        Tuple of (channels, height, width).

    num_filters : int
        Number of convolutional filters.

    output_dim : int
        Dimension of output after hidden layer. Should be 1 for regression or
        num_classes for classification.

    augment_dim: int
        Number of augmentation channels to add. If 0 does not augment ODE.

    time_dependent : bool
        If True adds time as input, making ODE time dependent.

    non_linearity : string
        One of 'relu' and 'softplus'

    tol : float
        Error tolerance.

    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, img_size, num_filters, output_dim=10, augment_dim=0, time_dependent=True, 
                 non_linearity='relu', tol=1e-3, adjoint=False, label_proj_strategy='repeat', in_proj=False, mid_conv=1):
        super(ConvODENet, self).__init__()
        self.device =device
        self.img_size = img_size
        self.num_filters = num_filters
        self.augment_dim = augment_dim
        self.output_dim = output_dim
        self.flattened_dim = (img_size[0] + augment_dim) * img_size[1] * img_size[2]
        self.time_dependent = time_dependent
        self.tol = tol
        
        assert augment_dim == 0, 'Not implemented yet'
        if isinstance(in_proj, nn.Module):
            self.in_projection = in_proj
        elif in_proj == 'identity' or in_proj is False:
            self.in_projection = nn.Identity()
        elif in_proj == 'linear' or in_proj is True:
            self.in_projection = nn.Sequential(
                nn.Flatten(), 
                nn.Linear(self.flattened_dim, self.flattened_dim),
                nn.Unflatten(1, (img_size[0] + augment_dim, img_size[1], img_size[2])),
            )
        elif in_proj in ['conv1x1', 'conv3x3']:
            assert self.flattened_dim == 28*28, self.flattened_dim  # only for MNIST
            ksize = 1 if in_proj == 'conv1x1' else 3
            self.in_projection = nn.Sequential(
                nn.Conv2d(1, 1, ksize, padding=ksize//2),
            )
        else:
            raise ValueError(f'Invalid in_proj {type(in_proj)} {in_proj}')
        # self.in_projection = ChannelAugment(augment_dim)
        
        odefunc = ConvODEFunc(device, img_size, num_filters, augment_dim=augment_dim,
                              time_dependent=time_dependent, non_linearity=non_linearity, mid_conv=mid_conv).to(self.device)
        # self.odefunc = odefunc
        self.odeblock = ODEBlock(device, odefunc, is_conv=True, tol=tol,
                                 adjoint=adjoint).to(self.device)

        self.out_projection = nn.Sequential(
            Flatten(),
            nn.Linear(self.flattened_dim, self.output_dim).to(self.device),
        )

        if label_proj_strategy == 'repeat':
            self.label_projection = nn.Sequential(
                nn.Linear(10, img_size[0] + augment_dim),
                AppendRepeat(img_size[1:]),
            )
        elif label_proj_strategy == 'reshape':
            self.label_projection = nn.Sequential(
                nn.Linear(10, self.flattened_dim),
                nn.Unflatten(1, (img_size[0] + augment_dim, img_size[1], img_size[2])),
            )
        elif label_proj_strategy == 'mlp':
            raise NotImplementedError
        else:
            raise ValueError(f'Unknown label_proj_strategy: {label_proj_strategy}')

    def forward(self, x, return_features=False, method="dopri5"):
        x = self.in_projection(x)
        features = self.odeblock(x.cuda(), method=method)
        pred = self.out_projection(features)
        if return_features:
            return features, pred
        return pred

    def get_traj(self, x, timesteps=100+1, method='dopri5'):
        '''
        note: should +1 to timesteps since it is both start & end inclusive.
        '''
        x = self.in_projection(x)
        out = self.odeblock.trajectory(x, timesteps, method=method)
        return out, self.out_projection(out[-1])
    
    def pred_v(self, z, t):
        self.odeblock.odefunc.nfe = 0
        return self.odeblock.odefunc(t, z)
    
class ODEBlock(nn.Module):
    """Solves ODE defined by odefunc.

    Parameters
    ----------
    device : torch.device

    odefunc : ODEFunc instance or anode.conv_models.ConvODEFunc instance
        Function defining dynamics of system.

    is_conv : bool
        If True, treats odefunc as a convolutional model.

    tol : float
        Error tolerance.

    adjoint : bool
        If True calculates gradient with adjoint method, otherwise
        backpropagates directly through operations of ODE solver.
    """
    def __init__(self, device, odefunc, is_conv=False, tol=1e-3, adjoint=False):
        super(ODEBlock, self).__init__()
        self.adjoint = adjoint
        self.device = device
        self.is_conv = is_conv
        self.odefunc = odefunc
        self.tol = tol

    def forward(self, x, eval_times=None, method='dopri5'):
        """Solves ODE starting from x.

        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)

        eval_times : None or torch.Tensor
            If None, returns solution of ODE at final time t=1. If torch.Tensor
            then returns full ODE trajectory evaluated at points in eval_times.
        """
        # Forward pass corresponds to solving ODE, so reset number of function
        # evaluations counter
        self.odefunc.nfe = 0

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)


        x_aug = x # explicitly handle augment dim with in_projection

        # if self.odefunc.augment_dim > 0:
        #     if self.is_conv:
        #         # Add augmentation
        #         batch_size, channels, height, width = x.shape
        #         aug = torch.zeros(batch_size, self.odefunc.augment_dim,
        #                           height, width).to(self.device)
        #         # Shape (batch_size, channels + augment_dim, height, width)
        #         x_aug = torch.cat([x, aug], 1)
        #     else:
        #         # Add augmentation
        #         aug = torch.zeros(x.shape[0], self.odefunc.augment_dim).to(self.device)
        #         # Shape (batch_size, data_dim + augment_dim)
        #         x_aug = torch.cat([x, aug], 1)
        # else:
        #     x_aug = x

        options = None if method == 'euler' else {'max_num_steps': MAX_NUM_STEPS}
        if self.adjoint:
            out = odeint_adjoint(self.odefunc, x_aug, integration_time,
                                 rtol=self.tol, atol=self.tol, method=method,
                                 options=options)
        else:
            out = odeint(self.odefunc, x_aug, integration_time,
                         rtol=self.tol, atol=self.tol, method=method,
                         options=options)

        if eval_times is None:
            return out[1]  # Return only final time
        else:
            return out

    def trajectory(self, x, timesteps, method='dopri5'):
        """Returns ODE trajectory.

        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)

        timesteps : int
            Number of timesteps in trajectory.
        """
        if isinstance(timesteps, int):
            integration_time = torch.linspace(0., 1., timesteps)
        elif isinstance(timesteps, torch.Tensor):
            integration_time = timesteps # be careful with this
        else:
            raise ValueError('timesteps should be int or torch.Tensor')
        
        return self.forward(x, eval_times=integration_time, method=method)


class MNISTConvODENet(ConvODENet):
    '''
    ConvODENet but not using ChannelAugment and using dynamics model defined in num_filters dim.
    '''
    def __init__(self, device, img_size, num_filters, output_dim=10, time_dependent=True,
                non_linearity='relu', tol=1e-3, adjoint=False, label_proj_strategy='repeat', in_proj=False, mid_conv=1):
        super(ConvODENet, self).__init__()
        self.device = device
        self.img_size = img_size
        self.num_filters = num_filters
        self.output_dim = output_dim
        self.flattened_dim = img_size[0] * img_size[1] * img_size[2]
        self.latent_dim = num_filters * img_size[1] * img_size[2]
        self.time_dependent = time_dependent
        self.tol = tol

        if isinstance(in_proj, nn.Module):
                self.in_projection = in_proj
        elif in_proj == 'identity' or in_proj is False:
            self.in_projection = nn.Identity() # maybe shape will not match
        elif in_proj == 'linear' or in_proj is True:
            self.in_projection = nn.Sequential(
                nn.Flatten(),
                nn.Linear(self.flattened_dim, self.latent_dim),
                nn.Unflatten(1, (num_filters, img_size[1], img_size[2])),
            )
        elif in_proj in ['conv1x1', 'conv3x3']:
            assert self.flattened_dim == 28*28, self.flattened_dim  # only for MNIST
            ksize = 1 if in_proj == 'conv1x1' else 3
            self.in_projection = nn.Sequential(
                nn.Conv2d(1, num_filters, ksize, padding=ksize//2),
            )
        else:
            raise ValueError(f'Invalid in_proj {type(in_proj)} {in_proj}')

        odefunc = MNISTConvODEFunc(device, num_filters, non_linearity=non_linearity, mid_conv=mid_conv).to(self.device)
        self.odeblock = ODEBlock(device, odefunc, is_conv=True, tol=tol, adjoint=adjoint).to(self.device)

        self.out_projection = nn.Sequential(
            Flatten(),
            nn.Linear(self.latent_dim, self.output_dim).to(self.device),
        )

        if label_proj_strategy == 'repeat':
            self.label_projection = nn.Sequential(
                nn.Linear(10, num_filters),
                AppendRepeat(img_size[1:]),
            )
        elif label_proj_strategy == 'reshape':
            self.label_projection = nn.Sequential(
                nn.Linear(10, self.latent_dim),
                nn.Unflatten(1, (num_filters, img_size[1], img_size[2])),
            )
        elif label_proj_strategy == 'mlp':
            raise NotImplementedError
        else:
            raise ValueError(f'Unknown label_proj_strategy: {label_proj_strategy}')
