from clus.config.master_config import *
import optax
from clus.models.utils.train_state import *
from clus.models.model.basic import *
from clus.models.model.cdm import *

PRETRAIN_CONFIG = CLUSConfig(
    optimizer_config=OptimizerConfig(**{
            'optimizer_cls' : optax.adam,
            'optimizer_kwargs' : {
                'learning_rate' : 3e-5 ,
                'b1' : 0.9,
            },
        }
    ),
    input_config=InputConfig(
        
    ),
    model_config=ModelConfig(**{
            'model_cls' : ConditionalDiffusion, 
            'model_kwargs' : {
                'input_config' : None, # lazy loading 
                'optimizer_config' : None, # lazy loading 
                'model_config' : {
                    'model_cls' :FlaxDenoisingBlockMLP,
                    'model_kwargs' : {
                        'dim' : 512,
                        'n_blocks' : 4,
                        'context_emb_dim' : 512,
                        'dropout' : 0.1,
                    }
                }, 
                'clip_denoised' : False,
                'diffusion_step' : 32,
            },
        }
    ),
    exp_config=ExpConfig(**{
            'phase_epoch' : 20000,
            'eval_epoch' : 5000,
            'batch_size' : 1024,
            'eval_env' : True,
            'base_path' : './data/l2ms/test', # base path for saving items
            'phase_optim' : 're_initialize',
            'replay_method' : 'random',  # 'kmeans' or 'random' or 'sequential'
            # 'phase_batch_sz' : 0, # No Replay
            # 'init_model_path' : './data/l2ms/kitchen_base/models/model_0.pkl',
            'init_model_path' : None,
        }
    ),
    scenario_config=ScenarioConfig(**{
            'dataloader_config' : {
                'dataloader_cls' : BaseDataloader,
                'dataloader_kwargs' :{
                    'skill_embedding_path' : 'data/continual_dataset/evolving_world/mm_lang_embedding.pkl',
                    'skill_exclude' : None,
                    'semantic_flag' : False, 
               }
            },
            'phase_config' : None, # lazy loading
            'evaluator_config' : {
                'eval_mode' : 'obs',
                'eval_episodes' : 10,
                'semantic_flag' : True,
            },
        }
    ),
)
