import torch
import sys
import torch.nn.functional as F
sys.path.append('..')
from utils import Averager, clip_perturbed_image
from loss import *


def clean_sample_selection(model, test_loader, C, rho, tau, device):
    model.eval()

    clean_sets = {i:[] for i in range(C)}
    for x, _, idx in test_loader:
        x, idx = x.to(device), idx.to(device)
        output = model(x)
        prob = F.softmax(output, dim=-1)
        conf, pred = torch.max(prob, dim=-1)
        conf_high = conf > tau
        for i in torch.unique(pred):
            pred_i = pred == i
            clean_idx_i = torch.logical_and(conf_high, pred_i)
            if not (clean_idx_i == False).all():
                clean_sets[int(i)].append(torch.hstack([conf[clean_idx_i].unsqueeze(-1), idx[clean_idx_i].unsqueeze(-1),
                                                        pred[clean_idx_i].type(torch.int).unsqueeze(-1)]))

    k = len(test_loader.dataset) * rho / C
    clean_dataset = []
    for i in clean_sets.keys():
        if len(clean_sets[i]) > 0:
            clean_sets[i] = torch.vstack(clean_sets[i])
            if len(clean_sets[i]) > k:
                _, retain_set_i = torch.topk(clean_sets[i][:, 0], k=int(k))
                clean_sets[i] = clean_sets[i][retain_set_i]
            clean_dataset.append(clean_sets[i])
    assert len(clean_dataset) > 0
    clean_dataset = torch.vstack(clean_dataset)

    return clean_dataset


def zoo(x, y, model, delta, args, device, clean=True):
    ce_loss = nn.CrossEntropyLoss(size_average=None, reduce=False, reduction='none')
    x_tilda = x + delta
    x_tilda = clip_perturbed_image(x, x_tilda)

    if not args.zo:
        pred = model(x_tilda)
        if clean:
            loss = args.wclean * ce_loss(pred, y)
        else:
            loss = im_loss(pred)
        grad, = torch.autograd.grad(loss, delta)
    else:
        batch_size = x_tilda.size()[0]
        channel = x_tilda.size()[1]
        h = x_tilda.size()[2]
        w = x_tilda.size()[3]
        x_temp = x_tilda.detach()

        with torch.no_grad():
            mu = torch.tensor(args.mu).to(device)
            q = torch.tensor(args.q).to(device)

            # Forward Inference (Original)
            recon_pre = model(x_temp)
            if clean:
                loss_0 = args.wclean * ce_loss(recon_pre, y)
            else:
                loss_0 = im_loss(recon_pre, reduce=False)

            # ZO Gradient Estimation
            grad_est = torch.zeros_like(x_temp).to(device)
            loss_tmps = []
            for k in range(args.q):
                # Obtain a random direction vector
                u = torch.normal(0, args.sigma, size=(batch_size, channel, h, w)).to(device)
                u /= torch.sqrt(torch.sum(u ** 2, dim=(1, 2, 3))).reshape(batch_size, 1, 1, 1).expand(
                    batch_size, channel, h, w)

                # Forward Inference (reconstructed image + random direction vector)
                recon_q_pre = model(x_temp + mu * u)

                # Loss Calculation and Gradient Estimation
                if clean:
                    loss_tmp = args.wclean * ce_loss(recon_q_pre, y)
                else:
                    loss_tmp = im_loss(recon_q_pre, reduce=False)
                loss_diff = torch.tensor(loss_tmp - loss_0)
                grad_est = grad_est + u * loss_diff.reshape(batch_size, 1, 1, 1) / (mu * q)
                loss_tmps.append(loss_tmp.detach().cpu().mean())
        grad = grad_est
    delta.data = delta + args.lr * torch.sign(grad)  # update perturbation
    delta.data = torch.clamp(delta, -args.ad_scale, args.ad_scale)  # project onto L-infinity ball
    delta.data = clip_perturbed_image(x, x + delta) - x  # clip perturbed image
    return delta, loss_tmps


def offline_tta(args, model, train_loader, test_loader, device=None):
    # black-box model
    model.eval()
    for param in model.parameters():
        param.requires_grad_(False)

    # torch.autograd.set_detect_anomaly(True)
    with torch.no_grad():
        clean_sets = clean_sample_selection(model, train_loader, model.out_dim, args.rho, args.tau, device)
    print(f"number of clean data selected: {clean_sets.shape[0]}")

    avgr = Averager()
    for i_bat, data_bat in enumerate(train_loader):
        x, y, idx = (data_bat[0].to(device), data_bat[1].to(device), data_bat[2].to(device))
        clean_mask = torch.isin(idx, clean_sets[:, 1])
        clean_set_mask = torch.isin(clean_sets[:, 1], idx)
        clean_x = x[clean_mask]
        noisy_x = x[~clean_mask]
        clean_y = clean_sets[clean_set_mask, 2].type(torch.long)

        model.eval()

        delta_clean = torch.zeros_like(clean_x).uniform_(-args.ad_scale, args.ad_scale).to(device)
        delta_clean.requires_grad = True
        delta_noisy = torch.zeros_like(noisy_x).uniform_(-args.ad_scale, args.ad_scale).to(device)
        delta_noisy.requires_grad = True
        for i in range(args.steps):
            if clean_x.shape[0] != 0:
                delta_clean, loss_clean_tmps = zoo(clean_x, clean_y, model, delta_clean, args, device, clean=True)
            delta_noisy, loss_noisy_tmps = zoo(noisy_x, None, model, delta_noisy, args, device, clean=False)
        if clean_x.shape[0] != 0:
            x_tilda_clean = clean_x + delta_clean.detach()
            x_tilda_clean = clip_perturbed_image(clean_x, x_tilda_clean)
            logits_clean = model(x_tilda_clean)
            ypred = logits_clean.argmax(dim=-1)
            acc_clean = (ypred == y[clean_mask]).float().mean().item()
            avgr.update(acc_clean, nrep=clean_x.shape[0])

        x_tilda_noisy = noisy_x + delta_noisy.detach()
        x_tilda_noisy = clip_perturbed_image(noisy_x, x_tilda_noisy)
        logits_noisy = model(x_tilda_noisy)
        ypred = logits_noisy.argmax(dim=-1)
        acc_noisy = (ypred == y[~clean_mask]).float().mean().item()
        avgr.update(acc_noisy, nrep=noisy_x.shape[0])
        print(f"batch {i_bat+1}, acc_clean = {acc_clean}, acc_noisy = {acc_noisy}, total acc = {avgr.avg}", flush=True)
    final_acc = avgr.avg

    return [final_acc]