import sys
import time

import torch
from tqdm import tqdm

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import tridiag_to_eigv, hvp
from src.tools.sharpness_tools.utils import get_device
from torch.cuda.amp import GradScaler


def lanczos(model, data_loader, max_itr):
    """
    Lanczos iteration following the wikipedia article here
            https://en.wikipedia.org/wiki/Lanczos_algorithm
    :param model:
    :param data_loader:
    :param max_itr: max iteration
    :return: eigen values, weights
    """
    scalar = GradScaler()
    float_dtype = torch.float64

    model_dim = sum([p.numel() for p in model.parameters()])
    device = next(model.parameters()).device

    # Initializing empty arrays for storing
    tridiag = torch.zeros((max_itr, max_itr), dtype=float_dtype).to(device)
    vecs = torch.zeros((max_itr, model_dim), dtype=float_dtype).to(device)

    # initialize a random unit norm vector
    init_vec = torch.zeros(model_dim, dtype=float_dtype).uniform_(-1, 1)
    init_vec /= torch.norm(init_vec)
    vecs[0] = init_vec

    # placeholders for data
    beta = 0.0
    v_old = torch.zeros(model_dim, dtype=float_dtype).to(device)

    for k in range(max_itr):
        t = time.time()

        v = vecs[k]
        time_mvp = time.time()
        w = hvp(model, data_loader, v, scalar)
        w = w.type(float_dtype)
        time_mvp = time.time() - time_mvp

        w -= beta * v_old
        alpha = torch.dot(w, v)
        tridiag[k, k] = alpha
        w -= alpha * v

        # Reorthogonalization
        for j in range(k):
            tau = vecs[j]
            coeff = torch.dot(w, tau)
            w -= coeff * tau

        beta = torch.norm(w)

        if beta < 1e-6:
            raise ZeroDivisionError
            quit()

        if k + 1 < max_itr:
            tridiag[k, k + 1] = beta
            tridiag[k + 1, k] = beta
            vecs[k + 1] = w / beta

        v_old = v

        # info = f"Iteration {k} / {max_itr} done in {time.time() - t:.2f}s (MVP: {time_mvp:.2f}s)"
        # print(info)

    return vecs, tridiag


def calc_eigenvalues(model, data_loader, max_itr, draws):
    """
    Calculate the top max_iter eigen values of the Hessian matrix over the entire dataset
    :param model:
    :param data_loader:
    :param max_itr: the number of eigen values to be calculated, it is also the iteration number of the lanczos algorithm
    :param draws: how many times to be repeated to calculate mean
    :return: top max_iter eigen values of the Hessian matrix over the entire dataset
    """
    device = get_device(model)
    tri = torch.zeros((draws, max_itr, max_itr)).to(device)
    for num_draws in tqdm(range(draws), ncols=120):
        _, tridiag = lanczos(model, data_loader, max_itr)
        tri[num_draws, :, :] = tridiag.detach().cpu()

    eigen_values, _ = tridiag_to_eigv(tri.detach().cpu())
    eigen_values_std = torch.std(torch.from_numpy(eigen_values), dim=0)
    eigen_values_mean = torch.mean(torch.from_numpy(eigen_values), dim=0)
    return eigen_values_mean, eigen_values_std


if __name__ == "__main__":
    import numpy as np


    class func(object):
        """docstring for func"""

        def __init__(self, H):
            super(func, self).__init__()
            self.H = H
            self.dim = H.shape[0]

        def __call__(self, x):
            return x.T @ (self.H @ x) / 2

        def gradient(self, x):
            return self.H @ x

        def hessian(self):
            return self.H

        def hvp(self, x):
            return self.H @ x


    d = 10
    H = np.random.randn(d, d)
    H = H.T @ H / 2
    H = torch.from_numpy(H).type(torch.float64)
    f = func(H)
    # print(calc_eigenvalues(f, 9, draws=100, max_itr=100))
    theoretical_res = np.sort(np.linalg.eig(H)[0])[::-1]
    for e in theoretical_res:
        print(f"{e:.4f}")
