from library.sde.base import DynamicalSystem

import torch
from torch import nn as nn
import numpy as np

'''
Classes necessary to compute the adjoint
'''


class AdjointRNN(DynamicalSystem):

    def __init__(self, rnn, derivative_activation, device='cpu'):
        super().__init__(noise_dim=None, noise=0.0, device=device)

        self.derivative_activation = derivative_activation

        self.dim = rnn.dim
        self.rnn = rnn

    def set_terminal_state(self, x, D, y):
        temp = 2 * (self.rnn.activation(x) @ D.T - y)
        temp2 = (D.unsqueeze(0) * self.derivative_activation(x).unsqueeze(1))
        self.terminal_state = torch.sum(temp.unsqueeze(-1) * temp2, dim=1)

    def get_initial_state(self, batch_size):
        return self.terminal_state

    def f(self, x, a):

        J = torch.einsum('ij,bi->bij', [self.rnn.W, self.derivative_activation(x)])

        return torch.einsum('bi,bji->bj', [a, J]) - a


class AdjointW(DynamicalSystem):
    def __init__(self, dim, device='cpu', dtype=torch.double):
        super().__init__(noise_dim=None, noise=0.0, device=device)

        self.dim = dim
        self.dtype = dtype

    def get_initial_state(self, batch_size):
        return torch.zeros((batch_size, self.dim), device=self.device, dtype=self.dtype)

    def f(self, phi_x, a):

        temp = torch.einsum('bi,bj->bij', [phi_x, a])

        return temp.reshape(-1, self.dim)

    def g(self, x, *args):

        return torch.zeros((len(x), self.dim), device=self.device, dtype=x.dtype)


class BackwardDynamicalSystem(DynamicalSystem):
    def __init__(self, ds, device='cpu'):
        super().__init__(noise_dim=None, noise=0.0, device=device)

        self.ds = ds
        self.dim = ds.dim

    def set_terminal_state(self, s):
        self.terminal_state = s

    def get_initial_state(self, batch_size):
        return self.terminal_state

    def f(self, *args):
        return -self.ds.f(*args)


def dtanh(x): return 1-torch.tanh(x)**2

def dretanh(x): return (x>=0)*(1-torch.tanh(x)**2)

def retanh(x): return torch.tanh(torch.relu(x))

def dsoftplus(x): return 1/(1+torch.exp(-x))

def identity(x): return x

def didentity(x): return torch.ones_like(x)

class LDSinput(DynamicalSystem):

    def __init__(self, dim, std=1.0, device='cpu', dtype=torch.double):
        super().__init__(device=device)

        self.std = std
        self.dim = dim
        self.dtype = dtype

    def set_initial_state(self, batch_size):

        self.initial_state = torch.randn(batch_size, self.dim, device=self.device, dtype=self.dtype)*self.std#*2-1
        self.W = torch.randn(batch_size, self.dim, self.dim, device=self.device, dtype=self.dtype)/np.sqrt(self.dim)

    def get_initial_state(self, batch_size):

        return self.initial_state

    def f(self, x, *args):

        return torch.einsum('bi,bij->bj', [x, self.W])-x


class FunctionToModule(nn.Module):
    def __init__(self, f):
        super().__init__()

        self.f = f

    def forward(self, x):

        return self.f(x)
