from rl_algorithm.ddqn.config import device
import utils
import tensorboardX
import torch
import wandb


def get_max_episode_length(env):
    if "FourRooms" in env:
        return 100
    elif "MultiRoom-N2-S4" in env:
        return 40
    elif "MultiRoom-N4-S5" in env:
        return 80
    elif "PutNear-6x6" in env:
        return 30
    elif "Crossing" in env:
        return 324
    else:
        return 100


def init(self, env, preprocess_obs, args, train_interval=1):
    self.device = device
    self.preprocess_obs = preprocess_obs
    self.n_actions = env.action_space.n

    self.learn_step_counter = 1
    self.rnd_learn_step_counter = 1

    # train, test, logging parameter
    self.max_episode_length = get_max_episode_length(args.env)
    self.update_target = 100
    self.train_interval = train_interval
    self.rnd_train_interval = train_interval
    self.test_interval = 5000
    self.log_interval = 500
    self.algorithm = args.algorithm

    # optimizer
    self.optimizer = torch.optim.RMSprop(self.policy_network.parameters(), args.lr)
    if self.reset_multi == 2:
        self.optimizer2 = torch.optim.RMSprop(self.policy_network2.parameters(), args.lr)
    elif self.reset_multi == 4:
        self.optimizer2 = torch.optim.RMSprop(self.policy_network2.parameters(), args.lr)
        self.optimizer3 = torch.optim.RMSprop(self.policy_network3.parameters(), args.lr)
        self.optimizer4 = torch.optim.RMSprop(self.policy_network4.parameters(), args.lr)


def init_exploration(self, args):
    self.exploration_type = args.exploration_type
    self.epsilon = 0.9
    self.eps_end = 0.05
    self.eps_decay_time_steps = 100000
    if self.exploration_type == 'exp-option':
        self.eps_decay_time_steps = -1 
    self.steps_done = 0

    self.n = 0
    self.w = 0
    self.type = "r"




def init_log(self, model_dir):
    self.txt_logger = utils.get_txt_logger(model_dir)
    self.csv_file, self.csv_logger = utils.get_csv_logger(model_dir, "logs.csv")
    self.tb_writer = tensorboardX.SummaryWriter(model_dir)

    self.logs = {}
