import os
import pickle
import numpy as np
import torch

from gp import GP
from util import get_iset_idxs

class BaseFunc:
    def __init__(self, xdim, xsize, noise_std=None):
        self.xdim = xdim
        self.xsize = xsize
        self.noise_std = noise_std

        self._iset_idxs = {} # iset_type -> iset
        self._iset = {} # iset_type -> iset
        self._iset_evals = {} # iset_type -> iset

        self._gp_hyperparameters = None

    @staticmethod
    def generate_discrete_points(n, dim=1, low=0.0, high=1.0):
        if dim == 1:
            return torch.linspace(low, high, n).reshape(-1, 1)
        elif dim > 1:
            rand01 = np.loadtxt("random_inputs.txt")
            return torch.from_numpy(rand01[: n * dim]).float().reshape(n, dim)
        else:
            raise Exception("Dimension must be positive!")

    @property
    def gp_hyperparameters(self):
        """gp_hyperparameters getter

        Note
        ----
        This is used when we want to do experiment assuming knowing GP hyperparameters

        Returns
        -------
        dictionary
            optimized GP hyperparameters of the function
            e.g.,
            {
                "likelihood.noise_covar.noise": 0.01,
                "covar_module.base_kernel.lengthscale": [0.1] * self.xdim,
                "covar_module.outputscale": 1.0,
                "mean_module.constant": 0.0,
            }
        """
        if self._gp_hyperparameters is not None:
            return self._gp_hyperparameters

        filename = f"func/gp_hyperparameters/gp_hyperparameters_{self.module_name}_xsize_{self.xsize}.pkl"
        if os.path.isfile(filename):
            # load gp hyperparameters from file
            with open(filename, "rb") as file:
                self._gp_hyperparameters = pickle.load(file)

        else:
            # optimize for the GP hyperparameters
            x_idxs = torch.tensor(list(range(self.xsize)), dtype=torch.int64)
            y = self.get_noisy_observation_from_input_idxs(x_idxs)

            init_hyperparameters = {
                "likelihood.noise_covar.noise": self.noise_std**2,
                "covar_module.base_kernel.lengthscale": [0.1] * self.xdim,
                "covar_module.outputscale": 1.0,
                "mean_module.constant": 0.0,
            }

            gp_model = GP(
                self.x_domain,
                y,
                initialization=init_hyperparameters,
                prior=None,
                ard=True,
            )

            GP.optimize_hyperparameters(
                gp_model,
                self.x_domain,
                y,
                learning_rate=0.1,
                training_iter=50,
                verbose=False,
            )

            self._gp_hyperparameters = {
                "likelihood.noise_covar.noise": gp_model.likelihood.noise_covar.noise.item(),
                "covar_module.base_kernel.lengthscale": gp_model.covar_module.base_kernel.lengthscale.detach().numpy(),
                "covar_module.outputscale": gp_model.covar_module.outputscale.item(),
                "mean_module.constant": gp_model.mean_module.constant.item(),
            }

            # write to file
            with open(filename, "wb") as file:
                pickle.dump(
                    self._gp_hyperparameters,
                    file,
                    protocol=pickle.HIGHEST_PROTOCOL,
                )

        return self._gp_hyperparameters


    def get_params(self, transformation):
        raise Exception("To be implemented in child class!")

    def get_init_observations(self, n, seed=0):
        with torch.no_grad():
            init_idxs = torch.randint(low=0, high=self.x_domain.shape[0], size=(n,))
            init_x = self.x_domain[init_idxs]
            init_y = self.get_noisy_observation_from_input_idxs(init_idxs)

        return init_x, init_y

    def iset(self, iset_type):
        if iset_type in self._iset:
            return self._iset[iset_type]

        with torch.no_grad():
            domain_idxs = list(range(self.xsize))
            func_range = self.get_noiseless_observation_from_input_idxs(domain_idxs)

            self._iset_idxs[iset_type] = get_iset_idxs(func_range, iset_type)
            self._iset[iset_type] = self.x_domain[self._iset_idxs[iset_type]]
            self._iset_evals[iset_type] = func_range[self._iset_idxs[iset_type]]

        return self._iset[iset_type]

    def iset_idxs(self, iset_type):
        if iset_type not in self._iset_idxs:
            self.iset(iset_type)
        return self._iset_idxs[iset_type]

    def iset_evals(self, iset_type):
        if iset_type not in self._iset_evals:
            self.iset(iset_type)
        return self._iset_evals[iset_type]

    def get_noisy_observation_from_input_idxs(self, x_idxs):
        if self.noise_std is None:
            raise Exception("Unknown noise")

        with torch.no_grad():
            n_obs = len(x_idxs)

        return (
            self.get_noiseless_observation_from_input_idxs(x_idxs)
            + torch.randn(n_obs) * self.noise_std
        )
