

import numpy as np
import torch


def generate_data(num_data=10000, dim=10, dv=5, bias=0.5, scramble=0, sigma_s=3.0, sigma_v=0.3):
    from scipy.stats import ortho_group
    S = np.float32(ortho_group.rvs(size=1, dim=dim, random_state=1))
    y = np.random.choice([1, 0], size=(num_data, 1))
    
    X = np.random.randn(num_data, dim)
    d = dim - dv
    X[:, :d] *= sigma_s
    X[:, d:] *= sigma_v
    flip = np.random.choice([1, 0], size=(num_data, 1), p=[bias, 1. - bias]) * y
    X[:, :d] += y
    X[:, d:] += flip
    if scramble == 1:
        X = np.matmul(X, S)
    X, y = torch.from_numpy(X).float(), torch.from_numpy(y).float()
    
    # y[y < 0] = 0
    return X, y


def encode_y(y):
    y = y.long()
    y_one_hot = torch.zeros(len(y), 2).scatter_(1, y, 1)
    return y_one_hot

class DataGenerator(object):
    def __init__(self, dim=10, pv=5, seed=19260817):
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        self.ps = dim - pv
        self.pv = pv
        
    def generate_env(self, r, n, y_encode=False):
        X, y = generate_data(dim=self.ps + self.pv, dv = self.pv, num_data=n, bias=r, scramble=0)
        
        if y_encode:
            return X, encode_y(y)
        return X, y
    
    def generate_envs(self, rs, ns, y_encode=False):
        envs = []
        
        if len(rs) != len(ns):
            raise Exception('the size of \'rs\' should be equal to the size of \'ns\'. ')
        
        for e_i in range(len(rs)):
            X_i, y_i = generate_data(dim=self.ps + self.pv, dv = self.pv, num_data=ns[e_i], bias=rs[e_i], scramble=0)
            if y_encode:
                y_i = encode_y(y_i)
            
            envs.append((X_i, y_i))
        return envs
    

if __name__ == '__main__':
    
    
    r_init = 0.9
    
    data_gen = DataGenerator(dim=10, pv=3)
    X, y = data_gen.generate_env(r=r_init, n=1000)
    
    from sklearn.linear_model import LogisticRegression
    # reg = LinearRegression().fit(X, y)
    
    # print('label: ', y[:10])
    
    clf = LogisticRegression(fit_intercept=False).fit(X, y.reshape(-1))
    
    rs = [r_init, 0.7, 0.5, 0.3, 0.1]
    ns = [1000, 1000, 1000, 1000, 1000]
    
    envs = data_gen.generate_envs(rs=rs, ns=ns)
    for e_i in range(len(envs)):
        X_i, y_i = envs[e_i]
        acc_i = clf.score(X_i, y_i)
        print('r_i: ', rs[e_i], '   ACC: ', acc_i)
    
    ### Invariant part
    print('--------------   Invariant part    --------')
    inv_clf = LogisticRegression(fit_intercept=False).fit(X[:, : data_gen.ps], y.reshape(-1))
    
    for e_i in range(len(envs)):
        X_i, y_i = envs[e_i]
        acc_i = inv_clf.score(X_i[:, : data_gen.ps], y_i)
        print('r_i: ', rs[e_i], '   ACC: ', acc_i)