"""
The DEQ implementation is largely based on https://github.com/locuslab/deq.
"""
from typing import Callable, Union, Tuple, Iterable, Optional
from spaghettini import quick_register
from functools import partial

import torch
from torch import nn
from torch.nn import Module
import torch.autograd as autograd
import wandb

from src.dl.fixed_point_solvers.fixed_point_iterator import fixed_point_iterator
from src.utils.misc import is_scalar

NUM_PRETRAINING_LAYERS = 2


def _log_backward_solver_metrics(backward_solver_outputs):
    """Log the metrics outputted by the backward solver. """
    try:
        prepend_key = "training/backwards_solver_metrics"
        backwards_metrics = {f"{prepend_key}_{k}": v for k, v in backward_solver_outputs.items() if is_scalar(v)}
        # Make sure to set commit=False when logging. Otherwise, the wandb's internal step gets messed up.
        wandb.log(data=backwards_metrics, commit=False)
    except:
        pass


@quick_register
class Skeleton(torch.nn.Module):
    def __init__(
            self,
            input_preprocessor: Module,
            cell: Module,
            classifier_layer: Union[Module, None],
            forward_solver: Callable,
            backward_solver: Union[Callable, None],
            weight_sharing: bool = True,
            z0_init_method: str = "zeros",
            weight_init_std: Optional[float] = None,
            mask_input_injection: bool = False,
            num_pretraining_layers: int = NUM_PRETRAINING_LAYERS,
            num_additional_unroll_steps_after_implicit_forward: int = 0,
            deq_jacobian_scaling: float = 1.) -> None:
        super().__init__()
        self.input_preprocessor = input_preprocessor
        self.f = cell
        self.classifier_layer = classifier_layer if classifier_layer is not None else lambda x: x
        self.forward_solver = forward_solver
        self.backward_solver = backward_solver
        self.weight_sharing = weight_sharing
        self.z0_init_method = z0_init_method
        self.weight_init_std = weight_init_std
        self.mask_input_injection = mask_input_injection
        self.num_pretraining_layers = num_pretraining_layers
        self.num_additional_unroll_steps_after_implicit_forward = num_additional_unroll_steps_after_implicit_forward
        self.deq_jacobian_scaling = deq_jacobian_scaling

        # For now, avoid supporting non-weight-sharing.
        assert weight_sharing, f"We're not supporting non-weight sharing just yet. "

        # Make sure only the supported z0 initialization methods are accepted.
        accepted_z0_init_methods = ["zeros", "normal", "copy_of_input", "external", "mixed_normal_and_zeros"]
        if z0_init_method not in accepted_z0_init_methods:
            message = f"zs_init_method is set to {z0_init_method}, it should be one of {accepted_z0_init_methods}."
            raise ValueError(message)

        # If mask_input_injection is True, then z0s must be initialized with the input. Otherwise the network never
        # sees the input.
        if mask_input_injection and z0_init_method != "copy_of_input":
            raise ValueError(f"If input injection is masked, then z0 must be initialized with the inputs. ")

        if self.weight_init_std is not None:
            def init_weights(m):
                print(f"right here with m: {m}")
                if type(m) in [nn.Linear, nn.Conv1d, nn.Conv2d]:
                    torch.nn.init.normal(m.weight, std=self.weight_init_std)

            self.apply(init_weights)

    def forward_for_unrolled_backward(self, zs0, input_injection, forward_map, model_logs):
        model_logs["inferred_backward_mode"] = "unrolled"

        solver_outputs = self.forward_solver(forward_map, zs0)
        z_star = solver_outputs["result"]
        final_z_star = self.f(z_star, input_injection)

        return final_z_star, solver_outputs, model_logs

    def forward_for_implicit_backward(self, zs0, input_injection, forward_map, model_logs):
        model_logs["inferred_backward_mode"] = "implicit"
        # Forward pass.
        with torch.no_grad():
            model_logs["no_grad_entered"] = 1
            solver_outputs = self.forward_solver(forward_map, zs0)
            z_star = solver_outputs["result"]
        final_z_star = self.f(z_star.requires_grad_(), input_injection)

        # Prepare for backward pass.
        if self.training:
            model_logs["registering_backwards_hook"] = 1

            def backward_hook(grad):
                # Remove the previous hook.
                if self.hook is not None:
                    self.hook.remove()
                    if torch.cuda.is_available():
                        torch.cuda.synchronize()  # To avoid infinite recursion.

                # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star.
                backward_map = lambda y: self.deq_jacobian_scaling * autograd.grad(final_z_star, z_star, y, retain_graph=True)[0] + grad
                backward_solver_outputs = self.backward_solver(backward_map, torch.zeros_like(grad))
                new_grad = backward_solver_outputs['result']

                # Log the backwards solver statistics. This, unfortunately, has to be done here as it's not
                # possible to pipe these statistics through PyTorch internals back to the training loop.
                _log_backward_solver_metrics(backward_solver_outputs)

                return new_grad

            self.hook = final_z_star.register_hook(backward_hook)

        # If asked, take a few unroll steps (as part of the computational graph) to stabilize gradient computation.
        if self.num_additional_unroll_steps_after_implicit_forward != 0:
            additional_unroll_solver = partial(fixed_point_iterator,
                                               num_iters=self.num_additional_unroll_steps_after_implicit_forward)

            additional_solver_outputs = additional_unroll_solver(forward_map, final_z_star.requires_grad_())

            output = additional_solver_outputs["result"]
        else:
            output = final_z_star

        return output, solver_outputs, model_logs

    def pretraining_forward(self, zs0, forward_map, model_logs):
        num_iters = self.num_pretraining_layers

        model_logs["inferred_backward_mode"] = "unrolled"
        pretraining_solver = partial(fixed_point_iterator, num_iters=num_iters)

        solver_outputs = pretraining_solver(forward_map, zs0)
        model_logs["inferred_num_pretraining_layers"] = solver_outputs["num_iters"]

        final_z_star = solver_outputs["result"]

        return final_z_star, solver_outputs, model_logs

    def init_zs0(self, xs_proj, external_zs0=None):
        if self.z0_init_method == "zeros":
            zs0 = torch.zeros_like(xs_proj, requires_grad=True)
        elif self.z0_init_method == "normal":
            zs0 = torch.randn_like(xs_proj)
        elif self.z0_init_method == "mixed_normal_and_zeros":
            bs = xs_proj.shape[0]
            normal_z0s = torch.randn_like(xs_proj)
            mask = torch.bernoulli(0.5 * torch.ones(size=(bs,))).type_as(xs_proj)
            for i in range(len(xs_proj.shape) - 1):
                mask = mask[..., None]
            return mask * normal_z0s
        elif self.z0_init_method == "external":
            assert external_zs0 is not None, f"zs0 asked to be initialized with external data, yet None given. "
            # Make sure that the dimentionality matches.
            if xs_proj.shape == external_zs0.shape:
                zs0 = external_zs0
            else:
                message = "Externally provided zs0 has the incorrect shape. "
                message = message + f" It should have shape={xs_proj.shape[1:]} but instead has {external_zs0.shape}. "
                assert xs_proj.shape[1:] == external_zs0.shape, message

                # Duplicate along the batch axis.
                bs = xs_proj.shape[0]
                repeats = [bs] + [1 for _ in external_zs0.shape]
                zs0 = torch.tile(external_zs0[None, ...], repeats, )
        else:
            assert self.z0_init_method == "copy_of_input"
            zs0 = torch.clone(xs_proj)  # We do want the gradients to flow through zs0, so no .detach() is added.
            # zs0.retain_grad()
        return zs0

    def forward(self, xs, in_pretraining_mode=False, external_zs0=None) -> Tuple[torch.Tensor, dict]:
        model_logs = dict()

        # ____ Preprocess the input. ____
        xs_proj = self.input_preprocessor(xs)

        # ____ Initialize the fixed point. ____
        zs0 = self.init_zs0(xs_proj, external_zs0=external_zs0)

        # ____ Run the recurrent body of the network. ____
        input_injection = xs_proj if not self.mask_input_injection else torch.zeros_like(xs_proj)
        forward_map = lambda z: self.f(z, input_injection)

        # If set, do a pretraining forward pass. (i.e. unroll cell a few times, use unrolled backprop for training).
        if in_pretraining_mode:
            model_logs["inferred_in_pretraining"] = 1
            final_z_star, solver_outs, model_logs = self.pretraining_forward(zs0=zs0, forward_map=forward_map,
                                                                             model_logs=model_logs)
        else:
            model_logs["inferred_in_pretraining"] = 0

            # If backwards solver is None, simply compute forward pass and return. This ensures unrolled backwards pass.
            if self.backward_solver is None:
                forward_fn = self.forward_for_unrolled_backward
            # Use implicit (i.e. trajectory independent) gradients.
            else:
                forward_fn = self.forward_for_implicit_backward
            final_z_star, solver_outs, model_logs = forward_fn(zs0=zs0, input_injection=input_injection,
                                                               forward_map=forward_map, model_logs=model_logs)

        # ____ Run the final classifier. ____
        # TODO: Decide whether to apply any nonlinearity before the final classifier.
        output = self.classifier_layer(final_z_star)

        # Append the solver outputs to the model logs.
        model_logs.update(solver_outs)

        return output, model_logs
