import functools
from typing import Callable

from .base import Environment
from .treasure_conveyor import ConveyorTreasure, MultiConveyorTreasure


_envs = {
    "treasure_conveyor": ConveyorTreasure,
    "multi_treasure_conveyor": MultiConveyorTreasure,
}


def register_environment(name: str, env_class: Callable):
    _envs[name] = env_class


def create(name: str, **kwargs) -> Environment:
    """Creates an Env with a specified brax system."""
    env = _envs[name](**kwargs)
    return env


def create_fn(name: str, **kwargs) -> Callable[..., Environment]:
    """Returns a function that when called, creates an Env."""
    return functools.partial(create, name, **kwargs)
