import random
import numpy as np
import statistics
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import svm, ensemble
from sklearn.neural_network import MLPClassifier
from folktables import ACSDataSource, ACSEmployment, ACSIncomePovertyRatio, ACSHealthInsurance
from homogenization import measure_homogenization, aggregate_measurements
from tqdm import tqdm
import pickle


def format_data(acs_data, applications_list, seed):
	name2application = {'employment' : ACSEmployment, 'income-poverty' : ACSIncomePovertyRatio, 'health-insurance' : ACSHealthInsurance}
	data_matrices = []

	# Construct data for each application
	for application_name in applications_list:
		application = name2application[application_name]
		application_matrices = application.df_to_numpy(acs_data)
		data_matrices.extend(application_matrices)
	
	# Unified traint/test split across applications
	split_matrices = train_test_split(*data_matrices, test_size = 0.2, random_state = seed) 
	assert len(split_matrices) % 6 == 0
	
	# Package data into applications_data = {name: {X_tr, X_test, y_tr, y_test, z_tr, z_test}}
	applications_data = {}
	for i in range(0, len(split_matrices), 6):
		name = applications_list[i // 6]
		X_train, X_test = split_matrices[i], split_matrices[i+1]
		y_train, y_test = split_matrices[i+2], split_matrices[i+3]
		z_train, z_test = split_matrices[i+4], split_matrices[i+5]
		applications_data[name] = {'X_tr' : X_train, 'X_test' : X_test, 'y_tr': y_train, 'y_test' : y_test, 'z_tr' : z_train, 'z_test' : z_test}
	
	return applications_data


def fixed_partition(applications_data, applications_list, seed, data_scale):
	rng = random.Random(seed)
	index = rng.randint(0, data_scale - 1)
	
	for name, entries in applications_data.items():
		X_train, y_train, z_train = entries['X_tr'], entries['y_tr'], entries['z_tr']
		N = len(y_train)
		block_length = N // data_scale
		start, end = block_length * index, block_length * index + block_length
		entries['X_tr'], entries['y_tr'], entries['z_tr'] = X_train[start : end], y_train[start : end], z_train[start : end]
		
	return applications_data


def disjoint_partition(applications_data, applications_list, seed, data_scale):
	k = len(applications_list)
	permutation = np.random.RandomState(seed=seed).permutation(data_scale)
	
	for position, (name, entries) in enumerate(applications_data.items()):
		X_train, y_train, z_train = entries['X_tr'], entries['y_tr'], entries['z_tr']
		N = len(y_train)
		block_length = N // data_scale
		index = permutation[position % data_scale]
		start, end = block_length * index, block_length * index + block_length
		entries['X_tr'], entries['y_tr'], entries['z_tr'] = X_train[start : end], y_train[start : end], z_train[start : end]
	
	return applications_data


def generate_predictions(data, method, seed, predict_train=False):
	if method == 'logistic':
		core_model = LogisticRegression
	elif method == 'gbm':
		core_model = ensemble.GradientBoostingClassifier
	elif method == 'svm':
		core_model = svm.SVC
	elif model == 'nn':
		core_model = MLPClassifier
	else:
		raise NotImplementedError
	
	model = make_pipeline(StandardScaler(), core_model(random_state=seed))
	model.fit(data['X_tr'], data['y_tr'])
	test_predictions = model.predict(data['X_test'])
	if predict_train:
		train_predictions = model.predict(data['X_tr'])
		return {'train_predictions' : train_predictions, 'test_predictions' : test_predictions}

	return test_predictions


def group_and_format(applications_data, grouping):
	reformatted_applications_data = {}
	groups = set()
	for name, data in applications_data.items():
		reformatted_data = {}
		X, y, z, yhat = data['X_test'], data['y_test'], data['z_test'], data['predictions']
		assert X.shape[0] == y.shape[0] == z.shape[0] == yhat.shape[0]

		for i in range(X.shape[0]):
			if grouping == 'individual':
				group = i
			elif grouping == 'race':
				group = z[i]
			else:
				raise NotImplementedError

			entry = {'input' : X[i], 'label' : y[i], 'prediction' : yhat[i], 'group' : group}
			
			# Entry id is index of entry in this case
			reformatted_data[i] = entry 
			groups.add(group)

		reformatted_applications_data[name] = reformatted_data

	return reformatted_applications_data, groups


# Experiment to acclimate reader with homogenization quantities
def base_experiment(acs_data, applications_list, method = 'logistic', groupings = ['race', 'individual'], verbose = False):
	print('Running base census experiment')

	data_seed = 0
	model_seeds = list(range(5))

	base_table = {}
	for grouping in tqdm(groupings):
		seed2measurements = {}
		for seed in tqdm(model_seeds):
			# Generate predictions
			applications_data = format_data(acs_data, applications_list, data_seed)

			for name, data in applications_data.items():
				predictions = generate_predictions(data, method, seed)
				data['predictions'] = predictions 
			
			# Group inputs and reformat data to prepare for homogenization measurement
			reformatted_applications_data, groups = group_and_format(applications_data, grouping)
			
			homogenization_measurements = measure_homogenization(reformatted_applications_data, groups, verbose = verbose)
			seed2measurements[seed] = homogenization_measurements
		aggregate_homogenization_measurements = aggregate_measurements(seed2measurements)
		base_table[grouping] = aggregate_homogenization_measurements
	
	return base_table


# Experiment to test role of model complexity on homogenization
def complexity_experiment(acs_data, applications_list, methods, groupings = ['individual', 'race']):
	print('Running complexity census experiment')

	data_seed = 0
	model_seeds = list(range(5))

	complexity_table = {}
	for method in tqdm(methods):
		for grouping in tqdm(groupings):
			seed2measurements = {}
			for seed in tqdm(model_seeds):
				# Generate predictions
				applications_data = format_data(acs_data, applications_list, data_seed)

				for name, data in applications_data.items():
					train_predictions, test_predictions = generate_predictions(data, method, seed)
					data['predictions'] = test_predictions 
					# TODO compute train accuracy



				# Group inputs and reformat data to prepare for homogenization measurement
				reformatted_applications_data, groups = group_and_format(applications_data, grouping)
				
				homogenization_measurements = measure_homogenization(reformatted_applications_data, groups)
				seed2measurements[seed] = homogenization_measurements
			aggregate_homogenization_measurements = aggregate_measurements(seed2measurements)
			complexity_table[(method, grouping)] = aggregate_homogenization_measurements

	return complexity_table


# Experiment to test role of data partition
def partition_experiment(acs_data, applications_list, data_scale, method = 'logistic', groupings = ['individual', 'race']):
	print('Running partition census experiment for data scale {}'.format(data_scale))

	data_seed = 0
	model_seeds = list(range(10))
	partition_seeds = list(range(10))

	partition_table = {}
	# Results for fixed 1/k training data where k = num. applications
	print('Training fixed models')
	
	seed2measurements = {grouping : {} for grouping in groupings}
	for model_seed in tqdm(model_seeds):
		for partition_seed in partition_seeds:
			# Generate predictions
			applications_data = format_data(acs_data, applications_list, data_seed)
			applications_data = fixed_partition(applications_data, applications_list, partition_seed, data_scale)

			for name, data in applications_data.items():
				predictions = generate_predictions(data, method, model_seed)
				data['predictions'] = predictions 
			
			for grouping in groupings:
				# Group inputs and reformat data to prepare for homogenization measurement
				reformatted_applications_data, groups = group_and_format(applications_data, grouping)
				homogenization_measurements = measure_homogenization(reformatted_applications_data, groups)
				seed2measurements[grouping][(model_seed, partition_seed)] = homogenization_measurements
	for grouping in groupings:
		aggregate_homogenization_measurements = aggregate_measurements(seed2measurements[grouping])
		partition_table[('fixed', grouping)] = aggregate_homogenization_measurements

	# Results for disjoint 1/k training data where k = num. applications
	print('Training disjoint models')
	seed2measurements = {grouping : {} for grouping in groupings}
	for model_seed in tqdm(model_seeds):
		for partition_seed in partition_seeds:
			# Generate predictions
			applications_data = format_data(acs_data, applications_list, data_seed)
			applications_data = disjoint_partition(applications_data, applications_list, partition_seed, data_scale)

			for name, data in applications_data.items():
				predictions = generate_predictions(data, method, model_seed)
				data['predictions'] = predictions 
				
			
			for grouping in groupings:
				# Group inputs and reformat data to prepare for homogenization measurement
				reformatted_applications_data, groups = group_and_format(applications_data, grouping)
				
				homogenization_measurements = measure_homogenization(reformatted_applications_data, groups)
				seed2measurements[grouping][(model_seed, partition_seed)] = homogenization_measurements
	for grouping in groupings:
		aggregate_homogenization_measurements = aggregate_measurements(seed2measurements[grouping])
		partition_table[('disjoint', grouping)] = aggregate_homogenization_measurements	

	return partition_table


if __name__ == '__main__':
	# Fetch ACS data
	local = False
	if local:
		root_dir = 'data'
	else:
		root_dir = '/nlp/scr/rishibommasani/homogenization-cache/ACS_data'
	data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person', root_dir = root_dir)
	print('Loading data')
	acs_data = data_source.get_data(download=False)
	# acs_data = data_source.get_data(download=False)

	applications_list = ['employment', 'income-poverty', 'health-insurance']
	
	for data_scale in tqdm([100000, 50000, 10000, 5000, 1000, 500, 100, 50, 10, 5]):
	# for data_scale in tqdm([100000, 50000, 10000, 5000, 1000, 500, 100, 50, 10, 5]):
		results = partition_experiment(acs_data, applications_list, data_scale)
		print('\n')
		print(data_scale)
		print(results)
		print('\n')
		pickle.dump(results, open("results/census_partition_10x10_{}.pkl".format(data_scale), "wb"))


	# methods = ['logistic', 'gbm']
	# complexity_table = complexity_experiment(acs_data, applications_list, methods)
	# print(complexity_table)
