import unittest

import jax
import jax.numpy as jnp

from tabular_mvdrl.utils import support_init


class TestSupportInitializers(unittest.TestCase):
    def setUp(self):
        self.rng = jax.random.PRNGKey(0)
        self.num_states = 3
        self.reward_dim = 2
        self.bins_per_dim = 4
        self.minvals = jnp.array([0.0, -5.0])
        self.maxvals = jnp.array([4.0, -1.0])

    def test_uniform_lattice(self):
        lattice_map = support_init.uniform_lattice(
            self.reward_dim, self.bins_per_dim, maxval=1.0, minval=0.0
        )(self.rng)
        self.assertListEqual(
            list(lattice_map.shape),
            [self.bins_per_dim**self.reward_dim, self.reward_dim],
        )
        self.assertListEqual(
            list(lattice_map[:: self.bins_per_dim, 0]),
            list(jnp.linspace(0.0, 1.0, self.bins_per_dim)),
        )
        self.assertListEqual(
            list(lattice_map[: self.bins_per_dim, 1]),
            list(jnp.linspace(0.0, 1.0, self.bins_per_dim)),
        )

    def test_uniform_lattice_rectangle(self):
        lattice_map = support_init.uniform_lattice(
            2, self.bins_per_dim, maxval=self.maxvals, minval=self.minvals
        )(self.rng)
        self.assertListEqual(
            list(lattice_map.shape),
            [self.bins_per_dim**self.reward_dim, self.reward_dim],
        )
        self.assertListEqual(
            list(lattice_map[:: self.bins_per_dim, 0]),
            list(jnp.linspace(self.minvals[0], self.maxvals[0], self.bins_per_dim)),
        )
        self.assertListEqual(
            list(lattice_map[: self.bins_per_dim, 1]),
            list(jnp.linspace(self.minvals[1], self.maxvals[1], self.bins_per_dim)),
        )

    def test_repeat_map(self):
        lattice_map_fn = support_init.uniform_lattice(
            self.reward_dim, self.bins_per_dim, maxval=1.0, minval=1.0
        )
        one_lattice = lattice_map_fn(self.rng)
        lattice_flat = jnp.reshape(one_lattice, (-1,))
        repeated_lattice = support_init.repeated_map(lattice_map_fn, self.num_states)(
            self.rng
        )
        self.assertListEqual(
            list(repeated_lattice.shape),
            [self.num_states, self.bins_per_dim**self.reward_dim, self.reward_dim],
        )
        for s in range(self.num_states):
            state_lattice = repeated_lattice[s]
            state_lattice_flat = jnp.reshape(state_lattice, (-1,))
            self.assertListEqual(list(state_lattice_flat), list(lattice_flat))


if __name__ == "__main__":
    unittest.main()
