#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import time

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from dataset import AdultDataset
from models import MLPNet, FairNet
from utils import conditional_errors
from utils import get_logger

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", help="Name used to save the log file.", type=str, default="adult")
parser.add_argument("-s", "--seed", help="Random seed.", type=int, default=42)
parser.add_argument("-u", "--mu", help="Hyperparameter of the coefficient of the adversarial classification loss",
                    type=float, default=1.0)
parser.add_argument("-e", "--epoch", help="Number of training epochs", type=int, default=5)
parser.add_argument("-r", "--lr", type=float, help="Learning rate of optimization", default=1.0)
parser.add_argument("-b", "--batch_size", help="Batch size during training", type=int, default=512)
parser.add_argument("-m", "--model", help="Which model to run: [mlp|fair]", type=str, default="mlp")
# Compile and configure all the model parameters.
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(8)

logger = get_logger(args.name)

# Set random number seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
dtype = np.float32

# Load UCI Adult dataset.
time_start = time.time()
adult_train = AdultDataset(root_dir='data', phase='train', tar_attr='income', priv_attr='sex')
adult_test = AdultDataset(root_dir='data', phase='test', tar_attr='income', priv_attr='sex')
train_loader = DataLoader(adult_train, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(adult_test, batch_size=args.batch_size, shuffle=False)
time_end = time.time()
logger.info("Time used to load all the data sets: {} seconds.".format(time_end - time_start))
input_dim = adult_train.xdim

# Train MLPNet to get baseline results.
configs = {"num_classes": 2, "num_epochs": args.epoch, "batch_size": args.batch_size, "lr": args.lr, "mu": args.mu,
           "input_dim": input_dim, "hidden_layers": [500, 200, 100], "adversary_layers": [50]}
num_epochs = configs["num_epochs"]
batch_size = configs["batch_size"]
lr = configs["lr"]

if args.model == "mlp":
    logger.info("Experiment without debiasing:")
    logger.info("Hyperparameter setting = {}.".format(configs))
    # Train MLPNet without debiasing.
    time_start = time.time()
    net = MLPNet(configs).to(device)
    optimizer = optim.Adadelta(net.parameters(), lr=lr)
    net.train()
    for t in range(num_epochs):
        running_loss = 0.0
        for xs, ys, attrs in train_loader:
            xs, ys, attrs = xs.to(device), ys.to(device), attrs.to(device)
            optimizer.zero_grad()
            ypreds = net(xs)
            # Compute prediction accuracy on training set.
            loss = F.nll_loss(ypreds, ys)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        logger.info("Iteration {}, loss value = {}".format(t, running_loss))
    time_end = time.time()
    logger.info("Time used for training = {} seconds.".format(time_end - time_start))
    # Test.
    net.eval()
    target_insts = torch.from_numpy(adult_test.X).float().to(device)
    target_labels = np.argmax(adult_test.Y, axis=1)
    target_attrs = np.argmax(adult_test.A, axis=1)
    preds_labels = torch.max(net(target_insts), 1)[1].cpu().numpy()
    cls_error, error_0, error_1 = conditional_errors(preds_labels, target_labels, target_attrs)
    idx = target_attrs == 0
    base_0, base_1 = np.mean(target_labels[idx]), np.mean(target_labels[~idx])
    pred_0, pred_1 = np.mean(preds_labels[idx]), np.mean(preds_labels[~idx])
    logger.info("Overall predicted error = {}, Err|A=0 = {}, Err|A=1 = {}".format(cls_error, error_0, error_1))
    logger.info("|Err|A=0 + Err|A=1| = {}".format(error_0 + error_1))
    logger.info("|Err|A=0 - Err|A=1| = {}".format(np.abs(error_0 - error_1)))
    logger.info("Bias: |Pred=1|A=0 - Pred=1|A=1| = {}".format(np.abs(pred_0 - pred_1)))
    logger.info("Pr(Y = 1|A = 0) = {}".format(base_0))
    logger.info("Pr(Y = 1|A = 1) = {}".format(base_1))
    logger.info("Total Variation Lower bound = {}".format(np.abs(base_0 - base_1)))
    logger.info("*" * 100)
elif args.model == "fair":
    # Training with FairNet to show the debiased results.
    logger.info("Experiment with FairNet adversarial debiasing:")
    logger.info("Hyperparameter setting = {}.".format(configs))

    time_start = time.time()
    net = FairNet(configs).to(device)
    optimizer = optim.Adadelta(net.parameters(), lr=lr)
    mu = args.mu
    net.train()
    for t in range(num_epochs):
        running_loss, running_adv_loss = 0.0, 0.0
        for xs, ys, attrs in train_loader:
            xs, ys, attrs = xs.to(device), ys.to(device), attrs.to(device)
            optimizer.zero_grad()
            ypreds, apreds = net(xs)
            # Compute both the prediction loss and the adversarial loss.
            loss = F.nll_loss(ypreds, ys)
            adv_loss = F.nll_loss(apreds, attrs)
            running_loss += loss.item()
            running_adv_loss += adv_loss.item()
            loss += mu * adv_loss
            loss.backward()
            optimizer.step()
        logger.info("Iteration {}, loss value = {}, adv_loss value = {}".format(t, running_loss, running_adv_loss))
    time_end = time.time()
    logger.info("Time used for training = {} seconds.".format(time_end - time_start))
    net.eval()
    target_insts = torch.from_numpy(adult_test.X).float().to(device)
    target_labels = np.argmax(adult_test.Y, axis=1)
    target_attrs = np.argmax(adult_test.A, axis=1)
    preds_labels = torch.max(net.inference(target_insts), 1)[1].cpu().numpy()
    cls_error, error_0, error_1 = conditional_errors(preds_labels, target_labels, target_attrs)
    idx = target_attrs == 0
    base_0, base_1 = np.mean(target_labels[idx]), np.mean(target_labels[~idx])
    pred_0, pred_1 = np.mean(preds_labels[idx]), np.mean(preds_labels[~idx])
    logger.info("Overall predicted error = {}, Err|A=0 = {}, Err|A=1 = {}".format(cls_error, error_0, error_1))
    logger.info("|Err|A=0 + Err|A=1| = {}".format(error_0 + error_1))
    logger.info("|Err|A=0 - Err|A=1| = {}".format(np.abs(error_0 - error_1)))
    logger.info("Bias: |Pred=1|A=0 - Pred=1|A=1| = {}".format(np.abs(pred_0 - pred_1)))
    logger.info("Pr(Y = 1|A = 0) = {}".format(base_0))
    logger.info("Pr(Y = 1|A = 1) = {}".format(base_1))
    logger.info("Total Variation Lower bound = {}".format(np.abs(base_0 - base_1)))
    logger.info("*" * 100)
else:
    raise NotImplementedError("{} not supported.".format(args.model))
