import os

import util.utils as utils
import time
import numpy as np
import torch
from models.utils.utils import get_activation_function

from torch.utils.tensorboard import SummaryWriter

class Logger(object):
    def __init__(self, args) -> None:        
        self.args = args
        self.device = "cuda" if args.cuda else "cpu"
        self.cuda = self.args.cuda
        self._make_dir()
        self.writer = SummaryWriter(self.tensorboard_save_dir)
        self.saturation_logs = []
        self.skewness_logs = []
        self.grad_hook = None 
        self.forward_logging = True
        self.grad_sum = None
        self.grad_sum_squares = None
        self.element_num = None # TODO: 한번만 계산해도 괜찮음.
        self.grad_mean = []
        self.grad_std = []
        self.forward_hook = []

    def _get_state_id(self, file_type=""):
        if file_type is not None:
            return '_'.join([self.state_id, file_type])
        else:
            return self.state_id

    def _make_dir(self):
        if self.args.pretrained is None:
            self.state_id = utils.make_state_id(self.args)

        else:
            self.state_id = self.args.pretrained.split('/')[-1].replace('.pth.tar', '')
            exist_dir = os.path.join('log', self.state_id)
            os.makedirs(exist_dir, exist_ok=True)

        self.model_save_dir = os.path.join('saved_models')
        self.tensorboard_save_dir = os.path.join('log', self._get_state_id(), 'tensorboard', time.strftime('%Y-%m-%d_%I:%M:%S_%p', time.localtime(time.time())))
        self.acc_save_dir = os.path.join('log', 'acc')
        self.info_save_dir = os.path.join('log', self._get_state_id())

        os.makedirs(self.model_save_dir, exist_ok=True)
        os.makedirs(self.acc_save_dir, exist_ok=True)
        os.makedirs(self.info_save_dir, exist_ok=True)
        if not self.args.evaluate:
            os.makedirs(self.tensorboard_save_dir, exist_ok=True)

    def _save_log(self, save_file_name, content, content_type='acc'):
        with open(save_file_name, 'a') as f:
            if content_type == 'accuracy':
                f.write(str(content)+'\n')

            elif content_type == 'skewness' or content_type == 'saturation':
                for c in content:
                    f.write(format(c, '.8f')+'\n')
                f.write('#### \n')

    def calculate_skewness(self, model, loader):
        model.eval()

        def _calculate_accumulated_statistic(feature, samples, statistic_type, mean=None, var=None, skewness=None):
            for ind, block_output in enumerate(feature):

                # (mb, c, h, w) -> (c, mb, h, w)
                block_output_t = block_output.transpose(0, 1)
                channels = block_output_t.shape[0]

                # (c, mb, h, w) -> (c, mb*h*w)
                block_output_channel = block_output_t.contiguous().view(channels, -1)

                if statistic_type == 'mean':
                    channel_statistic = block_output_channel
                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
    
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples

                    mean[ind] += batch_statistic

                elif statistic_type == 'var':
                    diff = block_output_channel - mean[ind]
                    channel_statistic = torch.pow(diff, 2.0)

                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
                
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples

                    var[ind] += batch_statistic

                elif statistic_type == 'skewness':
                    eps = 1e-5
                    std = torch.pow(torch.sqrt(var[ind]), 3.0)
                    diff = block_output_channel - mean[ind]
                    channel_statistic = torch.pow(diff, 3.0) / (std+eps)
                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
                
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples
                    
                    batch_statistic *= np.sqrt(n*(n-1)) / (n-2)
                    skewness[ind] += batch_statistic

        with torch.no_grad():
            mean = []
            var = []
            skewness = []
            samples = len(loader.dataset)

            depth, channels = self.get_depth_channels(model=model, target_layer='activation')

            for d in range(depth):
                mean.append(torch.zeros((channels[d], 1), device=self.device))
                var.append(torch.zeros((channels[d], 1), device=self.device))
                skewness.append(torch.zeros((channels[d], 1), device=self.device))

            # calculate mean
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'mean', mean)

            # calculate var
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'var', mean, var)

            # calculate skewness
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'skewness', mean, var, skewness)

            block_skewness = np.zeros(len(skewness))

            # average the skewness of channels
            for ind, block_skewness_unit in enumerate(skewness):
                block_skewness[ind] = torch.mean(torch.abs(block_skewness_unit))

            # print block skewness        
            print("Skewness")
            for skwn in block_skewness:
                print(format(skwn, '.8f'))
            print()

        return block_skewness

    def calculate_saturation(self, model, loader, empirical):
        model.eval()
        total_samples = len(loader.dataset)

        with torch.no_grad():
            block_sum = None
            block_len = None
            accumulated_minmax = []

            if empirical:
                acti_upper = None
            elif self.args.activation_type == 'lecun':
                acti_upper = 1.7159
            elif self.args.activation_type == 'sigmoid':
                # acti_upper = 0.5
                for ind in range(len(output)):
                    output[ind] = torch.sub(output[ind], 0.5)
            elif self.args.activation_type == 'tanh' or self.args.activation_type == 'softsign':
                acti_upper = 1
            else:
                print("invalid activation function")
                exit()
            
            # Init block depth
            block_len, channels = self.get_depth_channels(model, target_layer='block' if empirical else 'activation')
            for depth in range(block_len):
                channels_init = torch.zeros(channels[depth], 2, device=self.device)
                channels_init[:, 0] = float("Inf")
                channels_init[:, 1] = float("-Inf")

                accumulated_minmax.append(channels_init)

            # Get channel-wise maximum absolute values.
            if empirical:

                for data, target in loader:
                    if self.cuda:
                        data, target = data.cuda(), target.cuda()

                    channel_minmax = model.module.get_minmax(data, block_output=True, channel_flag=True)
                    
                    for l_ind, (cur_minmax, acc_minmax) in enumerate(zip(channel_minmax, accumulated_minmax)):
                        cur_min, cur_max = cur_minmax[:, 0], cur_minmax[:, 1]
                        acc_min, acc_max = acc_minmax[:, 0], acc_minmax[:, 1]

                        min_stack = torch.stack((cur_min, acc_min), dim=1)
                        max_stack = torch.stack((cur_max, acc_max), dim=1)

                        cur_acc_min, _ = torch.min(min_stack, axis=1)
                        cur_acc_max, _ = torch.max(max_stack, axis=1)

                        accumulated_minmax[l_ind] = torch.stack((cur_acc_min, cur_acc_max), dim=1)

            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()
                
                # output: (block_len, mini_batch, channel, height, width)
                output = model.module.get_activation(data, target='block' if empirical else 'activation')
                if self.args.activation_type == 'sigmoid':
                    output = torch.sub(output, 0.5)
                
                if block_sum is None:
                    # total_frequency/sum: (block_len, 1)
                    block_sum = [torch.zeros(1, device=self.device) for i in range(block_len)]
                    
                eps = 1e-5
                for ind, block_output in enumerate(output):
                    # block_output: (mini_batch, channel, height, width)
                    block_output_abs = torch.abs(block_output)

                    block_output_abs_tp = block_output_abs.transpose(0, 1)
                    # block_output_tp: (channel, mini_batch, height, width)
                    h = w = 1
                    if len(block_output_abs_tp.shape) == 4:
                        c, mb, h, w = block_output_abs_tp.shape
                    else:
                        c, mb = block_output_abs_tp.shape
                    flat_output_abs = block_output_abs_tp.contiguous().view(c, -1)
                    # flat_output: (channel, , mb*h*w)
                    
                    if empirical:
                        acti_upper, _ = torch.max(torch.abs(accumulated_minmax[ind]), dim=1, keepdim=True)
                        if self.args.activation_type == 'sigmoid':
                            output = torch.sub(acti_upper, 0.5)

                    flat_output_nor = (flat_output_abs) / (acti_upper+eps)

                    # flat_output_nor_abs: (channel, mb*h*w)
                    flat_output_nor_abs = torch.abs(flat_output_nor)

                    block_sum[ind] += (torch.sum(flat_output_nor_abs) / (total_samples * c * h * w))
            
            block_saturation = [ float(bus) for bus in block_sum ]

            # print block skewness
            if empirical:
                print("sparsity")
                for strt in block_saturation:
                    print(format(1-strt, '.8f'))
            else:
                print("Saturation")
                for strt in block_saturation:
                    print(format(strt, '.8f'))

            print()

        return block_saturation

    def skewness(self, model, loader):
        skewness = []
        skewness = self.calculate_skewness(model=model, loader=loader)
        save_file_name = os.path.join(self.info_save_dir, 'skewness')
        self._save_log(save_file_name, skewness, 'skewness')

    def saturation(self, model, loader, empirical=False):
        saturation = self.calculate_saturation(model=model, loader=loader, empirical=empirical)

        if empirical:
            save_file_name = os.path.join(self.info_save_dir, 'spasity')
            saturation = [1 - sat for sat in saturation]
        else:
            save_file_name = os.path.join(self.info_save_dir, 'saturation')

        self._save_log(save_file_name, saturation, 'saturation')

    def get_depth_channels(self, model, target_layer):
        depth = None
        channels = []

        model.eval()

        if 'cifar' in self.args.dataset:
            data = torch.randn((2, 3, 32, 32), device=self.device)
        elif 'tinyImageNet' == self.args.dataset:
            data = torch.randn((2, 3, 64, 64), device=self.device)
        elif 'ImageNet' == self.args.dataset:
            data = torch.randn((2, 3, 244, 244), device=self.device)

        with torch.no_grad():
            activations = model.module.get_activation(data, target=target_layer)

        depth = len(activations)
        for block_activation in activations:
            channels.append(block_activation.shape[1])
        
        return depth, channels

    def accuracy_save(self, accuracy):
        save_file_name = os.path.join(self.acc_save_dir, self._get_state_id('log'))
        self._save_log(save_file_name, accuracy, 'accuracy')
    
    def tensor_board(self, label, contents, epoch):
        self.writer.add_scalars(label, contents, epoch)

    def state_save(self, model, acc):
        utils.save_state(model, acc, self._get_state_id('pth.tar'))
  
    def terminate_logging(self):
        self.writer.close()
