import os
import time
import random
import argparse
import numpy as np
from PIL import Image

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import sys
import time

sys.path.append('../') 
import torch.distributed as dist
import models
import torch.nn as nn
from tqdm import tqdm as tq
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_curve
import yaml
import math
from scipy.special import softmax

def reduce_tensor(tensor, n):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

# custom attack
# from attacks.linf import DI_fgsm, GA_DI_fgsm, TDI_fgsm, GA_TDI_fgsm, TDMI_fgsm, GA_TDMI_fgsm
from attacks.linf import *
# from attacks.feature import DI_FSA, DMI_FSA, GA_DMI_FSA, GA_DI_FSA, Feature_Adam_Attack
# from perceptual_advex.attacks import ReColorAdvAttack

from utils import MyCustomDataset, get_architecture, Input_diversity, MultiEnsemble, get_dataset, get_model, get_shadow_dataset
from utils import CrossEntropyLoss, MarginLoss, get_folder_names

parser = argparse.ArgumentParser(description='PyTorch Unrestricted Attack')
parser.add_argument('--config', type=str)
parser.add_argument('--rank', type=int)
parser.add_argument('--world-size', type=int)

NUM_CLASSES = 10

def normalize(item):
    max = item.max()
    min = item.min()
    return (item - min) / (max - min)

def softmax_by_row(logits, T = 1.0):
    mx = np.max(logits, axis=-1, keepdims=True)
    exp = np.exp((logits - mx)/T)
    denominator = np.sum(exp, axis=-1, keepdims=True)
    return exp/denominator

def main(config):
    config = parse_config(args.config)
    NUM_CLASSES = config.num_classes
    batch_size = config.batch_size
    # Target model
    tmp_model = get_model(model_id=int(config.target_id), indice=0, model_num=0, config=args.config ).cuda().eval()
    Target_model =tmp_model
    Target_model = Input_diversity(tmp_model, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)

    Shadow_model = get_model(model_id=int(config.shadow_id), indice=1, model_num=0, config=args.config ).cuda().eval()
    Shadow_model = Input_diversity(Shadow_model, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)
    # Source model
    source_id_list = [int(item) for item in config.source_list.split('_')]
    print("Source id list: {}".format(source_id_list))
    
    Source_model_list = []
    for model_num in range(config.s_model_num):
        for idx in source_id_list:
            # temp_model = get_architecture(model_name=MODEL_NAME_DICT[idx]).cuda().eval()
            temp_model = get_model(model_id=idx, indice=1, model_num=model_num, config=args.config ).cuda().eval()
            Source_model_list.append(temp_model)
    Source_model = MultiEnsemble(Source_model_list, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)
    # Source_model = get_model(model_id=config.target_id, indice=1, model_num=0 ).cuda().eval()
    # Source_model = Input_diversity(Source_model, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)

    # Auxiliary model
    auxiliary_id_list = [int(item) for item in config.auxiliary_list.split('_')]
    print("Auxiliary id list: {}".format(auxiliary_id_list))

    Auxiliary_model_list = []
    for model_num in range(config.a_model_num):
        for idx in auxiliary_id_list:
            temp_model = get_model(model_id=idx, indice=1, model_num=model_num, config=args.config).cuda().eval()
            Auxiliary_model_list.append(temp_model)
    Auxiliary_model = MultiEnsemble(Auxiliary_model_list, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)
    # Auxiliary_model  = get_model(model_id=0, indice=2, model_num=0 ).cuda().eval()

    idx = get_idx_for_rank(config, args.world_size, args.rank)
    trainset, testset = get_dataset(idx)
    train_loader=torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
    test_loader=torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    shadow_idx = get_idx_for_rank(config, args.world_size, args.rank, shadow=True)
    shadow_memset, shadow_nonmemset = get_shadow_dataset(shadow_idx)
    shadow_mem_loader=torch.utils.data.DataLoader(shadow_memset, batch_size=batch_size, shuffle=False)
    shadow_nonmem_loader=torch.utils.data.DataLoader(shadow_nonmemset, batch_size=batch_size, shuffle=False)

    #get the name of config dir
    config_path=args.config
    all_dir_names = get_folder_names(config_path)
    config_dir_name = all_dir_names[2] #trainer
    config_second_dir_name = all_dir_names[3] #defense
    config_third_dir_name = all_dir_names[4] #attack config

    thres_list = config.thres_list
    for thres in thres_list:
        config.thres = thres
        print('*********',thres)
        if thres < 0:
            Target_model_list = [tmp_model]
            Target_model_a = MultiEnsemble(Target_model_list, config=config, num_classes=NUM_CLASSES, prob=config.prob, mode=config.mode, diversity_scale=config.scale)
            output_train_benign, train_label, output_train_adversarial, transfer_train_pertub_transfer=model_adv_prediction(config, Target_model, Source_model, Target_model_a, train_loader)
            output_test_benign, test_label, output_test_adversarial, transfer_test_pertub_transfer=model_adv_prediction(config, Target_model, Source_model, Target_model_a, test_loader)
        else:
            output_train_benign, train_label, output_train_adversarial, transfer_train_pertub_transfer=model_adv_prediction(config, Target_model, Source_model, Auxiliary_model, train_loader)
            output_test_benign, test_label, output_test_adversarial, transfer_test_pertub_transfer=model_adv_prediction(config, Target_model, Source_model, Auxiliary_model, test_loader)

        data_path = f'./data/{config_dir_name}/{config_second_dir_name}/{config_third_dir_name}/target/{config.attack_method}/thres_{thres}'
        os.makedirs(data_path, exist_ok=True)
        np.savez(os.path.join(data_path, f'world_size{args.world_size}_rank{args.rank}.npz'),
                    output_train_benign=output_train_benign, train_label=train_label, output_train_adversarial=output_train_adversarial,transfer_train_pertub_transfer=transfer_train_pertub_transfer,\
                        output_test_benign=output_test_benign, test_label=test_label,output_test_adversarial=output_test_adversarial,transfer_test_pertub_transfer=transfer_test_pertub_transfer )

    if not config.attack.attack is None:
        output_shadow_mem, label_shadow_mem = model_prediction(config, Shadow_model, shadow_mem_loader) #change target model -> shadow model
        output_shadow_nonmem, label_shadow_nonmem = model_prediction(config, Shadow_model, shadow_nonmem_loader)

        output_target_mem, label_target_mem = model_prediction(config, Target_model, shadow_nonmem_loader) #change target model -> shadow model
        output_target_nonmem, label_target_nonmem = model_prediction(config, Target_model, shadow_mem_loader)

        shadow_data_path = f'./data/{config_dir_name}/{config_second_dir_name}/{config_third_dir_name}/shadow'
        os.makedirs(shadow_data_path, exist_ok=True)
        np.savez(os.path.join(shadow_data_path, f'world_size{args.world_size}_rank{args.rank}.npz'),
                    output_shadow_mem=output_shadow_mem, label_shadow_mem=label_shadow_mem, output_shadow_nonmem=output_shadow_nonmem, label_shadow_nonmem=label_shadow_nonmem,\
                        output_target_mem=output_target_mem, label_target_mem=label_target_mem, output_target_nonmem=output_target_nonmem, label_target_nonmem=label_target_nonmem)

def model_adv_prediction(config, Target_model, Source_model, Auxiliary_model, dataloader):
    train_label = []
    transfer_train_pertub_transfer=[]
    output_train_adversarial_shadow = []
    output_train_benign =[]
    loss_fn = nn.CrossEntropyLoss()

    for num, data in tq(enumerate(dataloader)):
        # print(num)
        images,labels = data
        batch_size = len(images)
        
        image_tensor= images.cuda()
        output = Target_model(image_tensor)
        train_label.append(labels.numpy())
        # output_train_benign.append(softmax_by_row(output.data.cpu().numpy(),T = 1))
        output = output.data.cpu().numpy().astype(np.double)
        output_train_benign.append(softmax(output,1))

        img, label= images.cuda(), labels.cuda()
        # if config.attack_method == 'GA_TDMI_fgsm':
        #     x_adv, budget = GA_TDMI_fgsm(Source_model, Auxiliary_model, img.clone(), label.clone(), config, loss_fn)
        # elif config.attack_method == 'GA_TDI_fgsm':
        #     x_adv, budget = GA_TDI_fgsm(Source_model, Auxiliary_model, img.clone(), label.clone(), config, loss_fn)
        # elif config.attack_method == 'GA_DI_fgsm': 
        #     x_adv, budget = GA_DI_fgsm(Source_model, Auxiliary_model, img.clone(), label.clone(), config, loss_fn)
        x_adv, budget = eval(config.attack_method)(Source_model, Auxiliary_model, img.clone(), label.clone(), config, loss_fn)
            
        adv_output = Target_model.forward(x_adv)
        # output_train_adversarial_shadow.append(softmax_by_row(adv_output.data.cpu().numpy(),T = 1))
        adv_output=adv_output.data.cpu().numpy().astype(np.double)
        output_train_adversarial_shadow.append(softmax(adv_output,1))
        # output_train_adversarial_shadow.append(adv_output.data.cpu().numpy())
        transfer_train_pertub_transfer.append(budget.cpu().numpy())

    output_train_benign=np.concatenate(output_train_benign)
    output_train_adversarial_shadow=np.concatenate(output_train_adversarial_shadow)
    train_label=np.concatenate(train_label)
    transfer_train_pertub_transfer=np.concatenate(transfer_train_pertub_transfer)
    
    return output_train_benign, train_label, output_train_adversarial_shadow, transfer_train_pertub_transfer

def model_prediction(config, Target_model, dataloader):
    train_label = []
    output_train_benign =[]
    loss_fn = nn.CrossEntropyLoss()

    for num, data in tq(enumerate(dataloader)):
        # print(num)
        images,labels = data
        batch_size = len(images)
        image_tensor= images.cuda()
        output = Target_model(image_tensor)
        train_label.append(labels.numpy())
        output=output.data.cpu().numpy().astype(np.double)
        # output_train_benign.append(softmax_by_row(output.data.cpu().numpy(),T = 1))
        output_train_benign.append(softmax(output,1))
        # output_train_benign.append(output.data.cpu().numpy())

    output_train_benign=np.concatenate(output_train_benign)
    train_label=np.concatenate(train_label)

    return output_train_benign, train_label

def get_idx_for_rank(config, world_size, rank, shadow=False):
    num_sample=config.shadow_sample if shadow else config.num_sample 
    all_idx = np.arange(num_sample)
    num_per_rank = math.ceil(num_sample/world_size)

    return all_idx[rank*num_per_rank: (rank+1)*num_per_rank]


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def parse_config(config_path=None):
    with open(config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)
        new_config = dict2namespace(config)
    return new_config

if __name__ == "__main__":
    args = parser.parse_args()
    start = time.time()
    main(args)
    end = time.time()
    print('running time:', end-start)
