import math
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.distributions as D
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn import CrossEntropyLoss

from multiquery_randomized_smoothing.src.models import architectures

def get_mask_shape(args):
    if args.dataset == "cifar10":
        base_image_size = 32
    elif args.dataset == "tiny_imagenet":
        base_image_size = 64

    image_size = base_image_size + (2 * args.pad_size)
    d = image_size * image_size * 3

    if args.multiply:
        if args.wm_across_channels == "different":
            mask_shape = (3, image_size, image_size)
        elif args.wm_across_channels == "same":
            mask_shape = (image_size, image_size)
    return mask_shape

class GENERAL_ARCH(nn.Module):
    def __init__(self, args):
        super(GENERAL_ARCH, self).__init__()

        # self.first_query_with_mask = args.first_query_with_mask
        # # if masking in first query, initialize mask with values sampled from a uniform distribution on the interval [0,1)
        # if self.first_query_with_mask:
        #     mask_shape = get_mask_shape(args)
        #     if args.wm_init == "random":
        #         self.mask = nn.Parameter(torch.rand(mask_shape), requires_grad=True)
        #     elif args.wm_init == "identity":
        #         self.mask = nn.Parameter(torch.ones(mask_shape), requires_grad=True)

        if args.num_queries == 2:
        # you're definitely masking in second query

                # # outputs a 32*32 matrix for each image in the batch
                # self.wm_cnn = nn.Sequential(
                #     nn.Conv2d(3, 1, kernel_size=1),
                #     # nn.ReLU(True)
                #     nn.Sigmoid()
                # )

                self.second_query_mask_model = architectures.get_architecture(arch=args.second_query_mask_model).to(args.device)

                # # fc layers post concatenating outputs from multiple queries
                # self.fc_final = nn.Linear(20, args.num_classes) # can some other combination work better?

                # # init to identity
                # if args.wm_init == "identity":
                #     self.wm_cnn[0].weight.data.zero_()
                #     self.wm_cnn[0].bias.data.copy_(
                #         torch.tensor([5.], dtype=torch.float32))

        # budget splitting mechanism
        # if args.num_queries == 1:
        #     self.first_query_budget_frac = 1
        # elif args.num_queries == 2:
        if args.budget_split == "fixed":
            if args.budget_split_ratio == "20/80":
                self.first_query_budget_frac = torch.tensor(np.sqrt(0.2), device=args.device)
            elif args.budget_split_ratio == "50/50":
                self.first_query_budget_frac = torch.tensor(1 / np.sqrt(2), device=args.device)
            elif args.budget_split_ratio == "80/20":
                self.first_query_budget_frac = torch.tensor(np.sqrt(0.8), device=args.device)
        elif args.budget_split == "learnt":
            # init at 50/50 split
            self.first_query_budget_frac = nn.Parameter(torch.tensor(1/np.sqrt(2), device=args.device), requires_grad=True)
            
        # finally, initialize base classifier (common to vanilla and adaptive mode)
        self.base_classifier = architectures.get_architecture(arch=args.base_classifier,
                                                              prepend_normalize_layer=True,
                                                              dataset=args.dataset,
                                                              input_channels=3,
                                                              num_classes=args.num_classes)

    def forward(self, x, logging_trackers):

        x_original = copy.deepcopy(x)

        # first query
        if args.multiply:
            # multiply input with the mask
            x = torch.mul(x, self.mask)

        # add noise
        if args.noise:

            # computing sigma for first query
            if args.multiply:
                if args.wm_across_channels == "same":
                    norm_1 = torch.sqrt(torch.tensor(3)) * self.mask.norm(2)
                elif args.wm_across_channels == "different":
                    norm_1 = self.mask.norm(2)

                # sigma_1 = (args.sigma / self.first_query_budget_frac) * (norm_1 / np.sqrt(d))
                sigma_1 = (args.linf_pert * norm_1) / args.mu

                if logging_trackers["mode"] == "train":
                    log(wm_norm_1_file, "{}".format(norm_1))
            else:
                # if args.budget_split == "learnt":
                #     print("first_query_budget_frac {}".format(self.first_query_budget_frac))
                # numerator and denominator are same because no transformation
                # sigma_1 = (args.sigma / self.first_query_budget_frac) * (np.sqrt(d) / np.sqrt(d))
                sigma_1 = (args.linf_pert * np.sqrt(d)) / args.mu

            norm_dist = D.Normal(loc=0., scale=sigma_1)
            noise = norm_dist.rsample(x.shape).to(args.device)
            x += noise

        # # re-apply the mask to reduce the effect of noise
        # if args.multiply:
        #     if args.reapply_mask:
        #         # with torch.no_grad():
        #         #     x = torch.mul(x, self.mask)
        #         # detached_cloned_wm = self.mask.clone().detach()
        #         # x = torch.mul(x, detached_cloned_wm)
        #         x = torch.mul(x, self.mask)

        # Perform the usual forward pass for first query
        output_pred = self.base_arch(x)

        # second query
        if args.num_queries == 2:

            if args.multiply:
                # get the mask and compute the transformed image from it
                second_query_mask = self.wm_cnn(x)
                x_transformed = torch.mul(x_original, second_query_mask)
            else:
                x_transformed = x_original

            # add noise
            if args.noise:
                
                # computing the budget for second query (according to GDP formulation)
                second_query_budget_frac  = torch.sqrt(1 - torch.square(self.first_query_budget_frac)).to(args.device)

                # computing sigma for second query
                if args.multiply:
                    # compute norm from mask
                    if args.wm_across_channels == "same":
                        norm_2 = torch.sqrt(torch.tensor(3)) * torch.norm(second_query_mask.view(args.train_batch_size, -1), p=2, dim=1)
                    elif args.wm_across_channels == "different":
                        norm_2 = torch.norm(second_query_mask.view(args.train_batch_size, -1), p=2, dim=1)

                    # compute sigma from norms
                    sigma_2 = (args.sigma / second_query_budget_frac) * (norm_2 / np.sqrt(d))

                    if logging_trackers["mode"] == "train":
                        log(wm_norm_2_file, "{}".format(norm_2.mean()))

                    # see if this part can be parallelized instead of serialized; 
                    # can be a bottleneck otherwise
                    for i in range(len(x_transformed)):
                        norm_dist = D.Normal(loc=0., scale=sigma_2[i])
                        x_transformed[i] += norm_dist.rsample(x_transformed[i].shape).to(args.device)
                else:
                    # numerator and denominator are same because no transformation
                    sigma_2 = (args.sigma / second_query_budget_frac) * (np.sqrt(d) / np.sqrt(d))
                    # sigma_1 = (args.linf_pert * np.sqrt(d)) / args.mu

                    norm_dist = D.Normal(loc=0., scale=sigma_2)
                    noise = norm_dist.rsample(x_transformed.shape).to(args.device)
                    x_transformed += noise
                
                log(budget_query_split_file, "{} \t {}".format(self.first_query_budget_frac, second_query_budget_frac))

                # pass x_transformed (with noise) to rest of base_arch
                output_pred_2 = self.base_arch(x_transformed)

                # combining 2 queries output such that accuracy is almost same as single query

                # concatenating two query outputs and passing it to another fc layer
                # output_pred = torch.hstack((output_pred, output_pred_2))

                # taking a weighted average of two query output where individual query weights are budget fractions for that query
                output_pred = ((torch.square(self.first_query_budget_frac) * output_pred) + \
                               (torch.square(second_query_budget_frac) * output_pred_2))

        if logging_trackers["mode"] == 'test' and logging_trackers["batch_idx"] == 0:
            if args.noise:
                if args.num_queries == 1:
                    log(sigma_file, "{}".format(sigma_1))
                elif args.num_queries == 2:
                    log(sigma_file, "{} \t {}".format(sigma_1, sigma_2.mean()))

        if logging_trackers["mode"] == 'train' and logging_trackers["batch_idx"] == 0:
            torch.save(x_original, os.path.join(logging_trackers["saved_dir"], "x_original.pt"))
            torch.save(x, os.path.join(logging_trackers["saved_dir"], "x.pt"))
            if args.multiply:
                    torch.save(model.mask, os.path.join(logging_trackers["saved_dir"], "mask.pt"))
            if args.num_queries == 2:
                torch.save(x, os.path.join(logging_trackers["saved_dir"], "x_transformed.pt"))

        return output_pred, x