import chex
import fiddle as fdl
import jax
import jax.numpy as jnp

from tabular_mvdrl import kernels
from tabular_mvdrl.agents.cat_td import CatTDTrainer

from . import cat_td as cat_td_configs


def dirichlet_prior(key: chex.PRNGKey, n: int, alpha: float) -> chex.Array:
    return jax.random.dirichlet(key, alpha * jnp.ones(n))


def base(**kwargs) -> fdl.Config[CatTDTrainer]:
    cfg = cat_td_configs.base(**kwargs)
    cfg.signed = True
    return cfg


def rowland(**kwargs) -> fdl.Config[CatTDTrainer]:
    cfg = cat_td_configs.rowland(**kwargs)
    cfg.signed = True
    return cfg


def rowland_multivariate(
    **kwargs,
) -> fdl.Config[CatTDTrainer]:
    cfg = cat_td_configs.rowland_multivariate(**kwargs)
    cfg.signed = True


### FIDDLERS


def extrapolated_rowland(cfg: fdl.Config[CatTDTrainer]):
    cat_td_configs.extrapolated_rowland(cfg)


def finite_horizon(cfg: fdl.Config[CatTDTrainer], horizon=4):
    cat_td_configs.finite_horizon(cfg)


def terminal_reward(**kwargs):
    cfg = cat_td_configs.terminal_reward(**kwargs)
    cfg.signed = True
    return cfg


def l1_kernel(cfg: fdl.Config[CatTDTrainer]):
    cfg.kernel = kernels.l1
