import yaml
import gym

from omegaconf import OmegaConf


def load_config(args, config_path='./config/base_config.yaml'):
    config = OmegaConf.load(config_path)
    
    env = gym.make(args.env_name)
    seg_len = config.data.seg_len
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    max_episode_timestep = env._max_episode_steps
    
    config.env.name = args.env_name
    config.env.state_dim = state_dim
    config.env.action_dim = action_dim
    config.env.max_action = max_action
    config.env.max_episode_timestep = max_episode_timestep
    
    config.data.path = f'./preference_dataset/small/{args.env_name}.pkl'
    config.policy.path = f'./weights/pretrained/{args.env_name}.pth'
    
    config.model.input_dim = action_dim * seg_len
    config.model.context_dim = state_dim
    
    del env
    return config


