import math
import sys
import torch 
import pyro
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from scipy.stats import truncnorm
import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import argparse
from SBNN.model import Model
from pyro.infer import Predictive

CANTOR_ORDER = 10

parser = argparse.ArgumentParser(description='besov')
parser.add_argument('--function',
                    default="cantor",
                    type=str,
                    metavar='M',
                    help='cantor | sobolev | blocks | heavisine')
parser.add_argument('--L',
                    default=5,
                    type=int,
                    metavar='L',
                    help='depth of the NN model')
parser.add_argument('--W',
                    default=200,
                    type=int,
                    metavar='W',
                    help='width of the NN model')
parser.add_argument('--n',
                    default=100,
                    type=int,
                    metavar='n',
                    help='number of training data')
parser.add_argument('--epochs',
                    default=10000,
                    type=int,
                    help='number of total epochs to run')
parser.add_argument('--lr',
                    default=2e-3,
                    type=float,
                    help='learning rate for optimization step')
parser.add_argument("--tnx",
                    default=False,
                    type=bool,
                    help="weather sample from truncated normal"
                    )
parser.add_argument("--prior",
                    default="2GMM",
                    type=str,
                    help="prior: 2GMM (Gaussian mixture) | Gaussian"
                    )
parser.add_argument("--algorithm",
                    default="mcmc",
                    type=str,
                    help="mcmc or BBP (Bayes by BackProp")
parser.add_argument("--verbose",
                    default=20)
parser.add_argument("--seed",
                    default=1,
                    type=int,
                    help="seed number")
parser.add_argument("--double-precision",
                    default=False,
                    type=bool,
                    help="use double precision or not")
parser.add_argument("--batch-size",
                    default=100,
                    type=int,
                    help="batch size")
parser.add_argument("--manual-sigma-1",
                    default=1e-3,
                    type=float,
                    help="manual sigma 1 for 2GMM")
parser.add_argument("--num-samples",
                    default=1000,
                    type=int,
                    help="num MCMC samples")
parser.add_argument("--warmup-steps",
                    default=1000,
                    type=int,
                    help="num MCMC warmup steps")

class BesovFunction():
    def __init__(self, function):
        self.function = np.vectorize(function)
    
    def __call__(self, x):
        return self.function(x)

    def sample(self, m, sigma=1.0, seed=0, trunc_norm_x=False):
        np.random.seed(seed)
        if trunc_norm_x:
            loc = 0.5
            scale = 0.2
            a, b = (0 - loc) / scale, (1 - loc) / scale
            x_sample = truncnorm.rvs(a, b, loc=loc, scale=scale, size=m)
        else:
            x_sample = np.random.rand(m)
        y_sample = self(x_sample) + np.random.randn(m) * sigma
        return x_sample, y_sample

def approx_cantor_n_ftn(x, n):
    if n == 0:
        return x
    else: 
        if 0 <= x <= 1/3:
            return 1/2 * approx_cantor_n_ftn(3 * x, n - 1)
        elif 2/3 <= x <= 1:
            return 1/2 + 1/2 * approx_cantor_n_ftn(3 * x - 2, n - 1)
        else:
            return 1/2

def cantor(x):
    return approx_cantor_n_ftn(x, CANTOR_ORDER)

v_cantor = np.vectorize(cantor)

def sobolev(x):
    return np.where(x > 0, 1/np.log(x/2), 0)

def blocks(x):
    x_j = np.array([0.1, 0.13, 0.15, 0.23, 0.25, 0.40, 0.44, 0.65, 0.76, 0.78, 0.81])
    h_j = np.array([4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2])
    return np.sum(h_j * (np.sign(x-x_j) + 1) / 2)

def heavisine(x):
    return 4 * np.sin(4 * np.pi * x) - np.sign(x - 0.3) - np.sign(0.72 - x)

func_dict = {
    "cantor": {"func": v_cantor, "s": math.log(2) / math.log(3), "p": float("Inf"), "sigma": 0.01, "F": 1, "label_loc": "upper left"},
    "sobolev": {"func": sobolev, "s": 1, "p": 1, "sigma": 0.01, "F": 2, "label_loc": "upper right"},
    "blocks": {"func": blocks, "s": 1, "p": 1, "sigma": 0.1, "F": 10, "label_loc": "upper left"},
    "heavisine": {"func": heavisine, "s": 1, "p": 1, "sigma": 0.1, "F": 10, "label_loc": "upper right"}
}

def main():
    global args 
    args = parser.parse_args()

    func = args.function
    n = args.n
    lr = args.lr
    L = args.L
    W = args.W
    n_epochs = args.epochs
    seed = args.seed
    trunc_norm_x = args.tnx
    prior = args.prior
    verbose = min([args.verbose, n_epochs])
    algorithm = args.algorithm
    DTYPE = torch.float64 if args.double_precision else torch.float32
    batch_size = min([args.batch_size, n])
    manual_sigma1 = args.manual_sigma_1
    mcmc = (algorithm == "mcmc")

    n_samples = 1

    torch.manual_seed(seed)
    pyro.set_rng_seed(seed)
    torch.set_default_dtype(DTYPE)
    cuda = torch.cuda.is_available()

    true_function = func_dict[func]["func"]
    besov_ = BesovFunction(true_function)
    s = func_dict[func]["s"]
    p = func_dict[func]["p"]
    sigma = func_dict[func]["sigma"]
    F = func_dict[func]["F"]
    label_loc = func_dict[func]["label_loc"]

    x_grid = np.linspace(0, 1, 1000)
    d = 1

    x_sample, y_sample = besov_.sample(n, sigma, seed=seed, trunc_norm_x=trunc_norm_x)
    x_sample = torch.from_numpy(x_sample).type(DTYPE)
    y_sample = torch.from_numpy(y_sample).type(DTYPE)
    x_test = torch.from_numpy(x_grid).type(DTYPE)
    y_test = besov_(x_grid)
    y_test = torch.from_numpy(y_test).type(DTYPE)
    if cuda:
        x_sample = x_sample.cuda()
        y_sample = y_sample.cuda()
        x_test = x_test.cuda()
        y_test = y_test.cuda()
    
    model = Model(n=n, d=d, s=s, p=p, batch_size=batch_size, scale_data=sigma,
        W=W, L=L, F=F, cuda=cuda)
    model.setup_model(lr=lr, manual_sigma1=manual_sigma1, prior=prior, pyro=mcmc)

    if mcmc:
        model.model.set_deterministic(lr=lr)
        model.train_pyro_deterministic(x_sample.reshape(-1, 1), y_sample.reshape(-1, 1),
            n_epochs=n_epochs, verbose=0)
        initial_params = {}
        for name, val in model.model.net_deterministic.named_parameters():
            initial_params["net." + name] = val.detach().clone()
        model.run_mcmc(x_sample.reshape(-1, 1), y_sample.reshape(-1),
            algorithm="nuts",
            mcmc_params={
                "num_samples": args.num_samples,
                "warmup_steps": args.warmup_steps,
                "initial_params": initial_params
            }
        )
        predictive = Predictive(model.model, posterior_samples=model.mcmc.get_samples())
        preds_train = predictive(x_sample.reshape(-1, 1))['obs']
        preds = predictive(x_test.reshape(-1, 1))['obs']

        eps_rec = ((preds_train - y_sample.reshape(1, -1).tile([preds_train.shape[0], 1])) ** 2).mean(axis=1).sqrt().detach().cpu()
        eps_test_rec = ((preds - y_test.reshape(1, -1).tile([preds_train.shape[0], 1])) ** 2).mean(axis=1).sqrt().detach().cpu()

        y_pred = preds.mean(axis=0).detach().cpu()
        y_std = preds.std(axis=0).detach().cpu()
        eps_rec = ((preds_train - y_sample.reshape(1, -1).tile([preds_train.shape[0], 1])) ** 2).mean(axis=1).sqrt().detach().cpu()

        # plot (left)
        fig, ax = plt.subplots(1, 1, figsize = (6, 6))
        ax.scatter(x_sample.cpu().reshape(-1), y_sample.cpu().reshape(-1), label="training data")
        ax.plot(x_grid.reshape(-1), y_pred, linewidth=5, c='blue', label="mean function")
        ax.fill_between(x_grid, y_pred - 3 * y_std, y_pred + 3 * y_std,
                    alpha=0.5, color='#ffcd3c', label=r"mean $\pm$ 3std")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.legend(loc=label_loc)
        plt.savefig(f"figs/{func}_{prior}_{n}_a_{seed}.pdf")

        # plot (right)
        # fig, ax = plt.subplots(1, 1, figsize = (6, 6))
        # sns.histplot(eps_rec, bins=30, ax=ax)
        # ax.set_xlabel("empirical error")
        # plt.savefig(f"figs/{func}_{prior}_{n}_b_{seed}.pdf")

        # save errors
        np.save(f"results/{func}_{prior}_{n}_err_{seed}.npy", eps_rec.numpy())
        np.save(f"results/{func}_{prior}_{n}_err_test_{seed}.npy", eps_test_rec.numpy())

    else:
        model.train_BBP(
            x=x_sample.reshape(-1, 1), y=y_sample.reshape(-1, 1), 
            verbose=verbose, 
            n_epochs=n_epochs,
            n_samples=n_samples,
            best=False)
        if trunc_norm_x:
            fname = f"model/{func}_n={n}_tnx.pt"
        else:
            fname = f"model/{func}_n={n}.pt"
        torch.save(model.model.model.state_dict(), fname)

if __name__ == "__main__":
    main()