import numpy as np

class linear_instance(object):
    def __init__(self, X, Y, sigma) -> None:
        self.X = X
        self.Y = Y
        self.N = X.shape[0]
        self.d = X.shape[1]
        self.sigma = sigma
    
    def U_grad(self, theta, idx=None):
        if idx is None:
            return self.X.T @ (self.X @ theta - self.Y) / self.sigma**2 + theta
        Y_sub = np.atleast_1d(self.Y[idx])
        X_sub = np.atleast_2d(self.X[idx])
        factor = self.N / len(Y_sub)
        return X_sub.T @ (X_sub @ theta - Y_sub) / self.sigma**2 * factor + theta

class logistic_instance(object):
    def __init__(self, X, Y) -> None:
        self.X = X
        self.Y = Y
        self.N = X.shape[0]
        self.d = X.shape[1]
        self.SigmaX = np.eye(self.d)
    
    def U_grad(self, theta, idx=None):
        if idx is None:
            W = 1 / (1 + np.exp(-self.X @ theta))
            return -self.X.T @ (self.Y - W) + self.SigmaX @ theta
        Y_sub = np.atleast_1d(self.Y[idx])
        X_sub = np.atleast_2d(self.X[idx])
        factor = self.N / len(Y_sub)
        W = np.atleast_1d(1 / (1 + np.exp(-self.X[idx] @ theta)))
        return -X_sub.T @ (Y_sub - W) * factor + self.SigmaX @ theta
    
    def U(self, theta):
        W = 1 / (1 + np.exp(-self.X @ theta))
        return -np.sum(self.Y * np.log(W) + (1 - self.Y) * np.log(1 - W)) + theta.T @ self.SigmaX @ theta / 2

class double_well_instance(object):
    def __init__(self, sigma) -> None:
        self.sigma = sigma
        self.d = 1
    
    def U_grad(self, theta):
        return (theta / 2 - theta / (1 + theta**2)) / self.sigma**2

class CRE_instance(object):
    def __init__(self, Y) -> None:
        self.Y = Y
        self.I = Y.shape[0]
        self.J = Y.shape[1]
        self.d = self.I + self.J + 3
    
    def U(self, theta):
        I, J = self.Y.shape
        res = np.sum((self.Y - theta[-3] - np.tile(theta[:I], (J, 1)).T - np.tile(theta[I:-3], (I, 1)))**2) / 2
        res += np.sum(theta[:I]**2 / (2 * np.exp(theta[-2])))
        res += np.sum(theta[I:-3]**2 / (2 * np.exp(theta[-1])))                                  
        res += np.sum(theta[-3:]**2 / 2)
        return res

    def U_grad(self, theta):
        g = np.zeros(self.d)
        I = self.I
        J = self.J
        tmp = -(self.Y - theta[-3] - np.tile(theta[:I], (J, 1)).T - np.tile(theta[I:-3], (I, 1)))
        g[:I] = np.sum(tmp, 1) + theta[:I] / np.exp(theta[-2])
        g[I:-3] = np.sum(tmp, 0) + theta[I:-3] / np.exp(theta[-1])
        g[-3] = np.sum(tmp) + theta[-3]
        g[-2] = -np.sum(theta[:I]**2 / 2 / np.exp(theta[-2])) + theta[-2]
        g[-1] = -np.sum(theta[I:-3]**2 / 2 / np.exp(theta[-1])) + theta[-1]
        return g
