
import argparse
import sys
sys.path.append("src")
from load_data import DataLoad
from args_config import parser_config, config_args
from model import *
from s_utils import same_seeds, print_namespace
import torch.optim as optim

from filiting_s import filiting_func
from din_train import din_train_func
parser = argparse.ArgumentParser(description='Initialize Parameters!')
parser = parser_config(parser)
args = parser.parse_args()
args = config_args(args)

para_dict = args.__dict__
modeln = args.modelname
same_seeds(args.seed)
dls = DataLoad(args.datan)

## ex config
para_dict["argv"] = sys.argv[1:]

ec = ExpConfig(para_dict)

## main
def main():
    if modeln == "filiting_training_list":
        filiting_func(
            cuda = para_dict["cuda"],
            N = para_dict["beam_size"],
            topk = para_dict["topk"],
            ec=ec,
            tree_model_path_dict = dls.tree_model_path_dict, 
            bs = para_dict["bs"],
            )
    
    elif modeln == "din_train":
        din_train_func(para_dict, dls, ec)
    
    # elif modeln == "LinearTS":
    #     bandit_model = LinearTS(para_dict)
    #     bandit_model.run(dls)
    
    else:
        parser.print_help()
        assert False


if __name__ == "__main__":
    main()
        