import numpy as np
from sklearn.metrics import mean_squared_error
import torch

def infer_u(model, r, d, a):
    device = r.device
    
    u_mu, u_logvar = model.q_u(r.to(device), d.to(device), a.to(device))

    u_prev = model.reparameterize(u_mu, u_logvar)
    return u_prev

def gen_x(model, u, a):
    device = u.device

    r, d, _ = model.reconstruct_hard(u, a.to(device))
    x = torch.cat([r, d], dim=1)

    return x

def cf_eval(y, y_cf, a):
    a = a.squeeze()
    mask1 = (a == 0)
    mask2 = (a == 1)
    
    cf_effect = np.abs(y_cf - y)
    o1 = cf_effect[mask1]
    o2 = cf_effect[mask2]
    return np.sum(cf_effect) / cf_effect.shape[0], np.sum(o1) / o1.shape[0], np.sum(o2) / o2.shape[0]

def cfe_classifier(data_dict, clf):

    train_dat = data_dict["train"]
    test_dat = data_dict["test"]
    
    inputs =  train_dat["u_hat"]

    y = train_dat["y"].ravel()

    clf.fit(inputs, y)
    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)

    y_factual = clf.predict(test_dat["u_hat"])
    acc = mean_squared_error(test_dat["y"].ravel(), y_factual.ravel(),squared=False)

    y_counter = clf.predict(test_dat["u_cf_hat"])
    a = test_dat["a"]
    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)

    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf

def cfr_classifier(data_dict, clf):

    train_dat = data_dict["train"]
    test_dat = data_dict["test"]
    
    inputs = np.concatenate([train_dat["u_hat"], 
                        (train_dat["x"] + train_dat["x_cf_uhat"]) / 2], axis=1)

    y = train_dat["y"].ravel()
    clf.fit(inputs, y)
    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)
    #print('CFR Train RMSE:',train_acc)
    

    y_factual = clf.predict(np.concatenate([
        test_dat["u_hat"],
        (test_dat["x"] + test_dat["x_cf_uhat"]) / 2
    ], axis=1))
    acc = mean_squared_error(test_dat["y"].ravel(), y_factual.ravel(),squared=False)

    y_counter = clf.predict(np.concatenate([
        test_dat["u_cf_hat"],
        (test_dat["x_cf"] + test_dat["x_cf_cf_uhat"]) / 2
    ], axis=1))
    a = test_dat["a"]
    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)

    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf

def erm_classifier(data_dict, clf):
    train_dat = data_dict["train"]
    test_dat = data_dict["test"]
    
    inputs = np.concatenate([train_dat["x"],
                             train_dat['a']],axis=1)
    y = train_dat["y"].ravel()

    clf.fit(inputs, y)
    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)



    y_factual = clf.predict(np.concatenate([test_dat["x"],
                                            test_dat['a']],axis=1))
    
    acc = mean_squared_error(test_dat["y"].ravel(), y_factual.ravel(),squared=False)

    y_counter = clf.predict(np.concatenate([test_dat["x_cf"],
                                            test_dat['a_cf']],axis=1))
    a = test_dat["a"]
    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)

    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf

def pcf_mix(y_score, ycf_score, a, is_cf=False):
    # attribute corresponding to y
    a_0_indices = a == 0
    a_1_indices = a == 1
    a_0_ratio = np.sum(a_0_indices) / len(a)
    a_1_ratio = 1-a_0_ratio
    if is_cf is True:
        # we need to use the ratio in the real data
        a_0_ratio, a_1_ratio = a_1_ratio, a_0_ratio

    y_output = np.zeros_like(y_score.ravel())
    y_output[a_0_indices] = y_score[a_0_indices] * a_0_ratio + ycf_score[a_0_indices] * a_1_ratio
    y_output[a_1_indices] = y_score[a_1_indices] * a_1_ratio + ycf_score[a_1_indices] * a_0_ratio

    return y_output

def pcf_classifier(data_dict, clf):
    train_dat = data_dict["train"]
    test_dat = data_dict["test"]
    
    inputs = np.concatenate([train_dat["x"],
                             train_dat['a']],axis=1)

    y = train_dat["y"].ravel()
    
    clf.fit(inputs, y)
    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)

    # ======= factual pred ======= #
    y_factual_score = clf.predict(np.concatenate([test_dat["x"],
                                                test_dat['a']],axis=1))
    y_factual_cf_score = clf.predict(np.concatenate([test_dat["x_cf_uhat"],
                                                     test_dat['a_cf']],axis=1))
    
    y_factual = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())
    acc = mean_squared_error(test_dat["y"].ravel(), y_factual.ravel(), squared=False)

    # ======= counter pred ======= #
    y_counter_score = clf.predict(np.concatenate([test_dat["x_cf"],
                                                        test_dat['a_cf']],axis=1))
    y_counter_cf_score = clf.predict(np.concatenate([test_dat["x_cf_cf_uhat"],
                                                        test_dat['a']],axis=1))
    y_counter = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)

    a = test_dat["a"]
    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)

    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf

def pcfaug_classifier(data_dict, clf):
    train_dat = data_dict["train"]
    test_dat = data_dict["test"]
    
    inputs = np.concatenate([
    np.concatenate([train_dat["x"],
                             train_dat['a']],axis=1),
    np.concatenate([train_dat["x_cf_uhat"],
                             train_dat['a_cf']],axis=1)],axis=0)

    y = np.concatenate([train_dat["y"],train_dat["y"]],axis=0).ravel()
    
    clf.fit(inputs, y)
    
    train_acc = mean_squared_error(y, clf.predict(inputs),squared=False)

    # ======= factual pred ======= #
    y_factual_score = clf.predict(np.concatenate([test_dat["x"],
                                                test_dat['a']],axis=1))
    y_factual_cf_score = clf.predict(np.concatenate([test_dat["x_cf_uhat"],
                                                     test_dat['a_cf']],axis=1))
    
    y_factual = pcf_mix(y_factual_score, y_factual_cf_score, test_dat['a'].ravel())
    acc = mean_squared_error(test_dat["y"].ravel(), y_factual.ravel(), squared=False)

    # ======= counter pred ======= #
    y_counter_score = clf.predict(np.concatenate([test_dat["x_cf"],
                                                        test_dat['a_cf']],axis=1))
    y_counter_cf_score = clf.predict(np.concatenate([test_dat["x_cf_cf_uhat"],
                                                        test_dat['a']],axis=1))
    y_counter = pcf_mix(y_counter_score, y_counter_cf_score, test_dat['a_cf'].ravel(),is_cf=True)

    a = test_dat["a"]
    cf_effect, cf_effect0, cf_effect1 = cf_eval(y_factual, y_counter, a)

    return train_acc, acc, cf_effect, cf_effect0, cf_effect1, clf

