from .cohort import Cohort
from .rand import Random, RandomHomogeneous
from .tb import TB
from .mmitra import mMitra

from inspect import signature
import argparse
import ast

# NOTE: ARMMAN DATA NOT SHARED. 

NameToDomain = {
    "random": Random,
    "random_homogeneous": RandomHomogeneous,
    "tb": TB,
    "mmitra": mMitra,
}

def get_from_command_line(args):
    # Parse the arguments from the command line
    parser = argparse.ArgumentParser(allow_abbrev=False)
    for arg in args:
        parser.add_argument(f"--{arg}", type=str, required=True)
    cmd_args = parser.parse_known_args()[0]

    # Convert the arguments to the correct type
    for arg in args:
        try:
            setattr(cmd_args, arg, ast.literal_eval(getattr(cmd_args, arg)))
        except (ValueError, SyntaxError):
            print(f"WARNING: Could not convert value ``{getattr(cmd_args, arg)}`` to non-string type. Using string instead.")
    return cmd_args

def domain_from_cmdline_args(
    domain: Cohort,
    kwargs: dict,
) -> Cohort:
    """
    Wrapper function to get non-provided domain-specific arguments from the command line and return a domain object.

    Args:
        domain (Cohort): Domain class
        kwargs (dict): Keyword arguments to be passed to the domain class
    
    Returns:
        Callable: A domain object initialized with provided + command line arguments
    """

    # Get all the required arguments for the domain
    required_args = signature(domain.__init__).parameters.keys()

    # Compare the required arguments with the arguments provided in kwargs
    missing_args = set(required_args) - set(kwargs.keys()) - set(["self"])

    # If any argument is missing, get it from the command line
    if missing_args:
        args = get_from_command_line(missing_args)
        kwargs.update(vars(args))

    # Initialize the domain object
    domain_obj = domain(**kwargs)

    return domain_obj
