from typing import Optional, Dict
import logging
import copy

import numpy as np
import torch 
from torch import nn
from torch import Tensor, optim
import botorch
from botorch import fit_gpytorch_mll
from botorch.acquisition import ExpectedImprovement
from botorch.models import SingleTaskGP
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, MaternKernel
from gpytorch.priors.torch_priors import GammaPrior
from gpytorch.constraints import GreaterThan

from bbo.algorithms.basic_bo.wrapper import (
    WrapperMean,
    WrapperKernel,
    create_wrapper,
)
from bbo.algorithms.utils import latin_hypercube, from_unit_cube, timer_wrapper
from bbo.algorithms.basic_bo.definitions import (
    MeanConfig,
    KernelConfig,
    TrainConfig,
    AcqfConfig,
    PretrainConfig,
)
from bbo.datasets.base import SimpleDataset
from bbo.datasets import X_transform, Y_transform

log = logging.getLogger(__name__)


class BO:
    def __init__(
        self,
        dim: int,
        lb: Tensor,
        ub: Tensor,
        name: str = 'BO',
        n_init: int = 10,
        q: int = 1,
        is_share_parameters: bool = False, # whether mean and kernel use shared parameters
        train_config: dict = None,
        mean_config: dict = None,
        kernel_config: dict = None,
        acqf_config: dict = None,
        pretrain_config: dict = None,
        device: str = 'cpu',
    ):
        assert lb.ndim == 1 and ub.ndim == 1
        assert lb.shape == ub.shape
        assert (lb < ub).all()
        self.dim = dim
        self.lb = lb
        self.ub = ub
        self.name = name
        self.n_init = n_init
        self.q = q
        self.is_share_parameters = is_share_parameters

        self.train_config = TrainConfig(**train_config) if train_config is not None else TrainConfig()
        self.mean_config = MeanConfig(**mean_config) if mean_config is not None else MeanConfig()
        self.kernel_config = KernelConfig(**kernel_config) if kernel_config is not None else KernelConfig()
        self.acqf_config = AcqfConfig(**acqf_config) if acqf_config is not None else AcqfConfig()
        self.pretrain_config = PretrainConfig(**pretrain_config)

        assert self.train_config.mll_opt in ['l-bfgs', 'adam']
        assert self.mean_config.name in ['constant', 'mlp']
        assert self.kernel_config.name in ['rbf', 'matern']
        assert self.kernel_config.wrapper in ['identity', 'kumar', 'mlp']
        assert self.acqf_config.name in ['EI']
        assert self.acqf_config.acqf_opt in ['l-bfgs']

        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        log.info('Device: {}'.format(self.device))

        self.X = torch.zeros((0, dim))
        self.Y = torch.zeros((0, 1))

    def init(self):
        init_X = latin_hypercube(self.n_init, self.dim)
        init_X = from_unit_cube(init_X, self.lb.detach().cpu().numpy(), self.ub.detach().cpu().numpy())
        init_X = torch.from_numpy(init_X)
        return init_X

    def create_shared_mean_covar_module(self, mean_config, kernel_config):
        assert mean_config.name == 'mlp' and kernel_config.wrapper == 'mlp'
        assert mean_config.hidden_features == kernel_config.hidden_features

        dim = self.dim if kernel_config.wrapper != 'mlp' \
            else kernel_config.out_features
        base_kernel = ScaleKernel(MaternKernel(ard_num_dims=dim))
        # base_kernel = ScaleKernel(
        #     MaternKernel(
        #         nu=2.5,
        #         ard_num_dims=train_X.shape[-1],
        #         lengthscale_prior=GammaPrior(3.0, 6.0),
        #     ),
        #     outputscale_prior=GammaPrior(2.0, 0.15),
        # )

        wrapper = create_wrapper(
            'mlp', 
            config={
                'in_features': self.dim,
                'hidden_features': kernel_config.hidden_features,
            }
        )
        mean_final_layer = nn.Linear(
            in_features=mean_config.hidden_features[-1],
            out_features=1,
        )
        kernel_final_layer = nn.Sequential(
            nn.Linear(
                in_features=kernel_config.hidden_features[-1],
                out_features=kernel_config.out_features,
            ),
            nn.Tanh(),
        )
        mean_module = WrapperMean(wrapper, mean_final_layer)
        covar_module = WrapperKernel(base_kernel, wrapper, kernel_final_layer)

        return mean_module, covar_module

    def create_mean_module(self, mean_config):
        if mean_config.name == 'constant':
            mean_module = ConstantMean()
        elif mean_config.name == 'mlp':
            wrapper = create_wrapper(
                'mlp',
                config={
                    'in_features': self.dim,
                    'hidden_features': mean_config.hidden_features,
                }
            )
            mean_final_layer = nn.Linear(
                in_features=mean_config.hidden_features[-1],
                out_features=1,
            )
            mean_module = WrapperMean(wrapper, mean_final_layer)
        else:
            raise NotImplementedError
        return mean_module

    def create_covar_module(self, kernel_config):
        dim = self.dim if kernel_config.wrapper != 'mlp' \
            else kernel_config.out_features
        base_kernel = ScaleKernel(MaternKernel(ard_num_dims=dim))

        if kernel_config.wrapper == 'mlp':
            config = {
                'in_features': self.dim,
                'hidden_features': kernel_config.hidden_features,
            }
            kernel_final_layer = nn.Sequential(
                nn.Linear(
                    in_features=kernel_config.hidden_features[-1],
                    out_features=kernel_config.out_features,
                ),
                nn.Tanh(),
            )
        else:
            config = dict()
            kernel_final_layer = nn.Identity()

        wrapper = create_wrapper(
            kernel_config.wrapper,
            config=config,
        )

        covar_module = WrapperKernel(
            base_kernel,
            wrapper,
            kernel_final_layer,
        )
        return covar_module

    def create_model(self, train_X, train_Y):
        if self.is_share_parameters:
            mean_module, covar_module = self.create_shared_mean_covar_module(self.mean_config, self.kernel_config)
        else:
            mean_module = self.create_mean_module(self.mean_config)
            covar_module = self.create_covar_module(self.kernel_config)

        model = SingleTaskGP(train_X, train_Y, covar_module=covar_module, mean_module=mean_module)
        model.likelihood.noise_covar.register_constraint('raw_noise', GreaterThan(1e-4))
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
        model, mll = model.to(self.device), mll.to(self.device)

        return mll, model

    def load_pretrain(self, model, path):
        state_dict = torch.load(path)
        model.load_state_dict(state_dict, strict=False)
        log.info('Load from {}'.format(path))

        finetune_param_names = ['likelihood', 'covar_module.base_kernel']
        for name, param in model.named_parameters():
            is_finetune = [name for i in finetune_param_names if name.startswith(i)]
            if not is_finetune:
                param.requires_grad = False
                    
        return model

    def optimize_model(self, mll, model, train_X, train_Y):
        if self.pretrain_config.save_path is not None:
            path = self.pretrain_config.save_path
            path = path if path.endswith('.pth') else path + '.pth'
            model = self.load_pretrain(model, path)

        mll_opt = self.train_config.mll_opt
        mll_opt_lr = self.train_config.mll_opt_lr
        mll_opt_epochs = self.train_config.mll_opt_epochs
        if mll_opt == 'l-bfgs':
            fit_gpytorch_mll(mll)
        elif mll_opt == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=mll_opt_lr)
            model.train()
            model.likelihood.train()
            for _ in range(mll_opt_epochs):
                optimizer.zero_grad()
                output = model(train_X)
                loss = - mll(output, train_Y.reshape(-1))
                loss.backward()
                optimizer.step()
            model.eval()
            model.likelihood.eval()
        else:
            raise NotImplementedError

    def create_acqf(self, model, train_X, train_Y):
        if self.acqf_config.name == 'EI':
            AF = ExpectedImprovement(model, train_Y.max())
        else:
            raise NotImplementedError
        return AF

    def optimize_acqf(self, AF):
        bounds = torch.vstack((torch.zeros(self.dim), torch.ones(self.dim))).double().to(self.device)
        if self.acqf_config.acqf_opt == 'l-bfgs':
            next_X, _ = botorch.optim.optimize.optimize_acqf(AF, bounds=bounds, q=self.q, num_restarts=10, raw_samples=1024)
        else:
            raise NotImplementedError

        assert next_X.shape == (self.q, self.dim)
        return next_X

    def preprocess(self):
        train_X = (self.X - self.lb) / (self.ub - self.lb)
        train_Y = (self.Y - self.Y.mean()) / (self.Y.std() + 1e-6)
        train_X, train_Y = train_X.to(self.device), train_Y.to(self.device)
        
        return train_X.double(), train_Y.double()

    def postprocess(self, next_X):
        next_X = next_X.to('cpu')
        next_X = self.lb + next_X * (self.ub - self.lb)
        return next_X

    def ask(self) -> Tensor:
        """
        Outputs:
            Tensor with shape (q, dim)
        """
        if len(self.X) == 0:
            next_X = self.init()
        else:
            train_X, train_Y = self.preprocess()
            mll, model = self.create_model(train_X, train_Y)
            self.optimize_model(mll, model, train_X, train_Y)
            AF = self.create_acqf(model, train_X, train_Y)
            next_X = self.optimize_acqf(AF)
            next_X = self.postprocess(next_X)

        return next_X

    def tell(self, X: Tensor, Y: Tensor) -> Tensor:
        """
        Inputs:
            X: Tensor with shape (bs, dim)
            Y: Tensor with shape (bs, 1)
        """
        X, Y = X.to(self.X), Y.to(self.Y)
        self.X = torch.vstack((self.X, X))
        self.Y = torch.vstack((self.Y, Y))

    def train(self, train_id2dataset: Dict[str, SimpleDataset], val_id2dataset: Dict[str, SimpleDataset]):
        config = self.pretrain_config
        optim_config = config.optim_config
        device = self.device

        # merge all dataset
        train_X, train_Y = [], []
        for dataset in train_id2dataset.values():
            X, Y = dataset[np.arange(len(dataset))]
            train_X.append(X)
            train_Y.append(Y)
        train_X, train_Y = torch.vstack(train_X), torch.vstack(train_Y)
        assert (train_X >= 0).all() and (train_X <= 1).all()
        # train_X = X_transform(train_X, self.lb, self.ub, device)
        train_Y, _ , _ = Y_transform(train_Y, device=device)

        # initialize model
        mll, model = self.create_model(train_X, train_Y)
        optimizer = optim.Adam(model.parameters(), lr=optim_config.lr)
        best_model, best_mll_val = None, None

        for epoch in range(config.epochs):
            model.train()
            model.likelihood.train()

            # sample the dataset
            dataset_idx = np.random.randint(low=0, high=len(train_id2dataset))
            key = list(train_id2dataset.keys())[dataset_idx]
            dataset = train_id2dataset[key]

            # sample the data
            idx = np.random.choice(len(dataset), config.bs, replace=True)
            X, Y = dataset[idx]
            X, Y = X.to(device), Y.to(device)

            # scaling and shifting
            shifting = 2 * config.shifting * torch.rand(1).item() - config.shifting
            Y = Y + shifting

            model.set_train_data(X, Y, strict=False)
            output = model(X)
            loss = - mll(output, Y.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # eval
            if (epoch + 1) % config.eval_intervals == 0:
                model.eval()
                model.likelihood.eval()
                if val_id2dataset is not None:
                    id2dataset = val_id2dataset
                else:
                    id2dataset = train_id2dataset
                    log.info('Eval on training dataset')

                mean_mll_val_list = []
                for dataset_id in id2dataset:
                    dataset = id2dataset[dataset_id]
                    mll_val_list = []
                    for idx in np.split(np.arange(len(dataset)), np.arange(config.bs, len(dataset), config.bs)):
                        val_X, val_Y = dataset[idx] 
                        val_X, val_Y = val_X.to(device), val_Y.to(device)
                        with gpytorch.settings.prior_mode():
                            output = model.likelihood(model(val_X))
                            mll_val_list.append(output.log_prob(val_Y).mean().item())
                    mll_val = np.mean(mll_val_list)
                    
                    mean_mll_val_list.append(mll_val)
                    log.info('Epoch: {}, dataset id: {}, loss: {}'.format(epoch, dataset_id, mll_val))

                mean_mll_val = np.mean(mean_mll_val_list)
                log.info('Epoch: {}, mean loss: {}'.format(epoch, mean_mll_val))

                # record best model
                if best_mll_val is None or best_mll_val < mean_mll_val:
                    best_model = copy.deepcopy(model)
                    best_mll_val = mean_mll_val

        return model, mean_mll_val, best_model, best_mll_val