import unittest

import jax
import jax.numpy as jnp

from tabular_mvdrl.models import EWPModel
from tabular_mvdrl.utils import support_init


class TestEWPModel(unittest.TestCase):
    def setUp(self):
        self.rng = jax.random.PRNGKey(0)
        self.num_states = 3
        self.reward_dim = 2
        self.num_atoms = 5
        self.explicit_support = jax.random.normal(
            self.rng, shape=(self.num_atoms, self.reward_dim)
        )
        self.support_initializer = support_init.repeated_map(
            support_init.explicit_support(self.explicit_support), self.num_states
        )

    def test_support_format(self):
        ewp_model = EWPModel(self.num_states, self.reward_dim, self.num_atoms)
        params = ewp_model.init(self.rng, jnp.int32(0))
        support = ewp_model.apply(params, jnp.arange(self.num_states))
        self.assertListEqual(
            list(support.shape), [self.num_states, self.num_atoms, self.reward_dim]
        )

    def test_deterministic_param_init(self):
        ewp_model = EWPModel(self.num_states, self.reward_dim, self.num_atoms)
        params = ewp_model.init_with_support(
            self.rng, jnp.int32(0), self.support_initializer
        )
        support = ewp_model.apply(params, jnp.arange(self.num_states))
        true_support_flat = jnp.reshape(self.explicit_support, (-1,))
        for s in range(self.num_states):
            state_support = support[s]
            state_support_flat = jnp.reshape(state_support, (-1,))
            self.assertListEqual(list(state_support_flat), list(true_support_flat))


if __name__ == "__main__":
    unittest.main()
