import unittest
from unittest import TestCase
from functools import partial

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.dl.models.skeleton import Skeleton
from src.utils.testing_utils.bilevel_quadratic_test_case import BilevelQuadraticTestCase

BATCH_SIZE = 32
DIM = 50
NUM_FORWARD_ITERS = 1000
LEARNING_RATE = 0.1


class TestSkeleton(TestCase):

    def _get_skeleton(self, z0_init_method="zeros", mask_input_injection=False):
        class IdentityModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.out_features = DIM

            def forward(self, xs):
                return xs

        input_preprocessor = IdentityModule()

        def solver(fn_map, zs, curr_num_iters=NUM_FORWARD_ITERS):
            for i in range(curr_num_iters):
                zs = fn_map(zs)
            return dict(result=zs)

        class CellModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, z, x):
                return x + z

        curr_forwards_solver = partial(solver, curr_num_iters=2)
        curr_backwards_solver = None
        classifier_layer = IdentityModule()

        return Skeleton(
            input_preprocessor=input_preprocessor,
            cell=CellModule(),
            classifier_layer=classifier_layer,
            forward_solver=curr_forwards_solver,
            backward_solver=curr_backwards_solver,
            weight_sharing=True,
            z0_init_method=z0_init_method,
            mask_input_injection=mask_input_injection
        )

    def test_check_forward_coverage(self):
        """Check if forward pass enters right conditionals given different backward modes and when training/testing. """
        # Prepare for instantiating
        input_preprocessor = torch.nn.Linear(in_features=DIM, out_features=DIM)
        input_preprocessor.output_dim = DIM

        class DummyCell(torch.nn.Module):
            def forward(self, z, x):
                return z + x

        cell = DummyCell()

        classifier_layer = torch.nn.Linear(in_features=DIM, out_features=DIM)

        # Define the forward solver as a fixed point iterator with a single number of iterations.
        def forward_solver(fn_map, zs):
            return dict(result=fn_map(zs))

        def backward_solver(fn_map, zs):
            return fn_map(zs)

        def run_forward_pass_with_instantiated_skeleton(forward_solv, backward_solv, in_training_mode,
                                                        in_pretraining_mode=False, num_pretraining_layers=2):
            skeleton = Skeleton(
                input_preprocessor=input_preprocessor,
                cell=cell,
                classifier_layer=classifier_layer,
                forward_solver=forward_solv,
                backward_solver=backward_solv,
                weight_sharing=True,
                num_pretraining_layers=num_pretraining_layers
            )
            skeleton.training = in_training_mode

            xs = torch.zeros((DIM, DIM))
            ys, model_logs = skeleton.forward(xs, in_pretraining_mode=in_pretraining_mode)

            return model_logs

            # ____ Case 1: Use unrolling-based gradients. ____
        model_logs = run_forward_pass_with_instantiated_skeleton(forward_solv=forward_solver,
                                                                 backward_solv=None,
                                                                 in_training_mode=True,)
        self.assertEqual(model_logs["inferred_in_pretraining"], 0)
        self.assertEqual(model_logs["inferred_backward_mode"], "unrolled")
        self.assertTrue("no_grad_entered" not in model_logs)
        self.assertTrue("registering_backwards_hook" not in model_logs)

        # ____ Case 2: Use implicit gradients, in training mode. ____
        model_logs = run_forward_pass_with_instantiated_skeleton(forward_solv=forward_solver,
                                                                 backward_solv=backward_solver,
                                                                 in_training_mode=True)
        self.assertEqual(model_logs["inferred_in_pretraining"], 0)
        self.assertEqual(model_logs["inferred_backward_mode"], "implicit")
        self.assertTrue("no_grad_entered" in model_logs)
        self.assertTrue("registering_backwards_hook" in model_logs)

        # ____ Case 3: Use implicit gradients, but in test mode. ____
        model_logs = run_forward_pass_with_instantiated_skeleton(forward_solv=forward_solver,
                                                                 backward_solv=backward_solver,
                                                                 in_training_mode=False)
        self.assertEqual(model_logs["inferred_in_pretraining"], 0)
        self.assertEqual(model_logs["inferred_backward_mode"], "implicit")
        self.assertTrue("no_grad_entered" in model_logs)
        self.assertTrue("registering_backwards_hook" not in model_logs)

        # ____ Case 4: Test pretraining mode. ____
        num_pretraining_layers = 11
        model_logs = run_forward_pass_with_instantiated_skeleton(forward_solv=forward_solver,
                                                                 backward_solv=backward_solver,
                                                                 in_training_mode=True,
                                                                 in_pretraining_mode=True,
                                                                 num_pretraining_layers=num_pretraining_layers)
        self.assertEqual(model_logs["inferred_in_pretraining"], 1)
        self.assertEqual(model_logs["inferred_num_pretraining_layers"], num_pretraining_layers)
        self.assertEqual(model_logs["inferred_backward_mode"], "unrolled")
        self.assertTrue("no_grad_entered" not in model_logs)
        self.assertTrue("registering_backwards_hook" not in model_logs)
        self.assertTrue(model_logs["num_iters"] == num_pretraining_layers)

    @unittest.skip("FIX THE Z0 INITIALIZATION TESTS. SKIPPING FOR NOW. ")
    def test_zs0_initialization(self):
        """Check whether the deq states z0s are initilized correctly."""
        # Zero initialization.
        xs_proj = torch.randn(size=(128, 15, 17))
        curr_z0_init_method = "zeros"
        skeleton = self._get_skeleton(z0_init_method=curr_z0_init_method)
        zs0 = skeleton.init_zs0(xs_proj=xs_proj)
        self.assertTrue(torch.allclose(zs0, torch.zeros_like(xs_proj)))

        # Copy of zs_proj.
        xs_proj = torch.randn(size=(128, 15, 17), requires_grad=True)
        xs_proj.requires_grad = True
        curr_z0_init_method = "copy_of_input"
        skeleton = self._get_skeleton(z0_init_method=curr_z0_init_method)
        zs0 = skeleton.init_zs0(xs_proj=xs_proj)
        self.assertTrue(torch.allclose(zs0, xs_proj))
        dummy_loss = zs0.sum()
        dummy_loss.backward()
        self.assertTrue(torch.allclose(xs_proj.grad, zs0.grad))

        # External zs0, initialize along batch.
        xs_proj = torch.randn(size=(128, 15, 17), requires_grad=True)
        curr_z0_init_method = "external"
        skeleton = self._get_skeleton(z0_init_method=curr_z0_init_method)
        external_zs0 = torch.zeros(size=(15, 17), requires_grad=True)
        zs0 = skeleton.init_zs0(xs_proj=xs_proj, external_zs0=external_zs0)
        self.assertTrue(torch.allclose(zs0[0], external_zs0))

        # External zs0, initialize along batch.
        xs_proj = torch.randn(size=(128, 15, 17), requires_grad=True)
        curr_z0_init_method = "external"
        skeleton = self._get_skeleton(z0_init_method=curr_z0_init_method)
        external_zs0 = torch.randn(size=(128, 15, 17))
        zs0 = skeleton.init_zs0(xs_proj=xs_proj, external_zs0=external_zs0)
        self.assertTrue(torch.allclose(zs0, external_zs0))

    def test_input_injection(self):
        """Check if input injection is implemented correctly when it's turned off and on. """
        # Input injection is on.
        mask_input_injection = False
        skeleton = self._get_skeleton(z0_init_method="copy_of_input", mask_input_injection=mask_input_injection)
        xs = torch.randn(size=(5, 6, 7))
        xs.requires_grad = True
        outs, _ = skeleton(xs)
        # The coefficient is 4 = 1 + 1 + 2 = 1 (z_init copy_of_input) + 1 (final extra step for implicit gradients)
        # + 2 (solver steps)
        self.assertTrue(torch.allclose(4 * xs, outs))

        # Input injection off.
        mask_input_injection = True
        skeleton = self._get_skeleton(z0_init_method="copy_of_input", mask_input_injection=mask_input_injection)
        xs = torch.randn(size=(5, 6, 7))
        xs.requires_grad = True
        outs, _ = skeleton(xs)
        self.assertTrue(torch.allclose(xs, outs))

    def test_check_implicit_gradients(self, plot_implicit_vs_explicit_gradient_fidelity=False):
        """Check that the implicit gradients are computed properly on a toy example."""

        # Toy example at hand: Assume we're given an x, and z. We're trying to minimize (x-z)**2. This can be done with
        # gradient descent, which can be expressed using a linear model with hardcoded weights. Use this example
        # to compute the implicit gradients analytically and compre with the values the codebase computes.

        # Prepare to run the forward pass.
        class IdentityModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.out_features = DIM

            def forward(self, xs):
                return xs

        input_preprocessor = IdentityModule()

        # Define the solver as a fixed point iterator.
        def solver(fn_map, zs, curr_num_iters=NUM_FORWARD_ITERS):
            for i in range(curr_num_iters):
                zs = fn_map(zs)
            return dict(result=zs)

        # ____ Instantiate the skeleton, compute forward pass and compute loss. ____
        def get_grad_wrt_x(backwards_solv, num_forward_iters=NUM_FORWARD_ITERS, lr=LEARNING_RATE):
            class CellModule(torch.nn.Module):
                def __init__(self):
                    super().__init__()

                def forward(self, z, x):
                    return z + lr * 2 * (x - z)

            cell = CellModule()
            classifier_layer = IdentityModule()

            curr_forwards_solver = partial(solver, curr_num_iters=num_forward_iters)
            curr_backwards_solver = partial(backwards_solv, curr_num_iters=1000) \
                if backwards_solv is not None else None
            # Instantiate.
            skeleton = Skeleton(
                input_preprocessor=input_preprocessor,
                cell=cell,
                classifier_layer=classifier_layer,
                forward_solver=curr_forwards_solver,
                backward_solver=curr_backwards_solver,
                weight_sharing=True
            )
            skeleton.training = True

            # Run the forward pass. Make sure that the forward pass converges.
            x = torch.ones((DIM, DIM))
            x.requires_grad = True
            y, model_logs = skeleton.forward(x)
            # self.assertTrue(torch.allclose(y, x))

            outer_loss = torch.sum(0.5 * y ** 2)
            outer_loss.backward()

            return x, y

        # Confirm that the gradient is correctly using explicit/unrolled gradients.
        xs, ys = get_grad_wrt_x(backwards_solv=None)
        self.assertTrue(torch.allclose(xs.grad, torch.ones_like(xs.grad)))
        xs, ys = get_grad_wrt_x(backwards_solv=solver)
        self.assertTrue(torch.allclose(xs.grad, torch.ones_like(xs.grad)))

        if plot_implicit_vs_explicit_gradient_fidelity:
            learning_rates = [0.04, 0.998]
            fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False, figsize=(15, 5))
            for i, lr in enumerate(learning_rates):
                # Check how quickly the explicit and implicit gradients converge to the correct value.
                all_num_iters = list(range(200))
                unrolled_grad_dist = list()
                unrolled_diffs = list()
                for num_iters in all_num_iters:
                    curr_xs, curr_ys = get_grad_wrt_x(backwards_solv=None, num_forward_iters=num_iters,
                                                      lr=lr)
                    unrolled_diffs.append(abs(float(curr_xs[0, 0]) - float(curr_ys[0, 0])))
                    unrolled_grad_dist.append(abs(1. - float(curr_xs.grad[0, 0])))

                implicit_grad_dist = list()
                implicit_diffs = list()
                for num_iters in all_num_iters:
                    curr_xs, curr_ys = get_grad_wrt_x(backwards_solv=solver, num_forward_iters=num_iters,
                                                      lr=lr)
                    implicit_diffs.append(abs(float(curr_xs[0, 0]) - float(curr_ys[0, 0])))
                    implicit_grad_dist.append(abs(1. - float(curr_xs.grad[0, 0])))

                # Plot.
                print(unrolled_grad_dist)
                print(implicit_grad_dist)
                axs[0, 0].scatter(unrolled_diffs, unrolled_grad_dist, label=f"unrolled - lr={lr}")
                axs[0, 0].scatter(implicit_diffs, implicit_grad_dist, label=f"implicit - lr={lr}")
                # plt.plot(all_num_iters, [(unrolled_grad_values[i])/(implicit_grad_values[i]) for i in range(20)])
                axs[0, 0].set_xlabel(f"|current estimate - correct|")
                axs[0, 0].set_ylabel(f"|current grad estimate - correct|")
                axs[0, 0].legend()
                # axs[0, 0].set_title(f"Learning Rate = {lr}")
            plt.tight_layout()
            plt.show()

    def test_multidimentional_toy_example(self):
        """Check if gradients computed using bilevel quadratic model is correct"""
        curr_lr = 0.1
        curr_num_solver_iter = 1000

        A1 = torch.Tensor(np.array([[2., 0.], [0, 0.5]]))
        A2 = torch.Tensor(np.array([[2., 0.], [1, 0.5]]))
        A3 = torch.Tensor(np.array([[1., 0.5], [0.5, 1]]))
        As = [A1, A2, A3]
        use_backwards_solver_options = [False, True]
        additional_unroll_steps_after_forward = [0, 5]
        for curr_A in As:
            for use_backward_solver in use_backwards_solver_options:
                for additional_unroll_steps in additional_unroll_steps_after_forward:
                    # Set up the solver.
                    def solver(fn_map, zs, curr_num_iters=curr_num_solver_iter):
                        for i in range(curr_num_iters):
                            zs = fn_map(zs)
                        return dict(result=zs)

                    # Instantiate the bilevel quadratic test case.
                    curr_dim = curr_A.shape[0]
                    x_init_value_ = torch.ones((1, curr_dim))
                    test_case = BilevelQuadraticTestCase(
                        A=curr_A,
                        solver_fn=solver,
                        use_backwards_solver=use_backward_solver,
                        dim=curr_dim,
                        x_init_value=x_init_value_,
                        lr=curr_lr,
                        num_additional_unroll_steps_after_implicit_forward=additional_unroll_steps
                    )

                    # Get the fixed point and gradients.
                    z, curr_x_grad, model_logs = test_case.get_fixed_point_and_grad()

                    # ____ Check the results. ____
                    # Check that the correct gradient computation means is selected.
                    if not use_backward_solver:
                        self.assertEqual(model_logs["inferred_backward_mode"], "unrolled")
                    else:
                        self.assertEqual(model_logs["inferred_backward_mode"], "implicit")

                    # Check that the fixed point is correct.
                    curr_inner_loss = test_case.get_inner_loss(z=z)
                    self.assertTrue(torch.allclose(curr_inner_loss, torch.zeros_like(curr_inner_loss), atol=1e-6))

                    # Check that the gradients are computed correctly.
                    correct_x_grad = test_case.get_correct_grad(z=z)
                    self.assertTrue(torch.allclose(curr_x_grad, correct_x_grad))


if __name__ == "__main__":
    """
    Use the following command to run the tests. 
    python -m unittest -v src.dl.models.test_skeleton
    """

    unittest.main()
