import jax.numpy as jnp
import jax
import numpy as np
from sklearn.datasets import make_classification
from sgd import sgd, run_alphas_experiment
from utils import save_experiment

jax.config.update("jax_enable_x64", True)


key = jax.random.PRNGKey(42)
# number of samples, number of features
n, p = 100, 10
# batch size
bs = 1
# generate data
A, theta = make_classification(n_samples=n, n_features=p)
A = A / np.linalg.norm(A, 2)
A = jnp.array(A)
theta = jnp.array(theta, dtype=jnp.float64)

# initialization
x0 = np.random.rand(p)
x0 = jnp.array(x0 / np.linalg.norm(x0, 2))

# total number of iterations
n_iter = 200000

mu = 0.05


def svm(x, theta, idx):
    return (
        jnp.mean(jnp.maximum(1 - theta[idx] * (A[idx, :] @ x), 0))
        + mu * jnp.linalg.norm(x, 2) ** 2
    )


truex, trued, trueclb = sgd(
    svm,
    theta,
    x0,
    lambda i: 0.1,
    n_iter,
    n,
    key,
    callback=lambda x, dx: svm(x, theta, jnp.arange(n)),
    batch_size=n,
)
val = svm(truex, theta, jnp.arange(n))


facs = [0.1, 0.01, 0.001]
alphas = [(lambda _, fac=fac: fac) for fac in facs]

fs, dfs = run_alphas_experiment(
    val, trued, svm, theta, x0, alphas, n_iter, n, key, batch_size=1
)

# Save the results
save_experiment(
    svm,
    fs,
    dfs,
    val,
    trued,
    theta,
    x0,
    facs,
    n_iter,
    n,
    bs,
)
