import torch
import numpy as np
from pathlib import Path
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
from config import get_config
from networks import get_network
from datasets import load_data

parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--device', default=0, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=100, type=int, help='Amount of samples used during inference')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--stoch_varianz', default=0.05, type=float, help='Only applicable for SNN models - variance of noise added to the input')
parser.add_argument('--smooth', type=bool, default=False) # <- fixed
args = parser.parse_args()
args = get_config(args)
torch.cuda.set_device(args.device)


def main(args):
    # get data
    _, test_loader, _ = load_data(args.dataset, args.batch_size, args.root_dir)

    # get model
    if args.net == 'ResNet':
        model = get_network(args)
    else:
        model = get_network(args)
        args.droprate = 0

    # load model
    if args.net == 'SNN':
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.layer}_{args.stoch_varianz}.bin'''))

    else:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}.bin'''))
    model.load_state_dict(parameter)

    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # using 10 different seeds for testing
    accuracy = []
    for i in range(10):
        args.randseed += 1000
        np.random.seed(args.randseed)
        torch.manual_seed(args.randseed)
        random.seed(args.randseed)
        acc = test_model(model, test_loader, args)
        accuracy.append(acc)
    print(f'{format(np.mean(accuracy) * 100, ".2f")} \pm {format(np.sqrt(np.var(accuracy)) * 100, ".2f")}')


def test_model(model, test_loader, args):
    preds = []
    targets = []
    model.cuda()
    for _, (X_mb, t_mb) in enumerate(test_loader):
        X_mb, t_mb = X_mb.cuda(), t_mb.long()
        pred = predict_model(model, X_mb, args.n_samples)
        preds.append(pred)
        targets.append(t_mb.numpy())
    target = np.concatenate(targets, 0)
    predictions = np.concatenate(preds, 1)
    predicted = np.argmax(np.mean(predictions, axis=0), axis=1)
    accuracy = 1 - (np.count_nonzero((predicted - target)) / len(predicted))
    return accuracy


def predict_model(model, test_data, pred_number):
    preds = []
    for _ in range(pred_number):
        pred = model.forward(test_data)
        if args.net != 'SNN':
            pred = F.softmax(pred, dim=1)
        preds.append(pred.cpu().data.numpy())
    return np.array(preds)


if __name__ == "__main__":
    main(args)
