"""
The Anderson solver implementation is taken from https://github.com/locuslab/deq/blob/master/lib/solvers.py.
"""
from typing import Sequence, Union
from spaghettini import quick_register
import math

from functools import reduce, partial

import torch

M = 6
LAM = 1e-4
THRESHOLD = 50
EPS = 1e-3
DEFAULT_STOP_MODE = 'rel'
BETA = 1.0
RELATIVE_RESIDUAL_EPS = 1e-5


def prod(lst: Sequence[Union[int, float]]):
    assert isinstance(lst, list) or isinstance(lst, tuple)
    return reduce(lambda x, y: x * y, lst)


def normalized_l2_norm(x, normalize_by_sqrt=True):
    denom = math.sqrt(prod(x.shape)) if normalize_by_sqrt else prod(x.shape)
    return x.norm().item() / denom


def differentiable_slice_assignment(target, source, curr_slice, use_differentiable_version=True,
                                    confirm_equivalence=False):
    """"
    The differentiable version of target[slice] = source.

    https://stackoverflow.com/questions/60927234/torch-assign-not-in-place-by-tensor-slicing-in-pytorch

    Todo: See if this function can be made more efficient.
    """
    if confirm_equivalence:
        mask = torch.zeros_like(target, device=target.device, dtype=torch.bool)
        mask[curr_slice] = True
        diff_result = target.masked_scatter(mask, source)

        target[curr_slice] = source
        nondiff_result = target

        is_equivalent = torch.allclose(diff_result, nondiff_result)
        print(f"\nSlice assignment comparison: {is_equivalent}\n")
        assert is_equivalent

    if use_differentiable_version:
        mask = torch.zeros_like(target, device=target.device, dtype=torch.bool)
        mask[curr_slice] = True
        return target.masked_scatter(mask, source)
    else:
        target[curr_slice] = source
        return target


@quick_register
def anderson(f, x0, m=M, lam=LAM, threshold=THRESHOLD, eps=EPS, stop_mode=DEFAULT_STOP_MODE, beta=BETA,
             make_differentiable=False, **kwargs):
    """ Anderson acceleration for fixed point iteration."""
    assert stop_mode in ["abs", "rel"]

    bsz = x0.shape[0]
    dim = reduce((lambda x, y: x * y), x0.shape[1:])

    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    X = torch.zeros(bsz, m, dim, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, dim, dtype=x0.dtype, device=x0.device)

    # Decide whether to use the differentiable version of sliced assignment or not.
    slice_assign = partial(differentiable_slice_assignment, use_differentiable_version=make_differentiable)

    # Initialize the tracking of the fixed points X and function evaluations F.
    curr_slice = (slice(None), slice(0, 1))
    X = slice_assign(X, x0.view(bsz, 1, -1), curr_slice)
    F = slice_assign(F, f(x0).view(bsz, 1, -1), curr_slice)

    curr_slice = (slice(None), slice(1, 2))
    X = slice_assign(X, F[:, 0].unsqueeze(1), curr_slice)
    F = slice_assign(F, f(F[:, 0].view_as(x0)).view(bsz, 1, -1), curr_slice)

    # Start constructing the A matrix for solving the inner linear system.
    H = torch.zeros(bsz, m + 1, m + 1, dtype=x0.dtype, device=x0.device)

    all_ones_tensor = torch.ones_like(H[:, 0:1, 1:])
    curr_slice = (slice(None), slice(0, 1), slice(1, None))
    H = slice_assign(H, all_ones_tensor, curr_slice)

    all_ones_tensor = torch.ones_like(H[:, 1:, 0:1])
    curr_slice = (slice(None), slice(1, None), slice(0, 1))
    H = slice_assign(H, all_ones_tensor, curr_slice)

    # Start constructing the b matrix for solving the inner linear system.
    y = torch.zeros(bsz, m + 1, 1, dtype=x0.dtype, device=x0.device)

    all_ones_tensor = torch.ones_like(y[:, 0:1])
    curr_slice = (slice(None), slice(0, 1))
    y = slice_assign(y, all_ones_tensor, curr_slice)

    # Initialize trackers.
    diff_modes = ['abs', 'rel', 'abs_normalized_by_sqrt_of_dims', 'abs_normalized_by_num_of_dims']
    trace_dict = {k: [] for k in diff_modes}
    lowest_dict = {k: 1e10 for k in diff_modes}
    lowest_step_dict = {k: 0 for k in diff_modes}
    lowest_xest = None

    for k in range(2, threshold):
        n = min(k, m)
        G = F[:, :n] - X[:, :n]

        # Update the A matrix in the linear system.
        curr_slice = (slice(None), slice(1, n + 1), slice(1, n + 1))
        H_update = torch.bmm(G, G.transpose(1, 2)) + lam * torch.eye(n, dtype=x0.dtype, device=x0.device)[None]
        H = slice_assign(H, H_update, curr_slice)

        # Solve the linear system.
        # alpha = torch.linalg.solve(H[:, :n + 1, :n + 1], y[:, :n + 1])[:, 1:n + 1, 0]  # (bsz x n)
        alpha = torch.solve(y[:, :n + 1], H[:, :n + 1, :n + 1])[0][:, 1:n + 1, 0]

        # Update the fixed points X.
        curr_slice = (slice(None), slice(k % m, k % m + 1))
        X_update = beta * (alpha[:, None] @ F[:, :n])[:, 0] + (1 - beta) * (alpha[:, None] @ X[:, :n])[:, 0]
        X = slice_assign(X, X_update.unsqueeze(1), curr_slice)

        # Update the function evaluations F.
        curr_slice = (slice(None), slice(k % m, k % m + 1))
        F_update = f(X[:, k % m].reshape_as(x0)).reshape(bsz, 1, -1)
        F = slice_assign(F, F_update, curr_slice)

        # Compute iterate differences.
        gx = (F[:, k % m] - X[:, k % m]).view_as(x0)
        abs_diff = gx.norm().item()
        abs_diff_normalized_by_sqrt_of_dims = normalized_l2_norm(gx, normalize_by_sqrt=True)
        abs_diff_normalized_by_num_of_dims = normalized_l2_norm(gx, normalize_by_sqrt=False)
        rel_diff = abs_diff / (RELATIVE_RESIDUAL_EPS + F[:, k % m].norm().item())
        abs_diff_per_example = torch.sqrt(torch.sum(gx**2, dim=(1, 2)))
        rel_diff_per_example = abs_diff_per_example / (RELATIVE_RESIDUAL_EPS + F[:, k % m].norm(dim=list(range(1, len(F[:, k % m].shape)))))

        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff,
                     'abs_normalized_by_sqrt_of_dims': abs_diff_normalized_by_sqrt_of_dims,
                     'abs_normalized_by_num_of_dims': abs_diff_normalized_by_num_of_dims}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        trace_dict['abs_normalized_by_sqrt_of_dims'].append(abs_diff_normalized_by_sqrt_of_dims)
        trace_dict['abs_normalized_by_num_of_dims'].append(abs_diff_normalized_by_num_of_dims)

        for mode in diff_modes:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode:
                    lowest_xest = X[:, k % m].view_as(x0)
                    # lowest_gx = gx.clone().detach()
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = k

        # if trace_dict[stop_mode][-1] < eps:
        #     for _ in range(threshold - 1 - k):
        #         trace_dict[stop_mode].append(lowest_dict[stop_mode])
        #         trace_dict[alternative_mode].append(lowest_dict[alternative_mode])
        #     break

    recorded_nsteps = k

    out = {"result": lowest_xest,
           "anderson_lowest_diff": lowest_dict[stop_mode],
           "anderson_lowest_abs_diff": lowest_dict["abs"],
           "diff_l2": abs_diff_per_example,
           "diff_l2_mean": lowest_dict["abs"],
           "anderson_lowest_rel_diff": lowest_dict["rel"],
           "rel_diff_mean": lowest_dict["rel"],
           "rel_diff": rel_diff_per_example,
           "anderson_lowest_abs_diff_normalized_by_sqrt_of_dims": lowest_dict["abs_normalized_by_sqrt_of_dims"],
           "anderson_lowest_abs_diff_normalized_by_num_of_dims": lowest_dict["abs_normalized_by_num_of_dims"],
           "anderson_nstep": lowest_step_dict[stop_mode],
           "anderson_recorded_nsteps": recorded_nsteps,
           # "prot_break": False,
           "anderson_abs_trace": trace_dict['abs'],
           "anderson_rel_trace": trace_dict['rel'],
           "anderson_eps": eps,
           "anderson_threshold": threshold}
    X = F = None

    return out
