import numpy as np
import torch
import gym
import argparse
import os
import d4rl
import datetime
from buffer import ReplayBuffer
from latent_a import OffRL
from tensorboardX import SummaryWriter
from datetime import datetime
import h5py
from tqdm import tqdm
now = datetime.now()


current_time = now.strftime("%H:%M:%S")
def eval_policy(policy, sim_env_name, seed, mean, std, seed_offset=100, eval_episodes=10):
	eval_env = gym.make(sim_env_name)
	eval_env.seed(seed + seed_offset)

	avg_reward = 0.
	for _ in range(eval_episodes):
		state, done = eval_env.reset(), False
		while not done:
			state = (np.array(state).reshape(1,-1) - mean)/std
			action = policy.select_action(state)
			state, reward, done, _ = eval_env.step(action)
			avg_reward += reward

	avg_reward /= eval_episodes
	d4rl_score = eval_env.get_normalized_score(avg_reward) * 100
	return d4rl_score


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	# Experiment
	parser.add_argument("--algo", default="A2PO")        # OpenAI gym environment name
	parser.add_argument("--env", default="halfcheetah-random-expert-v2")        # OpenAI gym environment name
	parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
	parser.add_argument("--eval_freq", default=20000, type=int)       # How often (time steps) we evaluate
	parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
	parser.add_argument("--save_model", action="store_true")        # Save model and optimizer parameters
	parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
	parser.add_argument("--discount", default=0.99)                 # Discount factor
	parser.add_argument("--tau", default=0.005)                     # Target network update rate
	parser.add_argument("--doubleq_min", default=1.0, type=float)
	parser.add_argument("--epsilon", default=0.1, type=float)
	parser.add_argument("--policy_noise", default=0.1)              # Noise added to target policy during critic update
	parser.add_argument("--noise_clip", default=0.2)                # Range to clip target policy noise
	parser.add_argument("--policy_freq", default=2, type=int)      # Frequency of delayed policy updates
	parser.add_argument("--use_cuda", default=True, type=bool)
	parser.add_argument("--adv_step", default=200000, type=int)
	parser.add_argument("--use_discrete", default=False, type=bool)
	parser.add_argument("--lr_decay", default=False)
	parser.add_argument("--alpha", default=1.0, type=float)
	parser.add_argument("--normalize", default=True)
	parser.add_argument("--reward_tune", default=False, type=bool)

	args = parser.parse_args()
	device = torch.device("cpu")
	if args.use_cuda:
		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print("---------------------------------------")
	print(f"Setting: Training {args.algo}, Env: {args.env}, Seed: {args.seed}, Discrete: {args.use_discrete}")
	print("---------------------------------------")
	model_path = f"./logs/{args.env}/A2PO/{args.seed}/"
	data_path = f"./logs/{args.env}/A2PO/{args.seed}/"
	if not os.path.exists(model_path):
		os.makedirs(model_path)
	file_list = os.listdir(model_path)
	sim_env_name=args.env
	env = gym.make(sim_env_name)

	# Set seeds
	env.seed(args.seed)
	env.action_space.seed(args.seed)
	torch.manual_seed(args.seed)
	np.random.seed(args.seed)

	state_dim = env.observation_space.shape[0]
	action_dim = env.action_space.shape[0]
	max_action = float(env.action_space.high[0])

	kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"max_action": max_action,
		"discount": args.discount,
		"tau": args.tau,
		"device": device,
		"epsilon": args.epsilon,
		"doubleq_min": args.doubleq_min,
		"policy_noise": args.policy_noise,
		"noise_clip": args.noise_clip,
		"lr_decay": args.lr_decay,
		"alpha": args.alpha,
		"adv_step": args.adv_step,
		"use_discrete": args.use_discrete,
	}

	# Initialize policy
	policy=OffRL(**kwargs)
	dataset = d4rl.qlearning_dataset(env)

	replay_buffer = ReplayBuffer(state_dim, action_dim)
	replay_buffer.convert_D4RL(dataset,args.reward_tune)
	if args.normalize:
		mean,std = replay_buffer.normalize_states() 
	else:
		mean,std = 0,1
	writer = SummaryWriter(
		logdir=f'runs/{args.algo}_{args.env}_{args.seed}_{current_time}'
	)
	evaluations = []

	best_reward = float('-inf')
	best_model_path = None
	for t in tqdm(range(int(args.max_timesteps)), desc='PI training', ncols=75):
		policy.policy_train(replay_buffer, writer, sim_env_name, args.batch_size)
		# Evaluate episode
		if t % args.eval_freq == 0:
			eval_res = eval_policy(policy, sim_env_name, args.seed, mean, std)
			evaluations.append(eval_res)
			writer.add_scalar(f'{args.env}/eval_reward', eval_res, t)
			print(f"| {args.algo} | {args.env}_{args.seed} | iterations: {t} | eval_reward: {eval_res} |")
			np.save(f"{data_path}/{args.seed}", evaluations)
			policy.save(model_path)

