import copy
from argparse import ArgumentParser
from collections import OrderedDict
from typing import Callable, Sequence, Any, Dict, List

import optuna
import torch
import torch.nn as nn

from adl4cv.classification.loss.loss_calculator import LossEvaluatorHyperParameterSet
from adl4cv.classification.model.classification_module import ClassificationModule, \
    ClassificationModuleHyperParameterSet, ClassificationModuleHyperParameterSpace
from adl4cv.classification.model.graph._graph.edge_attrib import EdgeAttributeType
from adl4cv.classification.model.graph._graph.graph_builder import DenseGraphBuilderDefinitionSet
from adl4cv.classification.model.graph._transformer.softmax import AttentionScalingType
from adl4cv.classification.model.graph._transformer.transformer_conv import TransformerConv
from adl4cv.classification.model.models import ModelType, ClassificationModuleDefinition
from adl4cv.classification.optimizer.optimizers import OptimizerDefinition, AdamOptimizerDefinition
from adl4cv.classification.optimizer.schedulers import SchedulerDefinition
from adl4cv.parameters.params import DefinitionSet, DefinitionSpace, HyperParameterSpace
from adl4cv.utils.utils import SerializableEnum


class TransformersNormalizationType(SerializableEnum):
    LAYER_NORM = "layer_norm"
    BATCH_NORM = "batch_norm"
    NO_NORM = "no_norm"


def _get_clones(module, N):
    return nn.Sequential(OrderedDict([(f"layer_{i}", copy.deepcopy(module)) for i in range(N)]))


class MessagePassingNetHyperParameterSet(ClassificationModuleHyperParameterSet):
    """HyperParameterSet of the MessagePassingNet"""

    def __init__(self,
                 backbone_size: int = 512,
                 feature_size: int = 150,
                 num_message_pass: int = 1,
                 num_heads: int = 4,
                 dropout_prob: float = 0.0,
                 output_size: int = 10,
                 concat: bool = False,
                 skip_downsampling: bool = False,
                 beta: bool = False,
                 root_weight: bool = True,
                 edge_dim: int = None,
                 bias: bool = True,
                 intermediate_layers_to_concat: Dict[str, int] = {},
                 attention_scaling: AttentionScalingType = AttentionScalingType.NO_SCALING,
                 attention_scaling_threshold: float = None,
                 transformer_normalization_type: TransformersNormalizationType = TransformersNormalizationType.LAYER_NORM,
                 graph_builder_def: DefinitionSet = DenseGraphBuilderDefinitionSet(),
                 head_bias: bool = False,
                 ultimate_super_root_weight_off: bool = False,
                 disable_mp: bool = False,
                 disable_norm1: bool = False,
                 disable_norm2: bool = False,
                 disable_layer1: bool = False,
                 disable_layer2: bool = False,
                 optimizer_definition: OptimizerDefinition = AdamOptimizerDefinition(),
                 scheduler_definition: SchedulerDefinition = None,
                 loss_calc_params: Dict[str, LossEvaluatorHyperParameterSet] = {},
                 **kwargs):
        """
        Creates new HyperParameterSet
        :param feature_size: The dimension of the latent space to reduce the 512 size input
        :param num_message_pass: The number of message passings
        :param num_heads: The number of heads in the multi-headed attentio
        :param dropout_prob: The probability of the Dropour layer
        :param output_size: The size of the output
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(optimizer_definition, scheduler_definition, loss_calc_params, **kwargs)
        self.backbone_size = backbone_size
        self.feature_size = feature_size
        self.num_message_pass = num_message_pass
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob
        self.output_size = output_size
        self.concat = concat
        self.beta = beta
        self.root_weight = root_weight
        self.skip_downsampling = skip_downsampling
        self.graph_builder_def = graph_builder_def
        self.edge_dim = edge_dim
        self.attention_scaling = attention_scaling
        self.attention_scaling_threshold = attention_scaling_threshold
        self.bias = bias
        self.intermediate_layers_to_concat = intermediate_layers_to_concat
        self.head_bias = head_bias
        self.transformer_normalization_type = transformer_normalization_type
        self.ultimate_super_root_weight_off = ultimate_super_root_weight_off
        self.disable_mp = disable_mp
        self.disable_norm1 = disable_norm1
        self.disable_norm2 = disable_norm2
        self.disable_layer1 = disable_layer1
        self.disable_layer2 = disable_layer2

    def definition_space(self):
        return MessagePassingNetHyperParameterSpace(self)


class MessagePassingNetDefinition(ClassificationModuleDefinition):
    """Definition of the HybridMessagePassingNet"""

    def __init__(self, hyperparams: MessagePassingNetHyperParameterSet = MessagePassingNetHyperParameterSet()):
        super().__init__(ModelType.MessagePassingNet, hyperparams)

    @property
    def _instantiate_func(self) -> Callable:
        return MessagePassingNet

    def definition_space(self):
        return MessagePassingNetDefinitionSpace(self.hyperparams.definition_space())


class MessagePassingNet(ClassificationModule):
    """
    Message Passing Network for feature refinement, which includes the latent dimension reduction as first layer
    Inspiration: https://towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8
    It uses TransformerConvLayers from PyTorch Geometric
    """

    def __init__(self, params: MessagePassingNetHyperParameterSet = MessagePassingNetHyperParameterSet()):
        self._graph_builder = None
        super().__init__(params)

    def define_model(self) -> torch.nn.Module:
        input_dim = self.params.feature_size
        output_size = self.params.output_size
        num_message_pass = self.params.num_message_pass
        num_heads = self.params.num_heads
        dropout_prob = self.params.dropout_prob
        concat = hasattr(self.params, "concat") and self.params.concat
        skip_downsampling = hasattr(self.params, "skip_downsampling") and self.params.skip_downsampling
        assert (
                   not skip_downsampling) or num_message_pass == 1, 'Cannot do more than one message passing with feature size expansion'
        beta = self.params.beta
        root_weight = self.params.root_weight
        edge_dim = self.params.edge_dim
        attention_scaling = self.params.attention_scaling
        attention_scaling_threshold = self.params.attention_scaling_threshold
        bias = self.params.bias
        transformer_normalization_type = self.params.transformer_normalization_type
        ultimate_super_root_weight_off = self.params.ultimate_super_root_weight_off
        disable_mp = self.params.disable_mp
        disable_norm1 = self.params.disable_norm1
        disable_norm2 = self.params.disable_norm2
        disable_layer1 = self.params.disable_layer1
        disable_layer2 = self.params.disable_layer2

        edge_dim = 1 if self.params.graph_builder_def.hyperparams.edge_attrib_def.type is not EdgeAttributeType.NO_EDGE_ATTRIB else None

        hybrid_fc = nn.Linear(self.params.backbone_size, input_dim)
        full_feature_size = input_dim * (1 + len(self.params.intermediate_layers_to_concat))

        abstract_transformer = TransformerLayer(in_channels=full_feature_size,
                                                hidden_dim=full_feature_size,
                                                num_heads=num_heads,
                                                dropout_prob=dropout_prob,
                                                concat=concat,
                                                skip_downsampling=skip_downsampling,
                                                beta=beta,
                                                root_weight=root_weight,
                                                edge_dim=edge_dim,
                                                attention_scaling=attention_scaling,
                                                attention_scaling_threshold=attention_scaling_threshold,
                                                bias=bias,
                                                transformer_normalization_type=transformer_normalization_type,
                                                ultimate_super_root_weight_off=ultimate_super_root_weight_off,
                                                disable_mp=disable_mp,
                                                disable_norm1=disable_norm1,
                                                disable_norm2=disable_norm2,
                                                disable_layer1=disable_layer1,
                                                disable_layer2=disable_layer2)
        mpn_layers = _get_clones(abstract_transformer, num_message_pass)
        if skip_downsampling:
            fc = nn.Linear(in_features=full_feature_size * num_heads, out_features=output_size, bias=self.params.head_bias)
        else:
            fc = nn.Linear(in_features=full_feature_size, out_features=output_size, bias=self.params.head_bias)

        intermediate_hybrid_layers = [(self._get_intermediate_layer_key(key), nn.Sequential(OrderedDict([
            ("pooling", nn.AdaptiveAvgPool2d((1, 1))),
            ("flatten", nn.Flatten()),
            ("reduction", nn.Linear(intermediate_layer_size, input_dim))
        ])))
                                      for key, intermediate_layer_size in
                                      self.params.intermediate_layers_to_concat.items()]

        return nn.ModuleDict(OrderedDict([
            *intermediate_hybrid_layers,
            ('hybrid_fc', hybrid_fc),
            ('mpn_layers', mpn_layers),
            ('fc', fc)
        ]))

    @property
    def graph_builder(self):
        if self._graph_builder is None:
            self.graph_builder = self.params.graph_builder_def.instantiate()
        return self._graph_builder

    @graph_builder.setter
    def graph_builder(self, graph_builder):
        self._graph_builder = graph_builder
        self.module_logger.info(f"Graph builder set to {self._graph_builder}!")

    def dimension_reduction(self, x):
        return self.model.hybrid_fc(x)

    def reduce_intermediate(self, intermediate):
        reduced_intermediate = {}
        for intermediate_value_key in self.params.intermediate_layers_to_concat.keys():
            reduced_intermediate[intermediate_value_key] = \
                self.model[self._get_intermediate_layer_key(intermediate_value_key)](intermediate[intermediate_value_key])

        return torch.cat([torch.tensor([], device=self.device)] + list(reduced_intermediate.values()), dim=1)

    def feature_refinement(self, x):
        edge_attribs, edge_index = self.graph_builder.get_graph(x)
        for lay in self.model.mpn_layers:
            x = lay(x, edge_index, edge_attribs)
        return x

    def head(self, x):
        return self.model.fc(x)

    def forward(self, x, intermediate):
        """
        Runs the forward pass on the data
        :param data: Data to be forwarded
        :return: The output of the model
        """
        x = self.dimension_reduction(x)
        intermediate = self.reduce_intermediate(intermediate)
        x = torch.cat((x, intermediate), dim=1)

        x = self.feature_refinement(x)
        x = self.head(x)
        return x

    def initialize_model(self):
        pass

    def _get_intermediate_layer_key(self, layer_name):
        return f"intermediate_{layer_name}"

    @classmethod
    def add_argparse_args(cls, parent_parser):
        super_parser = ClassificationModule.add_argparse_args(parent_parser)
        parser = ArgumentParser(parents=[super_parser], add_help=False)
        parser.add_argument('--input_dim', type=int)
        parser.add_argument('--num_message_pass', type=int)
        parser.add_argument('--num_heads', type=int)
        parser.add_argument('--dropout_prob', type=float)
        parser.add_argument('--output_size', type=float, default=10)
        parser.add_argument('--concat', type=bool, default=False, help='concatenation within the multiheaded attention')
        parser.add_argument('--absolute_concat', type=bool, default=False, help='will not downsample to original size')
        parser.add_argument('--beta', type=bool, default=False, help='Possibility to scale message passing')
        parser.add_argument('--root_weight', type=bool, default=True, help='If false, just attention will happen')
        parser.add_argument('--edge_dim', type=int, default=None,
                            help='Specifies the dimension of edge feature vectors')
        parser.add_argument('--scaling', type=str, default=None,
                            help='Specifies the type of regularization before applying softmax')
        parser.add_argument('--threshold', type=float, default=None, help='Specifies the scale of scaling')
        parser.add_argument('--bias', type=bool, default=True,
                            help='Whether or not to use bias in query and key layers')
        parser.add_argument('--transformer_normalization_type', type=TransformersNormalizationType,
                            default=TransformersNormalizationType.LAYER_NORM,
                            help='Whether or not to use bias in query and key layers'),
        parser.add_argument('--ultimate_super_root_weight_off', type=bool, default=False,
                            help='Whether to skip the skip connection in forward of trans layer')
        return parser


class TransformerLayer(nn.Module):
    def __init__(self,
                 in_channels: int,
                 hidden_dim: int,
                 num_heads: int,
                 dropout_prob: float,
                 concat: bool = False,
                 skip_downsampling: bool = False,
                 beta: bool = False,
                 root_weight: bool = True,
                 edge_dim: int = None,
                 bias: bool = True,
                 attention_scaling: AttentionScalingType = AttentionScalingType.NO_SCALING,
                 attention_scaling_threshold: str = None,
                 ultimate_super_root_weight_off: bool = False,
                 transformer_normalization_type: TransformersNormalizationType = TransformersNormalizationType.LAYER_NORM,
                 disable_mp: bool = False,
                 disable_norm1: bool = False,
                 disable_norm2: bool = False,
                 disable_layer1: bool = False,
                 disable_layer2: bool = False,
                 ):
        super(TransformerLayer, self).__init__()
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob
        self.concat = concat
        self.skip_downsampling = skip_downsampling
        self.beta = beta
        self.root_weight = root_weight
        self.scaling = attention_scaling
        self.threshold = attention_scaling_threshold
        self.bias = bias
        self.transformer_normalization_type = transformer_normalization_type
        self.ultimate_super_root_weight_off = ultimate_super_root_weight_off
        self.disable_mp = disable_mp
        self.disable_norm1 = disable_norm1
        self.disable_norm2 = disable_norm2
        self.disable_layer1 = disable_layer1
        self.disable_layer2 = disable_layer2

        self.out_size = hidden_dim * (1 + (num_heads - 1) * int(concat))

        self.trans = TransformerConv(in_channels=in_channels,
                                     out_channels=hidden_dim,
                                     heads=num_heads,
                                     dropout=dropout_prob,
                                     concat=concat,
                                     beta=beta,
                                     root_weight=root_weight,
                                     edge_dim=edge_dim,
                                     attention_scaling=attention_scaling,
                                     attention_scaling_threshold=attention_scaling_threshold,
                                     bias=bias)
        self.linear1 = nn.Linear(hidden_dim * (1 + (num_heads - 1) * int(concat)),
                                 hidden_dim * (1 + (num_heads - 1) * int(concat)))
        self.dropout = nn.Dropout(dropout_prob)
        if self.skip_downsampling:
            self.linear2 = nn.Linear(self.out_size, self.out_size)
            if transformer_normalization_type == TransformersNormalizationType.LAYER_NORM:
                self.norm2 = nn.LayerNorm(self.out_size)
            elif transformer_normalization_type == TransformersNormalizationType.BATCH_NORM:
                self.norm2 = nn.BatchNorm1d(self.out_size)
            elif transformer_normalization_type == TransformersNormalizationType.NO_NORM:
                self.norm2 = nn.Identity()
        else:
            self.linear2 = nn.Linear(self.out_size, in_channels)
            if transformer_normalization_type == TransformersNormalizationType.LAYER_NORM:
                self.norm2 = nn.LayerNorm(in_channels)
            elif transformer_normalization_type == TransformersNormalizationType.BATCH_NORM:
                self.norm2 = nn.BatchNorm1d(self.out_size)
            elif transformer_normalization_type == TransformersNormalizationType.NO_NORM:
                self.norm2 = nn.Identity()
        if transformer_normalization_type == TransformersNormalizationType.LAYER_NORM:
            self.norm1 = nn.LayerNorm(self.out_size)
        elif transformer_normalization_type == TransformersNormalizationType.BATCH_NORM:
            self.norm1 = nn.BatchNorm1d(self.out_size)
        elif transformer_normalization_type == TransformersNormalizationType.NO_NORM:
            self.norm1 = nn.Identity()
        self.dropout1 = nn.Dropout(dropout_prob)
        self.dropout2 = nn.Dropout(dropout_prob)
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor = None):
        if not self.disable_mp:
            x2 = self.trans(x, edge_index, edge_attr)

            if self.ultimate_super_root_weight_off:
                x = x2
            else:
                x = x.repeat(1, 1 + (self.num_heads - 1) * int(self.concat)) + self.dropout1(x2)

        if not self.disable_norm1:
            x = self.norm1(x)
        if not self.disable_layer1:
            x = self.activation(self.linear1(x))
        # x = self.dropout(x)
        if not self.disable_layer2:
            x = self.linear2(x)
        # x = x + self.dropout2(x)
        if not self.disable_norm2:
            x = self.norm2(x)
        return x


class MessagePassingNetHyperParameterSpace(ClassificationModuleHyperParameterSpace):
    """HyperParameterSpace of the MessagePassingNet"""

    def __init__(self,
                 default_hyperparam_set: MessagePassingNetHyperParameterSet = MessagePassingNetHyperParameterSet()):
        super().__init__(default_hyperparam_set)

    @property
    def search_grid(self) -> Dict[str, Sequence[Any]]:
        search_grid = {
            "feature_size_exp": [5, 6, 7, 8, 9],
            "num_message_pass": [1, 2, 3, 4],
            "num_heads_exp": [0, 1, 2, 3]
        }
        if self.default_hyperparam_set.attention_scaling is not None:
            search_grid.update({"attention_scaling_threshold_exp": [0, 1, 2, 3, 4]})
        search_grid.update(super().search_grid)
        return search_grid

    @property
    def search_space(self) -> Dict[str, Sequence[Any]]:
        search_space = {
            "feature_size_exp": [5, 9],
            "num_message_pass": [1, 4],
            "num_heads_exp": [0, 3]
        }
        if self.default_hyperparam_set.attention_scaling is not None:
            search_space.update({"attention_scaling_threshold_exp": [0, 4]})
        search_space.update(super().search_space)
        return search_space

    def suggest(self, trial: optuna.Trial) -> MessagePassingNetHyperParameterSet:
        """
        Sugges new HyperParameterSet for a trial
        :return: Suggested HyperParameterSet
        """
        hyperparams = super().suggest(trial=trial)

        if hyperparams.feature_size is None:
            hyperparams.feature_size = 2 ** trial.suggest_int("feature_size_exp", 5, 9)

        if hyperparams.num_message_pass is None:
            hyperparams.num_message_pass = trial.suggest_int("num_message_pass", 1, 4)

        if hyperparams.num_heads is None:
            hyperparams.num_heads = 2 ** trial.suggest_int("num_heads_exp", 0, 3)

        if self.default_hyperparam_set.attention_scaling is not AttentionScalingType.NO_SCALING:
            hyperparams.attention_scaling_threshold = 2 ** trial.suggest_int("attention_scaling_threshold_exp", 0, 4)

        return hyperparams


class MessagePassingNetDefinitionSpace(DefinitionSpace):
    """DefinitionSpace of the MessagePassingNet"""

    def __init__(self, hyperparam_space: MessagePassingNetHyperParameterSpace = MessagePassingNetHyperParameterSpace()):
        super().__init__(ModelType.MessagePassingNet, hyperparam_space)

    def suggest(self, trial: optuna.Trial) -> MessagePassingNetDefinition:
        return MessagePassingNetDefinition(self.hyperparam_space.suggest(trial))
