import unittest as ut

import numpy as np

from comab.algo.comab_estimation_tools import neighborhood, grid, v_lr, v_lr_all, B_mask


class TestComabEstimationTools(ut.TestCase):
    def test_neighbors(self):
        cases = [
            (2000, 100, 6, 1, 57),
            (5, 2, 3, 1, 4),
            (5, 2, 1, 1, 2),
        ]

        for N, p, n, l, u in cases:
            with self.subTest(msg="Checking assignment", N=N, p=p, n=n, l=l, u=u):
                # init

                # run
                _l, _u = neighborhood(n, p, N)

                # check
                self.assertEqual(l, _l)
                self.assertEqual(u, _u)

    def test_grid(self):
        cases = [
            (2000, 100, np.array([1, 50, 123, 233, 398, 645, 1016, 1572])),
            (4, 1, np.array([1, 2, 3])),
            (4, 2, np.array([1, 2, 3])),
            (4, 4, np.array([1, 2, 3])),
            (4, 5, np.array([1, 2]))
        ]

        for N, p, expected in cases:
            with self.subTest(msg="Checking assignment", N=N, p=p, expected=expected):
                # init

                # run
                S = grid(N, p)

                # check
                np.testing.assert_array_equal(S, expected)

    def test_grid_neighbors(self):
        # init
        cases = [
            (np.array([1, 3, 45, 78]), 100, 12, 3, 45),
            (np.array([1, 3, 45, 78]), 100, 45, 45, 45),
            (np.array([1, 3, 45, 78]), 100, 80, 78, 100),
            (np.array([1, 2, 3]), 4, 0, 1, 1),
            (np.array([1, 2, 3]), 4, 4, 3, 4),
        ]

        for S, N, n, v_l, v_r in cases:
            with self.subTest(msg="Checking assignment", S=S, N=N, n=n, v_l=v_l, v_r=v_r):
                # init

                # run
                _v_l, _v_r = v_lr(S, n, N)

                # test
                self.assertEqual(_v_l, v_l)
                self.assertEqual(_v_r, v_r)

    def test_B_mask(self):
        # init
        N = 4
        S = np.array([1,2,3])
        v_l, _ = v_lr_all(S, N)
        expected = np.zeros((N+1, N+1), dtype=bool)
        expected[4,4] = True
        expected[3,4] = True

        # run
        mask = B_mask(v_l, N, S)

        # check
        np.testing.assert_array_equal(mask, expected)