import argparse
import numpy as np
from scipy.special import gamma as gamma
from tqdm import tqdm
import time
import torch


def get_surface(dim):
    return 2 * np.pi**(dim / 2) / gamma(dim / 2)


def number_of_harmonics(deg, dim):
    if dim == 2:
        return 2
    elif deg < dim - 2:
        binomial = np.prod(np.arange(dim - 2, deg + dim - 2) / np.arange(1, deg+1))
    else:
        binomial = np.prod(np.arange(deg + 1, deg + dim - 2) / np.arange(1, dim-2))
    # return (2 * deg + dim - 2) * binomial / (dim - 2)
    return (2 * (deg/(dim - 2)) + 1) * binomial
    # return (2 * deg + dim - 2) * scipy.special.binom(deg + dim - 3, dim - 3) / (dim - 2)


def legendre_polval(x, coeffs, d):
    deg = len(coeffs) - 1
    if deg == 0:
        return np.ones_like(x) * coeffs[0]
    elif deg == 1:
        return coeffs[0] + x * coeffs[1]
    y1 = 1
    y2 = x
    out = coeffs[0] + coeffs[1] * y2
    for i in range(2, deg + 1):
        a1 = (2 * i + d - 4) / (i + d - 3)
        a2 = (i - 1) / (i + d - 3)
        y = a1 * x * y2 - a2 * y1
        
        out += coeffs[i] * y
        y1 = y2
        y2 = y
    return out
    

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dim", type=int, default=4)
    parser.add_argument("--num_iters", type=int, default=100)
    parser.add_argument("--noise", type=float, default=1e-6)
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--ntest", type=int, default=100)
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    noise = args.noise
    d = args.dim
    num_iters = args.num_iters

    noise_name = '' if noise == 0.0 else f"_noise{noise}"
    print(f"d={d}, num_iters={num_iters}, noise={noise}")
    
    ntest = args.ntest
    random_state = args.seed
    
    if d == 3:
        sall = np.arange(1, 16) * 40
    elif d == 4:
        sall = np.arange(1, 16) * 160

    v = np.array([1.] * d) / np.sqrt(d)
    
    qall = np.arange(5, 23)
    print(qall)
    heatmap = np.zeros((len(qall), len(sall)))
    errors = np.zeros((len(qall), num_iters, len(sall), ntest))
    l2norm_noise = np.zeros((len(qall), num_iters, len(sall), ntest))
    freqs = np.zeros(len(sall))
    eps = np.sqrt(0.5)

    tic0 = time.time()

    for j in range(len(qall)):
        tic1 = time.time()
        q = qall[j]
        coeffs = [number_of_harmonics(deg_, d) for deg_ in range(q+1)]
        print(f" q : {q}")
        for i in range(len(sall)):
            s = sall[i]
            rng = np.random.RandomState(23 * i + random_state)
            print(f"s : {s}")
            num_success = 0
            for it in tqdm(range(num_iters)):

                xtrain = rng.randn(s, d)
                xtrain = xtrain / np.linalg.norm(xtrain, axis=1)[:,None]

                # The ground-truth function is P_d^q(<x, u>) where u is the normalized all-ones vector
                c = rng.randn(q+1)
                c_noise = [0,] * (q+1) + rng.randn(q).tolist()
                func = lambda x_: legendre_polval(np.dot(x_ , v) , c, d)

                ftrain = func(xtrain) + noise * legendre_polval(np.dot(xtrain , v) , c_noise, d)

                xtest = rng.randn(ntest, d)
                xtest = xtest / np.linalg.norm(xtest, axis=1)[:,None]
                ftest = func(xtest)

                noise_value = noise * legendre_polval(np.dot(xtest, v) , c_noise, d)
                l2norm_noise[j, it, i, :] = abs(noise_value)


                K = legendre_polval( np.dot( xtrain , xtrain.T) , coeffs, d) / s 
                f = ftrain * np.sqrt(1 / s)
                z = np.linalg.lstsq(K, f, rcond=None)[0]

                fpred = (z.T @ legendre_polval(xtrain @ xtest.T, coeffs, d) / np.sqrt(s))
                err = abs(fpred - ftest)

                errors[j, it, i, :] = err

                num_success += (abs(fpred - ftest) < abs(noise_value)).sum()

            freqs[i] = num_success / (num_iters * ntest) # abs(errors[j, :, i] < max(1e-12, np.mean(l2norm_noise[j,:,i]))).sum() / num_iters
            print(f"it:{it}, i:{i}, freq: {freqs[i]} ({num_success}/{num_iters * ntest})")

            heatmap[j, i] = freqs[i]

        print(f"q = {q} is done.!! |  q-time : {time.time() - tic1:.3f} s | total_time : {time.time() - tic0:.3f} s")

    torch.save({"heatmap": heatmap, "l2norm_noise":l2norm_noise, "errors": errors, "qall": qall, "sall": sall, "d": d, "ntest": ntest, "num_iters": num_iters}, f"./heatmap_d{d}{noise_name}.pth")


if __name__ == "__main__":
    main()
