import time
from argparse import Namespace
from typing import List, Optional

import opacus
import optuna
import ray
import torch
from tqdm.auto import tqdm

from .kernels import Kernel, RBFKernel
from .likelihoods import GaussianLikelihood
from .sgp import SparseGP

torch.set_default_dtype(torch.float64)


def train_model_instance(
    model: torch.nn.Module,
    optimiser: torch.optim.Optimizer,
    epochs: int,
    train_loader: torch.utils.data.DataLoader,
    xc: torch.Tensor,
    yc: torch.Tensor,
    trial: Optional[optuna.Trial] = None,
):
    epochs_iter = tqdm(range(epochs), desc="Epoch")
    t0 = time.time()
    for _ in epochs_iter:
        model.train()
        # Within each iteration, go over each minibatch of data.
        t02 = time.time()
        for xb, yb in train_loader:
            optimiser.zero_grad()

            qf_params = model(xb)
            qf_loc = qf_params[:, : yc.shape[-1]]
            qf_cov = qf_params[:, yc.shape[-1] :]
            qf = torch.distributions.Normal(qf_loc, qf_cov.pow(0.5))

            exp_ll = model.likelihood.expected_log_prob(yb, qf).sum()

            if len(xb) > 0:
                exp_ll = exp_ll * (len(xc) / len(xb))

            if hasattr(model, "_module"):
                kl = model._module.kl_divergence()
            else:
                kl = model.kl_divergence()

            elbo = exp_ll - kl
            (-elbo).backward()

            optimiser.step()

            metrics = {
                "elbo": elbo.item(),
                "kl": kl.item(),
                "exp_ll": exp_ll.item(),
            }
            epochs_iter.set_postfix(metrics)

        t12 = time.time()

    model.eval()
    t03 = time.time()
    with torch.no_grad():
        qf_params = model(xc)
        qf_loc = qf_params[:, : yc.shape[-1]]
        qf_cov = qf_params[:, yc.shape[-1] :]
        qf = torch.distributions.Normal(qf_loc, qf_cov.pow(0.5))

        exp_ll = model.likelihood.expected_log_prob(yc, qf).sum()
        if hasattr(model, "_module"):
            kl = model._module.kl_divergence()
        else:
            kl = model.kl_divergence()
        elbo = exp_ll - kl

        if trial is not None:
            trial.report(elbo.item(), epochs - 1)

    t13 = time.time()

    t1 = time.time()

    return (
        elbo,
        model,
        {
            "train_and_inference": (t1 - t0),
            "train": (t12 - t02),
            "inference": (t13 - t03),
        },
    )


def train_model(
    xc: torch.Tensor,
    yc: torch.Tensor,
    train_args: Namespace,
    trial: Optional[optuna.Trial] = None,
):
    # Construct model.
    kernel = train_args.kernel(**train_args.kernel_kwargs)

    if xc.shape[-1] == 1:
        init_z = torch.linspace(
            train_args.xmin, train_args.xmax, train_args.num_inducing
        ).unsqueeze(-1)
    elif xc.shape[-1] == 2:
        # num_inducing is per dimension.
        init_z1 = torch.linspace(
            train_args.xmin[0], train_args.xmax[0], train_args.num_inducing
        )
        init_z2 = torch.linspace(
            train_args.xmin[1], train_args.xmax[1], train_args.num_inducing
        )
        grid_z1, grid_z2 = torch.meshgrid(init_z1, init_z2)
        init_z = torch.stack((grid_z1, grid_z2), dim=-1)
    else:
        raise NotImplementedError

    likelihood = GaussianLikelihood(noise=train_args.init_noise)
    model = SparseGP(kernel, likelihood, init_z)

    train_dataset = torch.utils.data.TensorDataset(xc, yc)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_args.batch_size, shuffle=True
    )

    optimiser = torch.optim.Adam(model.parameters(), lr=train_args.lr)

    return train_model_instance(
        model, optimiser, train_args.epochs, train_loader, xc, yc, trial
    )


def dp_train_model(
    xc: torch.Tensor,
    yc: torch.Tensor,
    train_args: Namespace,
    trial: Optional[optuna.Trial] = None,
):
    # Construct model.
    kernel = train_args.kernel(**train_args.kernel_kwargs)

    # init_z = torch.linspace(xc.min(), xc.max(), train_args.num_inducing).unsqueeze(-1)
    if xc.shape[-1] == 1:
        init_z = torch.linspace(
            train_args.xmin, train_args.xmax, train_args.num_inducing
        ).unsqueeze(-1)
    else:
        raise NotImplementedError

    likelihood = GaussianLikelihood(noise=train_args.init_noise)
    model = SparseGP(kernel, likelihood, init_z)

    train_dataset = torch.utils.data.TensorDataset(xc, yc)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_args.batch_size, shuffle=True
    )

    privacy_engine = opacus.PrivacyEngine(accountant="prv")
    optimiser = torch.optim.Adam(model.parameters(), lr=train_args.lr)

    dp_model, dp_optimiser, dp_train_loader = privacy_engine.make_private_with_epsilon(
        module=model,
        optimizer=optimiser,
        data_loader=train_loader,
        epochs=train_args.epochs,
        target_epsilon=train_args.epsilon,
        target_delta=train_args.delta,
        max_grad_norm=train_args.max_grad_norm,
        grad_sample_mode="functorch",
    )

    return train_model_instance(
        dp_model, dp_optimiser, train_args.epochs, dp_train_loader, xc, yc, trial
    )


@ray.remote
def ray_dp_train_model(*args, **kwargs):
    return dp_train_model(*args, **kwargs)


@ray.remote
def ray_train_model(*args, **kwargs):
    return train_model(*args, **kwargs)


def batch_dp_train_model(
    epsilon: List[float],
    delta: List[float],
    xc: List[torch.Tensor],
    yc: List[torch.Tensor],
    train_args: Namespace,
    trial: optuna.Trial,
):
    # elbo = 0.0
    elbos = []
    for xcb, ycb, epsilon_, delta_ in zip(xc, yc, epsilon, delta):
        train_args.epsilon = epsilon_
        train_args.delta = delta_
        elbos.append(ray_dp_train_model.remote(xcb, ycb, train_args, trial))
        # elbos.append(dp_train_model(xcb, ycb, train_args, trial))

    elbos = [ray.get(elbo)[0] for elbo in elbos]
    elbos = [elbo for elbo in elbos if elbo is not None]
    elbo = sum(elbos) / len(elbos)

    return elbo


def batch_train_model(
    epsilon: List[float],
    delta: List[float],
    xc: List[torch.Tensor],
    yc: List[torch.Tensor],
    train_args: Namespace,
    trial: optuna.Trial,
):
    # elbo = 0.0
    elbos = []
    for xcb, ycb, epsilon_, delta_ in zip(xc, yc, epsilon, delta):
        train_args.epsilon = epsilon_
        train_args.delta = delta_
        elbos.append(ray_train_model.remote(xcb, ycb, train_args, trial))
        # elbos.append(dp_train_model(xcb, ycb, train_args, trial))

    elbos = [ray.get(elbo)[0] for elbo in elbos]
    elbos = [elbo for elbo in elbos if elbo is not None]
    elbo = sum(elbos) / len(elbos)

    return elbo


def objective(
    epsilon: List[float],
    delta: List[float],
    xc: List[torch.Tensor],
    yc: List[torch.Tensor],
    limits,
    xmin: float,
    xmax: float,
    kernel: type[Kernel] = RBFKernel,
    trial: Optional[optuna.Trial] = None,
):
    if trial is None:
        raise ValueError
    # Suggest values of hypers using a trial object.
    num_inducing = trial.suggest_int(
        "num_inducing", limits.min_inducing, limits.max_inducing
    )
    epochs = trial.suggest_int("epochs", limits.min_epochs, limits.max_epochs)
    batch_size = trial.suggest_int(
        "batch_size", limits.min_batch_size, limits.max_batch_size
    )
    max_grad_norm = trial.suggest_float(
        "max_grad_norm", limits.min_max_grad_norm, limits.max_max_grad_norm, log=True
    )
    lr = trial.suggest_float("lr", limits.min_lr, limits.max_lr, log=True)
    init_lengthscale = trial.suggest_float(
        "init_lengthscale",
        limits.min_init_lengthscale,
        limits.max_init_lengthscale,
        log=True,
    )
    init_scale = trial.suggest_float(
        "init_scale", limits.min_init_scale, limits.max_init_scale
    )
    init_noise = trial.suggest_float(
        "init_noise", limits.min_init_noise, limits.max_init_noise
    )

    kernel_kwargs = {
        "init_lengthscale": init_lengthscale,
        "init_scale": init_scale,
    }
    if hasattr(limits, "min_init_period") and hasattr(limits, "max_init_period"):
        init_period = trial.suggest_float(
            "init_period",
            limits.min_init_period,
            limits.max_init_period,
        )
        kernel_kwargs["init_period"] = init_period

    train_args = Namespace(
        **{
            "num_inducing": num_inducing,
            "epochs": epochs,
            "batch_size": batch_size,
            "lr": lr,
            "max_grad_norm": max_grad_norm,
            "init_noise": init_noise,
            "kernel": kernel,
            "kernel_kwargs": kernel_kwargs,
            "xmin": xmin,
            "xmax": xmax,
        }
    )

    return batch_dp_train_model(epsilon, delta, xc, yc, train_args, trial)
