import torch.nn as nn
import torch
import torch.nn.functional as F


class RQspline():

    #Pixelwise nonlinearity modeled with rational quadratic spline
    #Adapted from https://github.com/biweidai/SINF/blob/master/RQspline.py
    #See appendix B of https://arxiv.org/pdf/2007.00674.pdf

    def __init__(self, nknot, B, FixExtraSlope=False, eps=1e-5):

        self.nknot = nknot
        self.B = B
        self.FixExtraSlope = FixExtraSlope
        self.eps = eps


    def _prepare(self, param):
        #return knot points and derivatives

        logdx = param[:,:self.nknot-1] 
        #xx = torch.cumsum(torch.exp(logdx), dim=1)
        xx = torch.cumsum(F.softplus(logdx), dim=1)
        xx = (2*self.B - (self.nknot-1)*self.eps) * xx / xx[:,-1,None] - self.B
        xx = xx + (torch.arange(self.nknot-1, device=param.device)+1.)*self.eps
        xx = torch.cat((torch.ones(len(param), 1, device=param.device)*(-self.B), xx), dim=1)

        logdy = param[:,self.nknot-1:2*(self.nknot-1)] 
        yy = torch.cumsum(F.softplus(logdy), dim=1)
        yy = (2*self.B - (self.nknot-1)*self.eps) * yy / yy[:,-1,None] - self.B
        yy = yy + (torch.arange(self.nknot-1, device=param.device)+1.)*self.eps
        yy = torch.cat((torch.ones(len(param), 1, device=param.device)*(-self.B), yy), dim=1)

        delta = F.softplus(param[:,2*(self.nknot-1):]).clone()

        if self.FixExtraSlope:
            delta[:,0] = 1
            delta[:,-1] = 1
        return xx, yy, delta


    def forward(self, x, param):
        # x: (ndata, ndim) 2d array
        x = x.T
        xx, yy, delta = self._prepare(param) #(ndata, nknot)

        index = torch.searchsorted(xx.detach(), x.T.contiguous().detach()).T
        y = torch.zeros_like(x)
        logderiv = torch.zeros_like(x)

        #linear extrapolation
        select0 = index == 0
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[select0]
        y[select0] = yy[dim, 0] + (x[select0]-xx[dim, 0]) * delta[dim, 0]
        logderiv[select0] = torch.log(delta[dim, 0])
        selectn = index == self.nknot
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[selectn]
        y[selectn] = yy[dim, -1] + (x[selectn]-xx[dim, -1]) * delta[dim, -1]
        logderiv[selectn] = torch.log(delta[dim, -1])

        #rational quadratic spline
        select = ~(select0 | selectn)
        index = index[select]
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[select]
        xi = (x[select] - xx[dim, index-1]) / (xx[dim, index] - xx[dim, index-1])
        s = (yy[dim, index]-yy[dim, index-1]) / (xx[dim, index]-xx[dim, index-1])
        xi1_xi = xi*(1-xi)
        denominator = s + (delta[dim, index]+delta[dim, index-1]-2*s)*xi1_xi
        xi2 = xi**2

        y[select] = yy[dim, index-1] + ((yy[dim, index]-yy[dim, index-1]) * (s*xi2+delta[dim, index-1]*xi1_xi)) / denominator
        logderiv[select] = 2*torch.log(s) + torch.log(delta[dim, index]*xi2 + 2*s*xi1_xi + delta[dim, index-1]*(1-xi)**2) - 2 * torch.log(denominator)

        x = x.T
        y = y.T
        logderiv = logderiv.T

        return y, logderiv


    __call__ = forward


    def inverse(self, y, param):
        y = y.T
        xx, yy, delta = self._prepare(param)

        index = torch.searchsorted(yy.detach(), y.T.contiguous().detach()).T
        x = torch.zeros_like(y)
        logderiv = torch.zeros_like(y)

        #linear extrapolation
        select0 = index == 0
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[select0]
        x[select0] = xx[dim, 0] + (y[select0]-yy[dim, 0]) / delta[dim, 0]
        logderiv[select0] = torch.log(delta[dim, 0])
        selectn = index == self.nknot
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[selectn]
        x[selectn] = xx[dim, -1] + (y[selectn]-yy[dim, -1]) / delta[dim, -1]
        logderiv[selectn] = torch.log(delta[dim, -1])

        #rational quadratic spline
        select = ~(select0 | selectn)
        index = index[select]
        dim = torch.repeat_interleave(torch.arange(len(param)).view(1,-1), len(x), dim=0)[select]
        deltayy = yy[dim, index]-yy[dim, index-1]
        s = deltayy / (xx[dim, index]-xx[dim, index-1])
        delta_2s = delta[dim, index]+delta[dim, index-1]-2*s
        deltay_delta_2s = (y[select]-yy[dim, index-1]) * delta_2s

        a = deltayy * (s-delta[dim, index-1]) + deltay_delta_2s
        b = deltayy * delta[dim, index-1] - deltay_delta_2s
        c = - s * (y[select]-yy[dim, index-1])
        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()
        xi = - 2*c / (b + torch.sqrt(discriminant))
        xi1_xi = xi * (1-xi)

        x[select] = xi * (xx[dim, index] - xx[dim, index-1]) + xx[dim, index-1]
        logderiv[select] = 2*torch.log(s) + torch.log(delta[dim, index]*xi**2 + 2*s*xi1_xi + delta[dim, index-1]*(1-xi)**2) - 2 * torch.log(s + delta_2s*xi1_xi)

        x = x.T
        y = y.T
        logderiv = logderiv.T

        return x, logderiv



class spline_transform(nn.Module):

    #1D transform parametrized with RQspline

    def __init__(self, TRE=True, ndim=None, B=8, nknot=8, FixExtraSlope=True, eps=1e-5, conditional=False):

        super().__init__()

        self.TRE = TRE
        self.conditional = conditional
        if not TRE:
            assert ndim is not None
            self.ndim = ndim
            self.nparam = self.ndim*(3*nknot-2)
        elif conditional:
            self.nparam = 3*nknot-2
        else:
            self.nparam = 0
            self.knots = nn.Parameter(torch.randn(3*nknot-2))
        self.activation = RQspline(nknot, B, FixExtraSlope, eps)


    def forward(self, data, param=None):

        shape = data.shape
        if self.TRE:
            if not self.conditional:
                param = torch.repeat_interleave(self.knots.reshape(1,-1), len(data), dim=0)
            data, logj0 = self.activation(data.reshape(len(data), -1), param)
        else:
            if self.conditional:
                data, logj0 = self.activation(data.reshape(-1,1), param.reshape(len(data)*self.ndim, 3*self.activation.nknot-2))
            else:
                data, logj0 = self.activation(data.reshape(len(data),-1).transpose(0,1), param.reshape(self.ndim, 3*self.activation.nknot-2))
                data = data.transpose(0,1)
                logj0 = logj0.transpose(0,1)

        data = data.reshape(*shape)
        logj = torch.sum(logj0.reshape(len(data), -1), dim=1)
        return data, logj


    __call__ = forward


    def inverse(self, data, param=None):

        shape = data.shape
        if self.TRE:
            if self.conditional:
                data, logj0 = self.activation.inverse(data.reshape(len(data), -1), param)
            else:
                data, logj0 = self.activation.inverse(data.reshape(1, -1), self.knots.reshape(1,-1))
        else:
            if self.conditional:
                data, logj0 = self.activation.inverse(data.reshape(-1,1), param.reshape(len(data)*self.ndim, 3*self.activation.nknot-2))
            else:
                data, logj0 = self.activation.inverse(data.reshape(len(data),-1).transpose(0,1), param.reshape(self.ndim, 3*self.activation.nknot-2))
                data = data.transpose(0,1)
                logj0 = logj0.transpose(0,1)

        data = data.reshape(*shape)
        logj = torch.sum(logj0.reshape(len(data), -1), dim=1)
        return data, logj

