import unittest

import torch

from bbo.algorithms.np.transformer.transformer import Transformer


class TransformerTest(unittest.TestCase):
    def setUp(self):
        x_dim, n_out = 20, 3
        d_model = 64
        n_head = 4
        n_hidden = 4 * d_model
        dropout = 0.1
        n_layer = 2
        self.transformer = Transformer(
            x_dim, n_out, d_model, n_head, n_hidden,
            dropout, n_layer,
        )
        self.x_dim = x_dim
        self.n_out = n_out

    def test_run(self):
        seq_len, bs = 20, 32
        single_eval_pos = 5
        X = torch.randn((seq_len, bs, self.x_dim))
        Y = torch.randn((seq_len, bs, 1))
        out = self.transformer(X, Y, single_eval_pos)
        self.assertEqual(out.shape, (seq_len-single_eval_pos, bs, self.n_out))
