import unittest
from unittest import TestCase

import numpy as np
import torch

from src.analysis.logging.common_logging_utils import LogMetricModelLogs


class TestLogMetricModelLogs(TestCase):
    def test_log_metric_model_logs(self):
        """Test the model log logging utility."""
        # Construct a dummy model_logs dictionary containing floats, ints, bools,
        # numpy arrays of differing shapes and pytorch arrays of differing shapes.
        model_logs = dict()
        model_logs["a"] = int(1)
        model_logs["b"] = float(1)
        model_logs["c"] = bool(1)
        model_logs["d"] = np.array([1])
        model_logs["e"] = np.array([1])[None]
        model_logs["f"] = np.array([[1, 1]])
        model_logs["g"] = torch.tensor([1])
        model_logs["h"] = torch.tensor([1])[None]
        model_logs["i"] = torch.tensor([1, 1])
        model_logs["j"] = list([1, 1])
        model_logs["k"] = dict({"a": 1, "b": 2})

        # Call the logging function. See if it accepts the right objects.
        pl_system = None
        kwargs = dict(model_logs=model_logs)
        metric_logs = dict()

        model_log_logger = LogMetricModelLogs()
        metric_logs = model_log_logger(metric_logs=metric_logs, pl_system=pl_system, **kwargs)

        # Check that the correct metrics are added.
        all_keys = [model_log_logger.append_model_logs_prefix_to_key(k) for k in model_logs.keys()]
        scalar_keys = [model_log_logger.append_model_logs_prefix_to_key(k)
                       for k in ["a", "b", "c", "d", "e", "g", "h"]]
        for key in all_keys:
            if key in scalar_keys:
                self.assertIn(key, metric_logs.keys())
            else:
                self.assertNotIn(key, metric_logs.keys())


if __name__ == "__main__":
    """
    Run from root. 
    python -m unittest -v src.analysis.logging.test_common_logging_utils
    """
    unittest.main()
