import unittest
import os

import torch

from bbo.problems import HPOBProblem


if os.getenv('HPOB_ROOT_DIR') is None:
    os.environ['HPOB_ROOT_DIR'] = os.path.expanduser('~/dataset/hpob')


class HPOBSurrogateTest(unittest.TestCase):
    def test_run(self):
        root_dir = os.getenv('HPOB_ROOT_DIR')
        hpo_surrogate = HPOBProblem('4796', '3549', root_dir)
        X = torch.rand(10, 3)
        Y = hpo_surrogate(X)
        self.assertTrue(torch.is_tensor(Y))
        self.assertEqual(Y.shape, (10, 1))
