# -*- coding: utf-8 -*-

'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
import copy
import scipy

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset

import torchvision


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print("==> Computing mean and std..")
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode="fan_out")
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


try:
	_, term_width = os.popen("stty size", "r").read().split()
except:
	term_width = 80
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(" [")
    for i in range(cur_len):
        sys.stdout.write("=")
    sys.stdout.write(">")
    for i in range(rest_len):
        sys.stdout.write(".")
    sys.stdout.write("]")

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append("  Step: %s" % format_time(step_time))
    L.append(" | Tot: %s" % format_time(tot_time))
    if msg:
        L.append(" | " + msg)

    msg = "".join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(" ")

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write("\b")
    sys.stdout.write(" %d/%d " % (current+1, total))

    if current < total-1:
        sys.stdout.write("\r")
    else:
        sys.stdout.write("\n")
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ""
    i = 1
    if days > 0:
        f += str(days) + "D"
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + "h"
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + "m"
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + "s"
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + "ms"
        i += 1
    if f == "":
        f = "0ms"
    return f

def set_random_seed(seed=0):
    torch.manual_seed(seed + 0)
    torch.cuda.manual_seed(seed + 1)
    torch.cuda.manual_seed_all(seed + 2)
    np.random.seed(seed + 3)
    torch.cuda.manual_seed_all(seed + 4)
    random.seed(seed + 5)


def imshow(img):
    img = img.detach().cpu()
    img = img / 2 + 0.5   # unnormalize
    npimg = img.numpy()   # convert from tensor
    plt.imshow(np.transpose(npimg, (1, 2, 0))) 
    plt.show()


def generate_aug_imgs(args):
    canaries = []

    for _ in range(args.num_aug):
        x = args.aug_trainset[args.target_img_id][0]
        x = x.unsqueeze(0)

        canaries.append(x)
    
    return canaries


def get_logits(curr_canary, model, keep_tensor=False):
    with torch.no_grad():
        logits = model(curr_canary)
    
    if not keep_tensor:
        logits = logits.detach().cpu().tolist()

    return logits


def normalize_logits(logits):
    logits = logits - np.max(logits, axis=-1, keepdims=True)
    logits = np.array(np.exp(logits), dtype=np.float64)
    logits = logits / np.sum(logits, axis=-1,keepdims=True)

    return logits


def get_pure_logits(pred_logits, class_labels):
    pred_logits = copy.deepcopy(pred_logits)

    scores = []
    for pred_logits_i in pred_logits:
        score = copy.deepcopy(pred_logits_i[np.arange(len(pred_logits_i)), :, class_labels])
        
        scores.append(score)
        
    scores = np.array(scores)

    return scores


def get_normal_logits(pred_logits, class_labels):
    pred_logits = copy.deepcopy(pred_logits)

    scores = []
    for pred_logits_i in pred_logits:
        pred_logits_i = normalize_logits(pred_logits_i)

        score = copy.deepcopy(pred_logits_i[np.arange(len(pred_logits_i)), :, class_labels])
        
        scores.append(score)
        
    scores = np.array(scores)

    return scores


def get_log_logits(pred_logits, class_labels):
    pred_logits = copy.deepcopy(pred_logits)

    scores = []
    for pred_logits_i in pred_logits:
        pred_logits_i = normalize_logits(pred_logits_i)

        y_true = copy.deepcopy(pred_logits_i[np.arange(len(pred_logits_i)), :, class_labels])
        pred_logits_i[np.arange(len(pred_logits_i)), :, class_labels] = 0
        y_wrong = np.sum(pred_logits_i, axis=2)
        score = (np.log(y_true+1e-45) - np.log(y_wrong+1e-45))
        
        scores.append(score)
        
    scores = np.array(scores)

    return scores


def calibrate_logits(pred_logits, class_labels, logits_strategy):
    if logits_strategy == "pure_logits":
        scores = get_pure_logits(pred_logits, class_labels)
    elif logits_strategy == "log_logits":
        scores = get_log_logits(pred_logits, class_labels)
    elif logits_strategy == "normal_logits":
        scores = get_normal_logits(pred_logits, class_labels)
    else:
        raise NotImplementedError()

    return scores

'''
    Implemtation from:
    https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
'''

def lira_online(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=False):
    dat_in = []
    dat_out = []

    for j in range(shadow_scores.shape[1]):
        dat_in.append(shadow_scores[shadow_in_out_labels[:, j], j, :])
        dat_out.append(shadow_scores[~shadow_in_out_labels[:, j], j, :])
        
    in_size = min(map(len,dat_in))
    out_size = min(map(len,dat_out))

    dat_in = np.array([x[:in_size] for x in dat_in])
    dat_out = np.array([x[:out_size] for x in dat_out])

    mean_in = np.median(dat_in, 1)
    mean_out = np.median(dat_out, 1)

    if fix_variance:
        std_in = np.std(dat_in)
        std_out = np.std(dat_out)
    else:
        std_in = np.std(dat_in, 1)
        std_out = np.std(dat_out, 1)

    final_preds = []
    true_labels = []

    for ans, sc in zip(target_in_out_labels, target_scores):
        pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in+1e-30)
        pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)
        score = pr_in-pr_out

        final_preds.extend(score.mean(1))
        true_labels.extend(ans)

    final_preds = np.array(final_preds)
    true_labels = np.array(true_labels)

    return -final_preds, true_labels


def lira_offline(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=False):
    dat_out = []

    for j in range(shadow_scores.shape[1]):
        dat_out.append(shadow_scores[~shadow_in_out_labels[:, j], j, :])
        
    out_size = min(map(len,dat_out))

    dat_out = np.array([x[:out_size] for x in dat_out])

    mean_out = np.median(dat_out, 1)

    if fix_variance:
        std_out = np.std(dat_out)
    else:
        std_out = np.std(dat_out, 1)

    final_preds = []
    true_labels = []

    for ans, sc in zip(target_in_out_labels, target_scores):
        score = scipy.stats.norm.logpdf(sc, mean_out, std_out+1e-30)

        final_preds.extend(score.mean(1))
        true_labels.extend(ans)

    final_preds = np.array(final_preds)
    true_labels = np.array(true_labels)

    return -final_preds, true_labels


def cal_stats(final_preds, true_labels):
    fpr, tpr, thresholds = metrics.roc_curve(true_labels, final_preds, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)
    low = tpr[np.where(fpr<.01)[0][-1]]

    return fpr, tpr, auc, acc, low


def cal_results(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, logits_mul=1):
    some_stats = {}
    
    final_preds, true_labels = lira_online(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=True)
    fpr, tpr, auc, acc, low = cal_stats(logits_mul * final_preds, true_labels)
    some_stats["fix_auc"] = auc
    some_stats["fix_acc"] = acc
    some_stats["fix_TPR@0.01FPR"] = low

    final_preds, true_labels = lira_offline(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=True)
    fpr, tpr, auc, acc, low = cal_stats(logits_mul * final_preds, true_labels)
    some_stats["fix_off_auc"] = auc
    some_stats["fix_off_acc"] = acc
    some_stats["fix_off_TPR@0.01FPR"] = low

    final_preds, true_labels = lira_online(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=False)
    fpr, tpr, auc, acc, low = cal_stats(logits_mul * final_preds, true_labels)
    some_stats["auc"] = auc
    some_stats["acc"] = acc
    some_stats["TPR@0.01FPR"] = low

    final_preds, true_labels = lira_offline(shadow_scores, shadow_in_out_labels, target_scores, target_in_out_labels, fix_variance=False)
    fpr, tpr, auc, acc, low = cal_stats(logits_mul * final_preds, true_labels)
    some_stats["off_auc"] = auc
    some_stats["off_acc"] = acc
    some_stats["off_TPR@0.01FPR"] = low

    return some_stats


def get_dataset(args):
    if args.dataset == "cifar10":
        args.data_mean = (0.4914, 0.4822, 0.4465)
        args.data_std = (0.2023, 0.1994, 0.2010)
        args.num_classes = 10

        return torchvision.datasets.CIFAR10
    elif args.dataset == "cifar100":
        args.data_mean = (0.5071, 0.4867, 0.4408)
        args.data_std = (0.2675, 0.2565, 0.2761)
        args.num_classes = 100

        return torchvision.datasets.CIFAR100
    elif args.dataset == "mnist":
        args.data_mean = (0.1307,)
        args.data_std = (0.3081,)
        args.num_classes = 10

        return torchvision.datasets.MNIST
    else:
        raise NotImplementedError()


def get_log_logits_torch(logits, y):
    logits = logits - torch.max(logits, dim=-1, keepdims=True)[0]
    logits = torch.exp(logits)
    logits = logits / torch.sum(logits, dim=-1, keepdims=True)

    y_true = logits[:, y]
    num_class = logits.shape[-1]
    wrong_indx = [i for i in range(num_class) if i != y]
    y_wrong = torch.sum(logits[:, wrong_indx], dim=-1)
    logits = (torch.log(y_true+1e-45) - torch.log(y_wrong+1e-45))

    return logits


def split_shadow_models(shadow_models, target_img_id):
    in_models = []
    out_models = []

    for curr_model in shadow_models:
        if target_img_id in curr_model.in_data:
            curr_model.is_in_model = True
            in_models.append(curr_model)
        else:
            curr_model.is_in_model = False
            out_models.append(curr_model)
    
    return in_models, out_models


def get_curr_shadow_models(shadow_models, x, args):
    if args.offline:
        if args.stochastic_k is None:
            in_models, out_models = split_shadow_models(shadow_models, args.target_img_id)
            return out_models
        else:
            in_models, out_models = split_shadow_models(shadow_models, args.target_img_id)
            curr_shadow_models = random.sample(out_models, args.stochastic_k)
            return curr_shadow_models

    if args.stochastic_k is None:
        curr_shadow_models = shadow_models
    else:
        if args.stochastic_k > 1:
            # more balanced for kl loss
            in_models, out_models = split_shadow_models(shadow_models, args.target_img_id)

            num_in = int(args.stochastic_k / 2)
            num_out = args.stochastic_k - num_in
            
            curr_shadow_models = random.sample(in_models, num_in)
            curr_shadow_models += random.sample(out_models, num_out)
        elif args.balance_shadow:
            in_models, out_models = split_shadow_models(shadow_models, args.target_img_id)
            min_len = min(len(in_models), len(out_models))

            curr_shadow_models = random.sample(in_models, min_len)
            curr_shadow_models += random.sample(out_models, min_len)
        else:
            curr_shadow_models = random.sample(shadow_models, args.stochastic_k)


    return curr_shadow_models


class FixedFlipLabelDataset(Dataset):
    def __init__(self, target_dataset, poison_k=1):
        self.target_dataset = target_dataset
        self.num_classes = len(target_dataset.dataset.classes)
        self.poison_k = poison_k
        
        # Store length for efficient data access
        self.total_length = len(target_dataset) * poison_k
        
        # Generate fixed flipped labels for each image in target_dataset
        self.flipped_labels = []
        for _, label in target_dataset:
            flipped_label = random.choice([x for x in range(self.num_classes) if x != label])
            self.flipped_labels.append(flipped_label)

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        # Adjust index for target_dataset and fetch data
        adjusted_idx = idx // self.poison_k
        image, _ = self.target_dataset[adjusted_idx]
        
        # Get the previously generated flipped label for the image
        flipped_label = self.flipped_labels[adjusted_idx]
        
        return image, flipped_label


class RandomFlipLabelDataset(Dataset):
    def __init__(self, target_dataset, poison_k=1):
        self.target_dataset = target_dataset
        self.num_classes = len(target_dataset.dataset.classes)
        self.poison_k = poison_k
        
        # Store length for efficient data access
        self.total_length = len(target_dataset) * poison_k

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        # Adjust index for target_dataset and fetch data
        adjusted_idx = idx % len(self.target_dataset)
        image, label = self.target_dataset[adjusted_idx]
        
        # Flip the label randomly
        flipped_label = random.choice([x for x in range(self.num_classes) if x != label])
        
        return image, flipped_label


def plot_one(fpr, tpr, axis, label=None):
    axis[0].plot(fpr, tpr, label=label)
    axis[1].plot(fpr, tpr, label=label)
    