import numpy
import pandas
from sklearn.preprocessing import StandardScaler
import scipy.special
from algorithms import *

def weak_instruments_data(n,d=3,alpha=1):
    Z = numpy.random.randn(n, d)
    eta = numpy.random.randn(n, d)
    X = alpha*Z + eta
    xi = numpy.sum(eta, axis = 1)
    theta = numpy.zeros((d))
    theta[0] = 1 * (d / 20)**.5
    Y = X@theta + xi
    moments=iv.IV_Moments(X,Y,Z)
    w=moments.optimize()
    return (X,Y,Z,theta,numpy.linalg.norm(w-theta))

def corrupted_weak_instruments_data(n,d,alpha,epsilon):
    X,Y,Z,theta,oe = weak_instruments_data(n,d,alpha)
    k = int(epsilon*n)
    ZY = Z[k:].T @ Y[k:]
    for i in range(k):
        Z[i] = -(1.0/k) * ZY / d**0.5
        Y[i] = d**0.5
    return X,Y,Z,theta,oe

def heterogeneous_effects_data(n,d,epsilon):
    X = numpy.random.normal(size=(n,d))
    Z = numpy.random.binomial(1, 0.5, size=(n,))
    U = numpy.random.normal(size=(n,))
    T = numpy.random.binomial(1, scipy.special.expit(Z + (d**.5)*U*numpy.mean(X,axis=1)))
    theta = numpy.random.normal(size=(d,))
    Y = (X@theta)*T + U
    for i in range(int(epsilon*n)):
        X[i] = numpy.ones((d))
        Y[i] = 3*(d**0.5)
    betatheta = numpy.hstack([numpy.zeros((d)), theta])
    return numpy.hstack([X, T[:,None]*X]), Y, numpy.hstack([X, Z[:,None]*X]), betatheta

def pre_process_nlsym():
    df = pandas.read_csv("nlsym.csv")
    data_filter = df['educ'].values >= 6
    T = df['educ'].values[data_filter]
    Z = df['nearc4'].values[data_filter]
    y = df['lwage'].values[data_filter]

    # Impute missing values with mean, add dummy columns
    # Excluded the columns 'weights' as we don't know what it is
    X_df = df[['exper', 'expersq']].copy()
    # dropped column 'south66' and 'reg669' because it's a linear combination of some of the previous columns
    X_df[['black', 'smsa', 'south', 'smsa66']] = df[['black', 'smsa', 'south', 'smsa66']]
    X_df[['momdad14', 'sinmom14', 'reg661', 'reg662',
            'reg663', 'reg664', 'reg665', 'reg666', 'reg667', 'reg668']] = df[['momdad14', 'sinmom14', 
            'reg661', 'reg662','reg663', 'reg664', 'reg665', 'reg666', 'reg667', 'reg668']]
    X_df['fatheduc'] = df['fatheduc'].fillna(value=df['fatheduc'].mean())
    X_df['fatheduc_nan'] = df['fatheduc'].isnull()*1
    X_df['motheduc'] = df['motheduc'].fillna(value=df['motheduc'].mean())
    X_df['motheduc_nan'] = df['motheduc'].isnull()*1
    print(X_df.columns)
    #columns_to_scale = ['fatheduc', 'motheduc', 'exper', 'expersq']
    #scaler = StandardScaler()
    #X_df[columns_to_scale] = scaler.fit_transform(X_df[columns_to_scale])
    X = X_df.values[data_filter]

    return X,y,Z,T

def nlsym_data(m=2):
    X,y,Z,T = pre_process_nlsym()
    # select only first two features (exp and expsq)
    X = numpy.hstack([X[:, :m], numpy.ones((X.shape[0], 1))])
    return X,y,Z,T


def corrupted_nlsym_data(epsilon=0.1):
    X,y,Z,T = nlsym_data()
    n = X.shape[0]
    P = numpy.random.permutation(n)
    X = X[P]
    y = y[P]
    Z = Z[P]
    T = T[P]
    XT = numpy.hstack([X * T[:, None], X])
    XZ = numpy.hstack([X * Z[:, None], X])
    moments = IV_Moments(XT,y,XZ)
    wnaive = moments.optimize()
    disc = numpy.zeros((XZ.shape[1]))
    disc += -2.0 * XZ[int(epsilon*n):].T @ XT[int(epsilon*n):] @ wnaive
    A = XZ[:int(epsilon * n)].T
    q = numpy.linalg.lstsq(A, disc)[0]
    y[:int(epsilon*n)] = q - XT[:int(epsilon*n)] @ wnaive
    return X, y, Z, T

    
