import math
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['ConcatDimwiseMLP']

nonlinear_fn = {
    'tanh': torch.tanh,
    'elu': F.elu,
}


class ConcatDimwiseMLP(nn.Module):

    def __init__(self, input_dims, hidden_dims, naux, nonlinearity='elu', layer_type='concatsquash', am_rank=0):
        assert layer_type in ['ignore', 'concat', 'concatsquash'], 'layer_type {} not supported'.format(layer_type)
        super(ConcatDimwiseMLP, self).__init__()
        self.input_dims = input_dims
        self.hidden_dims = hidden_dims
        self.naux = naux
        self.layer_type = layer_type
        self.nonlinearity = nonlinearity
        self.am_rank = am_rank

        self._setup_params()

    def _init_weights(self, nin, nout, with_bias=True):
        weight = torch.Tensor(nin, nout)
        bias = torch.Tensor(nout)

        gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
        weight_bound = math.sqrt(3.0) * gain / math.sqrt(nin)
        bias_bound = 1 / math.sqrt(nin)
        with torch.no_grad():
            weight.uniform_(-weight_bound, weight_bound)
            bias.uniform_(-bias_bound, bias_bound)

        if with_bias:
            return [weight.view(-1), bias.view(-1)]
        else:
            return [weight.view(-1)]

    def _setup_params(self):
        dims = [1 + self.naux] + list(self.hidden_dims) + [1]
        params = []
        shapes = []
        am_nparams = 0
        for nin, nout in zip(dims[:-1], dims[1:]):
            params.extend(self._init_weights(nin, nout))
            shapes.append((nin, nout))
            am_nparams += (nin + nout) * self.am_rank + nout
            if self.layer_type == 'concat':
                params.extend(self._init_weights(1, nout, with_bias=False))
            elif self.layer_type == 'concatsquash':
                params.extend(self._init_weights(1, nout, with_bias=False))
                params.extend(self._init_weights(1, nout, with_bias=True))
        self.params = nn.Parameter(torch.cat(params).view(1, -1).expand(self.input_dims, -1).contiguous().detach())
        self.am_nparams = am_nparams
        self.shapes = shapes

    def forward(self, t, x, nx, params=None, am_params=None):
        params = params if params is not None else self.params.view(1, *self.params.shape)
        rank = self.am_rank
        am_offset = 0

        x = torch.cat([x, nx], 1)
        offset = 0
        for i, (nin, nout) in enumerate(self.shapes):

            def _reshape(weight, batchsize, shape):
                weight = weight.view(-1, self.input_dims, *shape).expand(batchsize, self.input_dims, *shape)
                return weight.reshape(-1, *shape)

            batchsize = x.shape[0] // self.input_dims
            nweights = nin * nout
            W = _reshape(params[:, :, offset:offset + nweights], batchsize, (nin, nout))
            b = _reshape(params[:, :, offset + nweights:offset + nweights + nout], batchsize, (nout,))
            offset += nweights + nout

            if am_params is not None:
                if rank > 0:
                    am_W1 = _reshape(am_params[:, :, am_offset:am_offset + nin * rank], batchsize, (nin, rank))
                    am_offset += nin * rank
                    am_W2 = _reshape(am_params[:, :, am_offset:am_offset + rank * nout], batchsize, (rank, nout))
                    am_offset += nout * rank
                    am_b = _reshape(am_params[:, :, am_offset:am_offset + nout], batchsize, (nout,))
                    am_offset += nout
                    am_x = torch.bmm(torch.bmm(x.view(x.shape[0], 1, x.shape[1]), am_W1), am_W2).view(x.shape[0], -1
                                                                                                      ) + am_b
                else:
                    am_b = _reshape(am_params[:, :, am_offset:am_offset + nout], batchsize, (nout,))
                    am_offset += nout
                    am_x = am_b

                x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), W).view(x.shape[0], -1) + b + am_x
            else:
                x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), W).view(x.shape[0], -1) + b

            if self.layer_type == 'concat':
                b_t = _reshape(params[:, :, offset:offset + nout], batchsize, (nout,))
                offset += nout
                x = x + b_t * t
            elif self.layer_type == 'concatsquash':
                b_t = _reshape(params[:, :, offset:offset + nout], batchsize, (nout,))
                g_Wt = _reshape(params[:, :, offset + nout:offset + 2 * nout], batchsize, (nout,))
                g_bt = _reshape(params[:, :, offset + 2 * nout:offset + 3 * nout], batchsize, (nout,))
                offset += 3 * nout

                x = x * torch.sigmoid(g_Wt * t + g_bt) + b_t * t

            if i < len(self.shapes) - 1:
                x = nonlinear_fn[self.nonlinearity](x)
        return x

    def extra_repr(self):
        return 'idim={}, naux={}, hidden_dims={}, nonlinearity={}, layer_type={}, am_rank={}'.format(
            self.input_dims, self.naux, self.hidden_dims, self.nonlinearity, self.layer_type, self.am_rank
        )
