import d3rlpy
from IPython import embed
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.base import _serialize_params
import argparse
import os
import json
parser = argparse.ArgumentParser()
parser.add_argument("--d", type=int)
parser.add_argument("--n", type=int)
parser.parse_args()
res = parser.parse_args()

d = res.d # dimension of middle layer
n = res.n # number of episodes
assert res.d is not None
assert res.n is not None

path = 'models/models_cp/'
filename = 'dqn_d' + str(d) + '_n' + str(n)
filepath_json = os.path.join(path, filename + '.json')
filepath_model = os.path.join(path, filename + '.pt')
if not os.path.exists(path):
    os.makedirs(path)

dataset, env = get_cartpole()
dqn = DQN()
dqn = DQN.from_json(filepath_json)
dqn.load_model(filepath_model)

res = evaluate_on_environment(env, n_trials=32, render=False)(dqn)
print("\nAverage value of DQN d: " + str(d) + " n: " + str(n))
print(res)