import numpy as np
from matplotlib import pyplot as plt

import pickle


def plotGenErrs(f_abs, res, T, name=""):
    # print(f_hat[:2,:2])
    gen_errs = np.sum((f_abs[:, np.newaxis] - res) ** 2, axis=0)
    # print(gen_errs.shape)
    # print(gen_errs[:2])
    plt.xlabel("t")
    plt.ylabel("log10 err")
    plt.title(f"{name}")
    plt.plot(np.linspace(0, T, res.shape[1]), np.log10(gen_errs))
    plt.show()


def compareGenErrs(f_abs, ress, T, names=None):
    if names is None:
        names = [f"res_{i}" for i in range(len(ress))]
    plt.xlabel("t")
    plt.ylabel("log10 err")
    plt.title("Generalization Errors")
    for res, name in zip(ress, names):
        gen_errs = np.sum((f_abs[:, np.newaxis] - res) ** 2, axis=0)
        plt.plot(np.linspace(0, T, res.shape[1]), np.log10(gen_errs), label=f"{name}")
    plt.legend()
    plt.show()


def solveAB(x, y):
    A = np.vstack([x, np.ones(len(x))]).T
    return np.linalg.lstsq(A, y, rcond=None)[0]


def loglogPlot(x, y, title=""):
    log_x = np.log10(x)
    log_y = np.log10(y)
    k, m = solveAB(log_x, log_y)
    plt.title(title)
    plt.plot(log_x, k * log_x + m, color='gray', linestyle="--")
    plt.plot(log_x, log_y, label=f"slope={k}")
    plt.legend()
    plt.show()


def compareGenErrsN(expr_res, T_max=None):
    plt.xlabel("t")
    plt.ylabel("log10 err")
    plt.title("Generalization Errors")

    # 使用渐变颜色
    cm = plt.colormaps.get_cmap("viridis")
    colors = [cm(i / len(expr_res.ns)) for i in range(len(expr_res.ns))]

    for res, n, T, color in zip(expr_res.gen_err_list, expr_res.ns, expr_res.T_ada_list, colors):
        mean_gen_errs = np.mean(res, axis=0)
        plt.plot(np.linspace(0, T, res.shape[1]), np.log10(mean_gen_errs), label=f"n={n}", color=color)
    if T_max is not None:
        plt.xlim((0, T_max))
    plt.legend()
    plt.show()


def plotMinGenErrsN(expr_res, export_filename=None):
    plt.figure()
    plt.xlabel("n (logarithmic)")
    plt.ylabel("$\\log_{10}$ generalization error ")
    p = expr_res.meta["f_decay"]
    plt.title(f"Generalization Errors; optimal rate={(2 * p - 1) / (2 * p):.3f}")
    # rep = ress[0][0].shape[0]
    rep = expr_res.meta["repeats"]
    means = []
    log_stds = []
    for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):
        min_gen_errs = np.min(res, axis=1)
        means.append(np.mean(min_gen_errs))
        if rep > 1:
            min_gen_errs_log = np.log10(min_gen_errs)
            log_stds.append(np.nanstd(min_gen_errs_log))
    ns = expr_res.ns
    means = np.array(means)
    ns = np.array(ns)
    log_means = np.log10(means)
    if rep > 1:
        plt.errorbar(np.log10(ns), log_means, yerr=log_stds, fmt="-o", label="err")
    else:
        plt.plot(np.log10(ns), log_means, "-o", label="err")
    # make a linear fit
    k, m = solveAB(np.log10(ns), log_means)
    plt.plot(np.log10(ns), k * np.log10(ns) + m, color='gray', linestyle="--",
             label=f"$\\log Err\\approx {k:.3f} \\log n + {m:.2f}$")
    # plt.xticks(np.log10(ns), [n // 100 for n in ns])
    xlabels = [ns[0]] + [""] * (len(ns) - 2) + [ns[-1]]
    plt.xticks(np.log10(ns), xlabels)
    plt.legend()
    if export_filename is not None:
        plt.savefig(export_filename)
    plt.show()

def computeMinGenErrorRate(expr_res):
    rep = expr_res.meta["repeats"]
    means = []
    log_stds = []
    for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):
        min_gen_errs = np.min(res, axis=1)
        means.append(np.mean(min_gen_errs))
        if rep > 1:
            min_gen_errs_log = np.log10(min_gen_errs)
            log_stds.append(np.nanstd(min_gen_errs_log))
    ns = expr_res.ns
    means = np.array(means)
    ns = np.array(ns)
    log_means = np.log10(means)
    k, m = solveAB(np.log10(ns), log_means)
    return k


def plotMinGenStepsN(expr_res, export_filename=None):
    plt.figure()
    plt.xlabel("n (logarithmic)")
    plt.ylabel("log10 min gen step")
    D = expr_res.meta["D"]
    plt.title(f"Oracle stopping time")
    # rep = ress[0][0].shape[0]
    rep = expr_res.meta["repeats"]
    means = []
    stds = []
    ticks = expr_res.meta["ticks"]
    for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):
        min_gen_errs = np.log10(np.argmin(res, axis=1) * T / ticks)
        means.append(np.mean(min_gen_errs))
        if rep > 1:
            stds.append(np.nanstd(min_gen_errs))
    ns = expr_res.ns
    if rep > 1:
        plt.errorbar(np.log10(ns), means, yerr=stds, fmt="-o", label="stopping time")
    else:
        plt.plot(np.log10(ns), means, "-o", label="stopping time")
    # make a linear fit
    ns = np.array(ns)
    means = np.array(means)
    k, m = solveAB(np.log10(ns), means)
    plt.plot(np.log10(ns), k * np.log10(ns) + m, color='gray', linestyle="--",
             label=f"$\\log t\\approx {k:.2f} \\log n + {m:.2f}$")
    xlabels = [ns[0]] + [""] * (len(ns) - 2) + [ns[-1]]
    plt.xticks(np.log10(ns), xlabels)
    plt.legend()
    if export_filename is not None:
        plt.savefig(export_filename)
    plt.show()


def plotErrorSteps(expr_res, pow, factor, export_filename=None):
    plt.xlabel("n")
    plt.ylabel("$\\log_{10}$ generalization error ")
    p = expr_res.meta["f_decay"]
    plt.title(
        f"Generalization Errors;expected rate={(2 * p - 1) / (2 * p):.3f} \n $t={factor:.2f} \\times n^{{{pow:.2f}}} $")
    # rep = ress[0][0].shape[0]
    rep = expr_res.meta["repeats"]
    total_ticks = expr_res.meta["ticks"]
    means = []
    log_stds = []
    for res, T, n in zip(expr_res.gen_err_list, expr_res.T_ada_list, expr_res.ns):
        t = factor * n ** pow
        tick = int(t / T * total_ticks)
        tick = min(tick, total_ticks)
        means.append(np.mean(res[:, tick]))
        if rep > 1:
            min_gen_errs_log = np.log10(res[:, tick])
            log_stds.append(np.nanstd(min_gen_errs_log))
    ns = expr_res.ns
    means = np.array(means)
    ns = np.array(ns)
    log_means = np.log10(means)
    if rep > 1:
        plt.errorbar(np.log10(ns), log_means, yerr=log_stds, fmt="-o", label="err")
    else:
        plt.plot(np.log10(ns), log_means, "-o", label="err")
    # make a linear fit
    k, m = solveAB(np.log10(ns), log_means)
    plt.plot(np.log10(ns), k * np.log10(ns) + m, color='gray', linestyle="--",
             label=f"$\\log Err\\approx {k:.3f} \\log n + {m:.2f}$")
    # plt.xticks(np.log10(ns), [n // 100 for n in ns])
    xlabels = [ns[0]] + [""] * (len(ns) - 2) + [ns[-1]]
    plt.xticks(np.log10(ns), xlabels)
    plt.legend()
    if export_filename is not None:
        plt.savefig(export_filename)
    plt.show()


def loadResult(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)


def loadAndPlot(filename, name=None, lower=0, upper=None):
    result = loadResult(filename)
    if upper is None:
        upper = len(result.ns)
    result = result[lower:upper]
    if name is None:
        name = filename[:-4]  # remove .pkl
    plotMinGenErrsN(result, name + ".pdf")
    plotMinGenStepsN(result, name + "_steps.pdf")
