import matplotlib.pyplot as plt
import numpy as np

from active_ranking import utils
from active_ranking.experiments import experiment
from active_ranking.metrics.roc import ROC, norm_one
from active_ranking.scenarios import inputs

scenario = inputs.__dict__["scenario_3"]
j_max = scenario["j_max"]
d = scenario["d"]
e = scenario["eta"]
n_max = scenario["n_max"]
n_0 = scenario["n_0"]

models = [m(n_0, n_max, j_max, d, e) for m in experiment.models]

fig1, ax1 = plt.subplots(figsize=(6, 6), dpi=200)
fig2, ax2 = plt.subplots(figsize=(6, 6), dpi=200)
fig3, ax3 = plt.subplots(figsize=(6, 6), dpi=200)

model = models[0]
roc_true = ROC(
    y_true=model.y_test,
    y_prediction=model.true_eta(
        model.x_test)
)

for model in models:
    roc_model = ROC(
        y_true=model.y_test,
        y_prediction=list(model.predictions.values())[-1])
    plt.sca(ax1)
    roc_model.plot_roc_curve(
        label=f"{model.name} ROC curve (AUC = {np.round(roc_model.auc, 2)})",
        marker=".")

    print(f"{model.name}:", norm_one(roc_true, roc_model))

    plt.sca(ax2)
    plt.plot(model.n_sample, model.norm_one, label=model.name)

    plt.sca(ax3)
    plt.plot(model.n_sample, model.norm_infinity, label=model.name)

plt.sca(ax1)
roc_true.plot_roc_curve(c="k", label="True ROC curve", marker=".")
roc_true.plot_back_ground()
plt.legend(facecolor="w")
plt.savefig("results/figures/one_example/roc.png")

plt.sca(ax2)
plt.xlabel("t")
plt.ylabel("$d_1$ regret")
plt.legend()
plt.savefig("results/figures/one_example/regret_d_1.png")
plt.sca(ax3)
plt.grid(True)
plt.legend()
plt.xlabel("Number of sample")
plt.ylabel("$d_\infty$ regret")
plt.savefig("results/figures/one_example/regret_d_infty.png")

print(utils.__cached_time_n_call__)
print(utils.__cached_time__)
