from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import os
from scipy import misc



def normalize_cifar10(X):

    cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
    cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

    mu = torch.tensor(cifar10_mean).view(3,1,1)
    std = torch.tensor(cifar10_std).view(3,1,1)
    return (X - mu)/std

def normalize_cifar100(X):

    CIFAR100_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    mu = torch.tensor(CIFAR100_MEAN).view(3,1,1)
    std = torch.tensor(CIFAR100_STD).view(3,1,1)
    return (X - mu)/std


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)



class OELoss(nn.Module):
    def __init__(self):
        super(OELoss, self).__init__()

    def forward(self, x, y):
        return -(x.mean(1) - torch.logsumexp(x, dim=1)).mean()


class LinfPGDAttack:
    """
    PGD Attack with order=Linf

    :param predict: forward pass function.
    :param loss_fn: loss function.
    :param eps: maximum distortion.
    :param nb_iter: number of iterations.
    :param eps_iter: attack step size.
    :param rand_init: (optional bool) random initialization.
    :param clip_min: mininum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param targeted: if the attack is targeted.
    """

    def __init__(
            self, model, epsilon=8.0 / 255., alpha=2.0 / 255., attack_iters=40, 
            clip_min=0., clip_max=1.,loss_func='CE',ds='CIFAR-10'):
        self.epsilon = epsilon
        self.alpha = alpha
        self.attack_iters = attack_iters
        self.model = model
        self.ds=ds

        if loss_func == 'CE':
            self.loss_func = nn.CrossEntropyLoss()
        elif loss_func == 'OE':
            self.loss_func = OELoss()
        else:
            assert False, 'Not supported loss function {}'.format(loss_func)

        self.clip_min = clip_min
        self.clip_max = clip_max

    def perturb(self, X, y=None):

        X = X.detach().clone()
        if y is not None:
            y = y.detach().clone()
            
        delta = torch.zeros_like(X)
        delta.uniform_(-self.epsilon, self.epsilon)
        delta = clamp(delta, self.clip_min-X, self.clip_max-X)
        delta.requires_grad = True
        for _ in range(self.attack_iters):
              if self.ds == 'CIFAR-10':
                output = self.model(normalize_cifar10(X + delta))
              else:
                output = self.model(normalize_cifar100(X + delta))
              loss = self.loss_func(output, y)
              loss.backward()
              grad = delta.grad.detach()
              d = torch.clamp(delta + self.alpha * torch.sign(grad), min=-self.epsilon, max=self.epsilon)
              d = clamp(d, self.clip_min - X, self.clip_max - X)
              delta.data = d
              delta.grad.zero_()
        return torch.clamp(X + delta, min=self.clip_min, max=self.clip_max)
    
        
