import os
import numpy as np
import torch


from .basefunc import BaseFunc


class Goldstein(BaseFunc):
    def __init__(
        self,
        xsize=10,
        noise_std=0.01,
    ):
        """__init__.

        Parameters
        ----------
        xsize : int
            the size of the discrete input domain
        noise_std : float
            the standard deviation of the noise of the GP model
            to generate the noisy observation

        """
        xsize = int(xsize)
        noise_std = float(noise_std)

        xdim = 2

        super(Goldstein, self).__init__(xdim, xsize, noise_std=noise_std)

        self.module_name = "goldstein"
        self.xsize = xsize
        self.x_domain = BaseFunc.generate_discrete_points(xsize, xdim)

    def get_noiseless_observation_from_inputs(self, x):
        """get function evaluation at input

        Parameters
        ----------
        x : tensor array of size (n, self.xdim)
            inputs to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs x

        """
        with torch.no_grad():
            x = x.reshape(-1, self.xdim)
            x = 4.0 * x - 2.0

            val = (
                -(
                    torch.log(
                        (
                            1
                            + (x[:, 0] + x[:, 1] + 1.0) ** 2
                            * (
                                19.0
                                - 14.0 * x[:, 0]
                                + 3.0 * x[:, 0] ** 2
                                - 14.0 * x[:, 1]
                                + 6.0 * x[:, 0] * x[:, 1]
                                + 3.0 * x[:, 1] ** 2
                            )
                        )
                        * (
                            30.0
                            + (2.0 * x[:, 0] - 3.0 * x[:, 1]) ** 2
                            * (
                                18.0
                                - 32.0 * x[:, 0]
                                + 12 * x[:, 0] ** 2
                                + 48.0 * x[:, 1]
                                - 36 * x[:, 0] * x[:, 1]
                                + 27.0 * x[:, 1] ** 2
                            )
                        )
                    )
                    - 8.693
                )
                / 2.427
            )

            val = val.reshape(
                -1,
            )
        return val


    def get_noiseless_observation_from_input_idxs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array or list of int64 of shape (n,)
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        x = self.x_domain[x_idxs, :].reshape(-1, self.xdim)
        return self.get_noiseless_observation_from_inputs(x)
