from lpcmdp.algorithm.utils import *
import torch
from lpcmdp.algorithm.model import *
from torch.utils.data import DataLoader
import itertools
from tqdm import tqdm
from lpcmdp.env.FrozenLake import FrozenLakeEnv, FrozenLakeEnv_nocost
from lpcmdp.env.datacollector import FrozenLake_DataCollector
env_nocost = FrozenLakeEnv_nocost(ncol=8, nrow=8)
exp = ValueIteration(env_nocost, 0.01, env_nocost.gamma)
exp.value_iteration()
# plot_policy(np.array(exp.pi), env_nocost, "w")
env = FrozenLakeEnv(ncol=8, nrow=8)
offline_dataset_collector = FrozenLake_DataCollector(env, expert_pi=exp.pi, percent=0.5)
from lpcmdp.algorithm.coptidice import CoptidiceTrainer

coptidice_trainer = CoptidiceTrainer(env, offline_dataset_collector)
coptidice_policy = coptidice_trainer.train()
env.plot_policy(coptidice_policy)
print(coptidice_trainer.get_logger())