import os
import numpy as np
import copy

import torch
import torch.nn as nn
import torch.distributions as D

from multiquery_randomized_smoothing.src.models import architectures
from multiquery_randomized_smoothing.src.train_utils import log, get_image_size, get_mask_shape

class SINGLE_QUERY_ARCH(nn.Module):
    def __init__(self, args):
        super(SINGLE_QUERY_ARCH, self).__init__()

        self.args = args

        # compute vanilla sigma
        image_size = get_image_size(args)
        self.image_dims = image_size * image_size * 3
        self.vanilla_sigma = (args.linf_pert * np.sqrt(self.image_dims)) / args.mu

        # if learning an average mask in first query, initialize a random or identity mask
        if self.args.first_query_with_mask:
            mask_shape = get_mask_shape(args, image_size)
            if args.mask_init == "random":
                self.first_query_mask = nn.Parameter(torch.rand(mask_shape), requires_grad=True)
            elif args.mask_init == "identity":
                self.first_query_mask = nn.Parameter(torch.ones(mask_shape), requires_grad=True)
            
        # finally, initialize base classifier (common to vanilla and adaptive mode)
        self.base_classifier = architectures.get_architecture(arch=args.base_classifier,
                                                              prepend_preprocess_layer=True,
                                                              prepend_normalize_layer=True,
                                                              dataset=args.dataset,
                                                              input_size=image_size,                                                              
                                                              input_channels=3,
                                                              num_classes=args.num_classes)

    def forward(self, x, logging_trackers):

        x_original = copy.deepcopy(x)

        # FIRST QUERY
        if self.args.first_query_with_mask:
            x = torch.mul(x, self.first_query_mask)
            if self.args.wm_across_channels == "same":
                norm_1 = torch.sqrt(torch.tensor(3)) * self.first_query_mask.norm(2)
            elif self.args.wm_across_channels == "different":
                norm_1 = self.first_query_mask.norm(2)
            # computing sigma for first query based on mask's l2 norm
            sigma_1 = (self.vanilla_sigma / self.args.first_query_budget_frac) * (norm_1 / np.sqrt(self.image_dims))
        else:
            # -- if no first query mask, then add noise directly to inputs (without transformation)
            # -- in the second term below, numerator and denominator are same because no transformation
            sigma_1 = self.vanilla_sigma
            
        norm_dist = D.Normal(loc=0., scale=sigma_1)
        noise = norm_dist.rsample(x.shape).to(self.args.device)
        x += noise # adding the sampled noise

        # # re-apply the mask to reduce the effect of noise
        # if args.multiply:
        #     if args.reapply_mask:
        #         # with torch.no_grad():
        #         #     x = torch.mul(x, self.first_query_mask)
        #         # detached_cloned_wm = self.first_query_mask.clone().detach()
        #         # x = torch.mul(x, detached_cloned_wm)
        #         x = torch.mul(x, self.first_query_mask)

        # Perform the usual forward pass; passing the second query's noisy image
        output_pred = self.base_classifier(x)

        if logging_trackers["mode"] == 'train' or logging_trackers["mode"] == 'test':
            log(logging_trackers["sigma_log_file"], "{}".format(sigma_1, self.vanilla_sigma))
            if logging_trackers["epoch"] % 10 == 0 or logging_trackers["epoch"] == 1:  
                if logging_trackers["batch_idx"] == 0:
                    torch.save(x_original, os.path.join(logging_trackers["saved_dir"], "original_images.pt"))
                    if self.args.first_query_with_mask:
                        torch.save(self.first_query_mask, os.path.join(logging_trackers["saved_dir"], "first_query_mask.pt"))
                    torch.save(x, os.path.join(logging_trackers["saved_dir"], "q_1_images.pt"))
            
        return output_pred