import logging
import os
import datetime

import torch
import torch.nn as nn
import numpy as np
import torch.backends.cudnn as cudnn
import pickle

from torch.utils.data import DataLoader
from torch import optim
from torch.nn import functional as F

from utils.parameters import get_parameter
from utils.dataset import construct_datasets, targetData
from utils.utils import get_grad_diff, read_results, set_random_seed
from utils.model import Linear, ResNet18, VGG16, MobileNetV2
from pretrain import test


def victim(args, logging, poison_weight, train_data=None, test_data=None):
    print('==> begin victim')
    logging.info('==> begin victim')

    if train_data is None and test_data is None:
        train_data, test_data = construct_datasets(args.dataset, args.datadir, load=True)

    train_loader = DataLoader(train_data, batch_size=args.batchsize, shuffle=True, num_workers=1)
    test_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=True, num_workers=1)

    input_shape = len(train_data[0][0])
    num_classes = args.num_classes
    if args.net == 'ResNet18':
        model = ResNet18(num_classes).to(args.device)
    elif args.net == 'VGG16':
        model = VGG16().to(args.device)
    elif args.net == 'MobileNetV2':
        model = MobileNetV2().to(args.device)
    elif args.net == 'Linear':
        model = Linear(input_shape, num_classes).to(args.device)

    loss_func = nn.CrossEntropyLoss()
    if args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.opt == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr,  momentum=0.9, weight_decay=5e-4)

    checkpoint = torch.load(os.path.join(args.moddir, args.dataset + '_' + args.net + '_' + str(args.modname) + '.pth'), map_location=args.device)
    model.load_state_dict(checkpoint['net'])

    test(args, logging, model, test_loader)

    target_data = torch.utils.data.Subset(test_data, args.targetids)
    target_loader = DataLoader(target_data, batch_size=args.batchsize, shuffle=False, num_workers=1)

    # find unlearning data
    baseids = (poison_weight > 0.5).nonzero().squeeze().tolist()
    base_data = torch.utils.data.Subset(train_data, baseids)
    base_loader = DataLoader(base_data, batch_size=args.batchsize, shuffle=False, num_workers=1)

    model.eval()
    for i, (images, labels, _) in enumerate(target_loader):
        images, labels = images.to(args.device), labels.to(args.device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted.detach().cpu().numpy())
        logging.info(predicted.detach().cpu().numpy())

    # first-order unlearning method
    diff = get_grad_diff(args, model, poison_weight, base_loader)
    d_theta = diff

    model.eval()
    with torch.no_grad():
        for p in model.parameters():
            if p.requires_grad:
                new_p = p - args.tau * d_theta.pop(0)
                p.copy_(new_p)

    total, att_success, pred_success = 0, 0, 0

    with torch.no_grad():
        for batch, (images, labels, _) in enumerate(target_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            prob = F.softmax(outputs, dim=1)
            top_prob, top_class = prob.topk(1, dim=1)

            total += labels.size(0)
            if args.atsetting == 'targeted':
                att_success += (predicted == args.poisonclass).sum().item()
            elif args.atsetting == 'untargeted':
                att_success += (predicted != args.targetclass).sum().item()
            pred_success += (predicted == args.targetclass).sum().item()

            print(predicted.detach().cpu().numpy())
            logging.info(predicted.detach().cpu().numpy())

    att_success_rate = att_success/total
    pred_succss_rate = pred_success/total
    print(f'Attack success = {att_success_rate} Predicted success = {pred_succss_rate}')
    logging.info(f'Attack success = {att_success_rate} Predicted success = {pred_succss_rate}')


if __name__ == '__main__':
    args = get_parameter()

    if not os.path.exists(args.vicdir):
        os.makedirs(args.vicdir)

    log_path = args.craftproj + '-log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
    log_path = log_path + '.txt'
    logging.basicConfig(
        filename=os.path.join(args.vicdir, log_path),
        format="%(asctime)s - %(name)s - %(message)s",
        datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.info(str(args))

    if torch.cuda.is_available():
        #args.device = torch.device("cuda")
        cudnn.benchmark = True
    else:
        args.device = "cpu"
    print(f'device: {args.device}')
    logger.info(f'device: {args.device}')

    train_data, test_data = construct_datasets(args, args.dataset, args.datadir, load=True)

    args.targetclass, args.poisonclass, args.targetids, poison_weight = read_results(args)

    victim(args, logging, poison_weight, train_data, test_data)

"""
python first_order_victim.py -craftproj=craft -dataset=CIFAR10 -net=ResNet18 -modname=model1 -outdir=outputs-res -tau=0.0001
"""