import torch
import torch.nn as nn

from ..diffopnet.diagjac_fn import f_and_jac_fn, f_and_jac_fn_amortized, f_and_jac_fn_low_rank


class GlacierODEfunc(nn.Module):

    def __init__(self, exclusive_net, dimwise_net):
        super(GlacierODEfunc, self).__init__()
        self.exclusive_net = exclusive_net
        self.dimwise_net = dimwise_net
        self.register_buffer("_num_evals", torch.tensor(0.))

    def before_odeint(self):
        self._num_evals.fill_(0)

    def forward(self, t, states):
        assert len(states) >= 2
        x = states[0]

        # increment num evals
        self._num_evals += 1

        # convert to tensor
        if not torch.is_tensor(t):
            t = torch.tensor(t)
        t = t.type_as(x).requires_grad_(True)
        batchsize = x.shape[0]

        if len(states) > 2:
            dimwise_params = states[2]
            dx, diag_jac = f_and_jac_fn_amortized(self.exclusive_net, self.dimwise_net, dimwise_params, t, x)
        else:
            dx, diag_jac = f_and_jac_fn(self.exclusive_net, self.dimwise_net, t, x)
        divergence = diag_jac.view(batchsize, -1).sum(dim=1, keepdim=True)

        return tuple([dx, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])


class GlacierLowRankODEfunc(nn.Module):

    def __init__(self, exclusive_net, dimwise_net):
        super(GlacierLowRankODEfunc, self).__init__()
        self.exclusive_net = exclusive_net
        self.dimwise_net = dimwise_net
        self.register_buffer("_num_evals", torch.tensor(0.))

    def before_odeint(self):
        self._num_evals.fill_(0)

    def forward(self, t, states, with_djac=False):
        assert len(states) == 3
        x = states[0]

        # increment num evals
        self._num_evals += 1

        # convert to tensor
        if not torch.is_tensor(t):
            t = torch.tensor(t)
        t = t.type_as(x).requires_grad_(True)
        batchsize = x.shape[0]

        # amortized low rank dimwise net
        dimwise_am_params = states[2]
        dx, diag_jac = f_and_jac_fn_low_rank(self.exclusive_net, self.dimwise_net, dimwise_am_params, t, x)
        divergence = diag_jac.view(batchsize, -1).sum(dim=1, keepdim=True)

        zeros_ = [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]]
        f = tuple([dx, -divergence] + zeros_)

        if with_djac:
            djac = tuple([diag_jac, torch.zeros_like(divergence)] + zeros_)
            return f, djac
        else:
            return f
