
import numpy as np
import torch
import math



def sign(x):
    if x > 0:
        return 1
    if x < 0:
        return -1
    return 0


def modified_selection_bias_old(ps, pv, sol, n, r):
    S = np.random.normal(0, 1, [n, ps])
    Z = np.random.normal(0, 1, [n, ps + 1])
    for i in range(ps):
        S[:, i:i + 1] = 0.8 * Z[:, i:i + 1] + 0.2 * Z[:, i + 1:i + 2]

    noise = np.random.normal(0, 0.3, [n, 1])

    Y = np.dot(S, sol) + noise + 1 * S[:, 0:1] * S[:, 1:2] * S[:, 2:3]
    Y_compare = np.dot(S, sol) + 1 * S[:, 0:1] * S[:, 1:2] * S[:, 2:3]

    if r > 0:
        center = Y_compare
    else:
        center = -Y_compare

    r = abs(r)
    sigma = math.sqrt(1/math.log2(r))

    V = np.zeros((center.shape[0], pv), dtype=np.float32)
    for i in range(center.shape[0]):
        V[i,:] = np.random.multivariate_normal(center[i]*(np.zeros(pv)+1.0), sigma*np.eye(pv), 1)

    X = np.concatenate((S,V), axis=1)
    X = torch.Tensor(X)
    Y = torch.Tensor(Y)
    return X, Y


def modified_selection_bias(ps, pv, sol, n, r):
    
    n1 = 1000000
    
    pvb = 1
    pv -= pvb
    
    S = np.random.normal(0, 2, [n1, ps])
    V = np.random.normal(0, 2, [n1, pvb + pv])

    Z = np.random.normal(0, 1, [n1, ps + 1])
    for i in range(ps):
        S[:, i:i + 1] = 0.8 * Z[:, i:i + 1] + 0.2 * Z[:, i + 1:i + 2]

    # beta = np.zeros((ps, 1))
    # for i in range(ps):
    #     beta[i] = (-1) ** i * (i % 3 + 1) * 1.0/2

    # noise = np.random.normal(0, 1.0, [n1, 1])
    noise = np.random.normal(0, 0.5, [n1, 1])

    # Y = np.dot(S, beta) + noise + 5 * S[:, 0:1] * S[:, 1:2] * S[:, 2:3]
    
    Y = np.dot(S, sol) + noise + 1. * S[:, 0:1] * S[:, 1:2] * S[:, 2:3]
    
    index_pre = np.ones([n1, 1], dtype=bool)
    for i in range(pvb):
        D = np.abs(V[:, pv + i:pv + i + 1] * sign(r) - Y)
        pro = np.power(np.abs(r), -D * 5)
        selection_bias = np.random.random([n1, 1])
        index_pre = index_pre & (
                    selection_bias < pro)
    index = np.where(index_pre == True)
    S_re = S[index[0], :]
    V_re = V[index[0], :]
    Y_re = Y[index[0]]
    n_s, p_s = S_re.shape
    index_s = np.random.permutation(n_s)

    X_re = np.hstack((S_re, V_re))
    # beta_X = np.vstack((beta, np.zeros((pv + pvb, 1))))
    sol_X = np.vstack((sol, np.zeros((pv + pvb, 1))))

    X = torch.from_numpy(X_re[index_s[0:n], :]).float()
    y =  torch.from_numpy(Y_re[index_s[0:n], :]).float()
    
    
    return X, y


        
def generate_env(r, n=1000, dim=10, pv=3):
    
    ps = dim - pv
    sol = np.zeros((ps, 1))
    for i in range(ps):
        sol[i] = (-1) ** i * (i % 3 + 1) * 1.0 / 3
        
    X, y = modified_selection_bias(dim-pv, pv, sol, 1000, r)
    
    return X, y, sol

class DataGenerator(object):
    def __init__(self, dim=10, pv=3, seed=19260817):
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        self.ps = dim - pv
        self.pv = pv
        
        self.sol = np.zeros((self.ps, 1))
        for i in range(self.ps):
            self.sol[i] = (-1) ** i * (i % 3 + 1) * 1.0 / 3
        
    def generate_env(self, r, n):
        X, y = modified_selection_bias(self.ps, self.pv, self.sol, n, r)
        return X, y
    
    def generate_envs(self, rs, ns):
        envs = []
        
        if len(rs) != len(ns):
            raise Exception('the size of \'rs\' should be equal to the size of \'ns\'. ')
        
        for e_i in range(len(rs)):
            X_i, y_i = modified_selection_bias(self.ps, self.pv,  self.sol, ns[e_i], rs[e_i])
            # co_i = np.ones_like(y_i) * rs[e_i]
            
            envs.append((X_i, y_i))
            # envs.append((X_i, y_i, co_i))
        return envs
        

def combine_envs(envs):
    X = []
    y = []
    for env in envs:
        X.append(env[0])
        y.append(env[1])
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)
    return X.reshape(-1, X.shape[1]), y.reshape(-1, 1)


source_r = 2.0

if __name__ == '__main__':
    
    np.random.seed(0)
    # X, y, sol = generate_env(dim=10, pv=5, r=1.5)
    # print(sol)
    # print(y[:3])
    
    data_gen = DataGenerator(dim=10, pv=3)
    X, y = data_gen.generate_env(r=source_r, n=1000)
    
    print('source bias: ', source_r)
    
    # print(y[:3])
    
    # envs = data_gen.generate_envs(rs=[1.5, -1.5], ns=[500, 10])
    # X, y = combine_envs(envs)
    
    from sklearn.linear_model import LinearRegression, Ridge
    # reg = LinearRegression().fit(X, y)
    
    estimated_sol = Ridge(fit_intercept=False).fit(X, y).coef_
    estimated_sol = torch.Tensor(estimated_sol)
    
    rs = [3.0, 2.0, 1.5, -1.5, -2.0, -3.0]
    ns = [500, 500, 500, 500, 500, 500]
    
    envs = data_gen.generate_envs(rs=rs, ns=ns)
    X_o, y_o = combine_envs(envs)
    
    print('range of Y: ', '(', y_o.min(), y_o.max(), ')')
    
    for e_i in range(len(envs)):
        X_i, y_i = envs[e_i]
        err = torch.mean((X_i.matmul(estimated_sol.T) - y_i) ** 2.).item()
        rmse_i = np.sqrt(err)
        print('r_i: ', rs[e_i], '   MSE: ', err)
    
    
    err = torch.mean((X_o.matmul(estimated_sol.T) - y_o) ** 2.).item()
    rmse = np.sqrt(err)
    print('overall MSE: ', err)
    
    ### Invariant part
    print('--------------   Invariant feature    --------')
    estimated_inv_sol = Ridge(fit_intercept=False).fit(X[:, : data_gen.ps], y).coef_
    
    estimated_inv_sol = torch.Tensor(estimated_inv_sol)
    
    for e_i in range(len(envs)):
        X_i, y_i = envs[e_i]
        err = torch.mean((X_i[:, : data_gen.ps].matmul(estimated_inv_sol.T) - y_i) ** 2.).item()
        # rmse_i = np.sqrt(err)
        print('r_i: ', rs[e_i], '   MSE: ', err)
    
    err = torch.mean((X_o[:, : data_gen.ps].matmul(estimated_inv_sol.T) - y_o) ** 2.).item()
    rmse = np.sqrt(err)
    print('overall MSE: ', err)
    
    
    