from models.cvx_mlp import Convex_MLP
from utils.model_utils import get_hyperplane_cuts
import jax.numpy as jnp
from jax import jit, tree_util, vmap

class CVX_ReLU_MLP(Convex_MLP):
    def __init__(self, X, y, P_S, beta, rho, seed, d_diags = None, e_diags = None):
        super().__init__(X, y, P_S, beta, rho, seed)
        self.d_diags = d_diags
        self.e_diags = e_diags
    
    def init_model(self):
        self.d_diags, self.seed = get_hyperplane_cuts(self.X, self.P_S, self.seed)
        self.e_diags = 2*self.d_diags-1
    
    def matvec_Fi(self, i, vec):
        return self.d_diags[:,i] * (self.X @ vec)
    
  
    def rmatvec_Fi(self, i, vec):
        return  self.X.T @ (self.d_diags[:,i] * vec)
    
    @jit
    def matvec_F(self, vec):
        n = self.X.shape[0]
        out = jnp.zeros((n,))
        for i in range(self.P_S):
            out += self.matvec_Fi(i, vec[0,:,i] - vec[1,:,i])
        return out
    
    @jit
    def rmatvec_F(self, vec):
        n, d = self.X.shape
        out = jnp.zeros((2, d, self.P_S))
        for i in range(self.P_S):
            rFi_v = self.rmatvec_Fi(i,vec)
            out = out.at[0,:,i].set(rFi_v)
            out = out.at[1,:,i].set(-rFi_v)
        return out
    

    def matvec_Gi(self, i, vec):
        return self.e_diags[:,i] * (self.X @ vec)
    
    
    def rmatvec_Gi(self, i, vec):
        return self.X.T@(self.e_diags[:,i]*vec)
    
    @jit
    def matvec_G(self, vec):
        n, d = self.X.shape
        out = jnp.zeros((2, n, self.P_S))
        for i in range(self.P_S):
            out = out.at[0,:,i].set(self.matvec_Gi(i,vec[0,:,i]))
            out = out.at[1,:,i].set(self.matvec_Gi(i,vec[1,:,i]))
        return out
    
    @jit
    def rmatvec_G(self,vec):
        n, d = self.X.shape
        out = jnp.zeros((2, d, self.P_S))
        for i in range(self.P_S):
            out = out.at[0,:,i].set(self.rmatvec_Gi(i,vec[0,:,i]))
            out = out.at[1,:,i].set(self.rmatvec_Gi(i,vec[1,:,i]))
        return out
    
    @jit 
    def matvec_A(self, vec):
        b = vec  # jax arrays are immutable, so there's no need to copy
        b = b + 1/self.rho * self.rmatvec_F(self.matvec_F(vec))
        b = b + self.rmatvec_G(self.matvec_G(vec))
        return b
    
    def _tree_flatten(self):
        children = (self.X, self.y, self.beta, self.seed, self.d_diags, self.e_diags)  # arrays / dynamic values
        aux_data = {'P_S': self.P_S, 'rho': self.rho}  # static values
        return (children, aux_data)
    
    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        X, y, beta, seed, d_diags, e_diags = children
        P_S = aux_data['P_S']
        rho = aux_data['rho']
        return cls(X, y, P_S, beta, rho, seed, d_diags, e_diags)
  
tree_util.register_pytree_node(CVX_ReLU_MLP,
                                CVX_ReLU_MLP._tree_flatten,
                                CVX_ReLU_MLP._tree_unflatten)