import jax.numpy as jnp
import jax
import numpy as np
from sklearn.datasets import make_regression
from sgd import run_alphas_experiment
from utils import save_experiment

jax.config.update("jax_enable_x64", True)


key = jax.random.PRNGKey(42)
# number of samples, number of features
n, p = 10, 4
# batch size
bs = 1
# generate data
A, _ = make_regression(n_samples=n, n_features=p, noise=0.1)
A = A / np.linalg.norm(A, 2)
A = jnp.array(A)

# initialization
x0 = np.random.rand(p)
x0 = x0 / np.linalg.norm(x0, 2)

# total number of iterations
n_iter_cst = 100000
n_iter_dec = 100000

# step sizes
eigs = np.linalg.eig(A.T @ A)
mu, L = np.min(eigs[0]), np.max(eigs[0])
facs = [1.0, 0.5, 0.1]
alpha_ref = mu / (4 * L * L)

cst_alphas = [(lambda _, fac=fac: fac * alpha_ref) for fac in facs]


########################################
# Constant step size || A x - theta ||^2
########################################
def lstsq(x, theta, idx):
    return (0.5 / idx.shape[0]) * jnp.sum((A[idx, :] @ x - theta[idx]) ** 2)


# "true" solution
theta = np.random.randn(n) / np.sqrt(n)
sol = np.linalg.lstsq(A, theta, rcond=1e-14)[0]
val = lstsq(sol, theta, idx=jnp.arange(n))
true_jac = np.linalg.inv(A.T @ A) @ A.T

### Constant step size
print("Running constant step size in regular regime")
fs_cst, dys_cst = run_alphas_experiment(
    val, true_jac, lstsq, theta, x0, cst_alphas, n_iter_cst, n, key, batch_size=bs
)
save_experiment(
    lstsq,
    fs_cst,
    dys_cst,
    val,
    true_jac,
    theta,
    x0,
    [f"{fac} $\eta_0$" for fac in facs],
    n_iter_cst,
    n,
    bs,
    custom_name="cst",
)

##########################################
# Decreasing step size || A x - theta||^2
##########################################

dec_alphas = [
    lambda i: alpha_ref / (i + 1),
    lambda i: alpha_ref / np.sqrt(i + 1),
    lambda i: alpha_ref / (i + 1) ** 0.25,
    lambda i: alpha_ref / np.log(i + 2),
]

# Decreasing step size
print("Running decreasing step size in regular regime")
fs_dec, dys_dec = run_alphas_experiment(
    val, true_jac, lstsq, theta, x0, dec_alphas, n_iter_dec, n, key, batch_size=bs
)
save_experiment(
    lstsq,
    fs_dec,
    dys_dec,
    val,
    true_jac,
    theta,
    x0,
    ["1/k", "1/sqrt(k)", "1/k^0.25", "1/log(k+2)"],
    n_iter_dec,
    n,
    bs,
    custom_name="dec",
)


##########################################
# Constant step size || A x - A theta ||^2
##########################################
def lstsq_interpol(x, theta, idx):
    return (0.5 / idx.shape[0]) * jnp.sum((A[idx, :] @ x - (A @ theta)[idx]) ** 2)


# "true" solution
theta_interpol = np.random.randn(p)

sol_interpol = np.linalg.lstsq(A, A @ theta_interpol, rcond=1e-14)[0]
val_interpol = lstsq_interpol(sol_interpol, theta_interpol, idx=jnp.arange(n))
true_jac_interpol = np.eye(p)

### Constant step size in interpolation regime
print("Running constant step size in interpolation regime")
fs_interpol_cst, dys_interpol_cst = run_alphas_experiment(
    val_interpol,
    true_jac_interpol,
    lstsq_interpol,
    theta_interpol,
    x0,
    cst_alphas,
    n_iter_cst,
    n,
    key,
    batch_size=bs,
)
save_experiment(
    lstsq_interpol,
    fs_interpol_cst,
    dys_interpol_cst,
    val_interpol,
    true_jac_interpol,
    theta_interpol,
    x0,
    [f"{fac} $\eta_0$" for fac in facs],
    n_iter_cst,
    n,
    bs,
    custom_name="cst",
)


########################################
# Constant step size || A x - theta ||^2
########################################
def lstsq(x, theta, idx):
    return (0.5 / idx.shape[0]) * jnp.sum((A[idx, :] @ x - theta[idx]) ** 2)


# "true" solution
theta_img = A @ np.random.randn(p)
sol_img = np.linalg.lstsq(A, theta_img, rcond=1e-12)[0]
val_img = lstsq(sol_img, theta_img, idx=jnp.arange(n))
true_jac_img = np.linalg.inv(A.T @ A) @ A.T

### Constant step size
print("Running constant step size in regular regime")
fs_cst_img, dys_cst_img = run_alphas_experiment(
    val_img,
    true_jac_img,
    lstsq,
    theta_img,
    x0,
    cst_alphas,
    n_iter_cst,
    n,
    key,
    batch_size=bs,
)
save_experiment(
    lstsq,
    fs_cst_img,
    dys_cst_img,
    val_img,
    true_jac_img,
    theta_img,
    x0,
    [f"{fac} $\eta_0$" for fac in facs],
    n_iter_cst,
    n,
    bs,
    custom_name="cst-img",
)
