import sys
sys.path.append("../")

# from __future__ import division
import tensorflow as tf
import numpy as np
from collections import deque
import random
import gym
from gym import wrappers
from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv
from core import *
from utils_latentPolicy_sac_lstm_zt_zt1 import *
import os
import tensorflow_probability as tfp
import multiprocessing as mp
import os
import d4rl
import json
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

slim = tf.contrib.slim
rnn = tf.contrib.rnn
tfd = tfp.distributions
config=tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True

CLUSTER = 4
ENV_NAME = 'ITS'
ENV_INFO = 'ITS_cluster_'+str(CLUSTER)

# use normalized pattern-seg. Only used for get basic dimension of state and action
with open('../processed_data/{}/train_cluster_{}.npy'.format(ENV_NAME, CLUSTER), 'rb') as f: 
    DATA = np.load(f, allow_pickle=True)


def evaluate(args):
    
    def learn_dist_from_s(state, code_size, reuse=tf.AUTO_REUSE, is_training=True, var_scope="BC"):
        with tf.variable_scope(var_scope, reuse=reuse) as scope:
            with slim.arg_scope([slim.fully_connected], 
                                    activation_fn=tf.nn.relu,
                                    weights_initializer=tf.glorot_uniform_initializer,
                                    weights_regularizer=slim.l2_regularizer(0.001),
                                    biases_regularizer=slim.l2_regularizer(0.001),
                                    normalizer_fn = slim.batch_norm,
                                    normalizer_params = {"is_training": is_training},
                                    reuse = reuse,
                                    scope = scope):
                # is_training = False for evaluation
                x = slim.fully_connected(state, 128, scope="fc1")
                x = slim.fully_connected(x, 64, scope="fc2")
                loc = slim.fully_connected(x, code_size, activation_fn=None, scope="loc")
                scale =slim.fully_connected(x, code_size, activation_fn=tf.nn.softplus, scope="scale")
    #             dist = tfd.MultivariateNormalDiag(loc, scale)
                out_sample = tfd.TruncatedNormal(loc, scale, -1., 1.).sample() # -1, 1 bound
                out_log_prob = trun_normal_log_prob(action_holder, loc, scale, -1., 1.)
                return out_sample, out_log_prob
        
    def trun_normal_log_prob(x, mu, std, low, high):
        z = tfd.Normal(0,1).cdf((high-x)/(std+EPS)) - tfd.Normal(0,1).cdf((low-x)/(std+EPS))
        return tf.reduce_sum(-0.5*((x - mu) / (std+EPS))**2 - 0.5*tf.log(2*np.pi) - tf.log(std*z), axis=1, name="log_prob")
    
    ope_path = args
    
#     try:

    LR = 0.0003
    GAMMA = .995
    BUFFER_SIZE_SAC = 2*10**6
    MINIBATCH_SIZE_SAC = 256
    MINIBATCH_SIZE_OPE = 4
    RANDOM_SEED = 2599
    MAX_EPISODES = 2000
    MAX_EPISODE_LEN = len(DATA[0]['observations'])
    NUM_OPE_MODELS = 1
    CODE_SIZE = 16
    EXPLORATION = .3
    REPEAT = 1
    BUFFER_SIZE_OPE = 3000
    beta = 1.


    OPE_LR = 0.001
    OPE_DS = 1000
    OPE_DR = 0.98
    
    EPS = 1e-8

    BEST_MAE = 9999.
    vae_seg = []

    network_params = {
    'hidden_sizes':[256, 256],
    'activation':'relu',
    'policy':mlp_gaussian_policy
    }

    rl_params = {
        'env_name':env_name,

        # control params
        'seed': RANDOM_SEED,
        'epochs': MAX_EPISODES,
        'actor_critic':mlp_actor_critic,
        'steps_per_epoch': MAX_EPISODE_LEN,
        'replay_size': BUFFER_SIZE_SAC,
        'batch_size': MINIBATCH_SIZE_SAC,
        'start_epis': 0,
        'max_ep_len': MAX_EPISODE_LEN,
        'save_freq': 10,
        'render': False,

        # rl params
        'gamma': 0.99,
        'polyak': 0.995,
        'lr': LR,
        'grad_clip_val':None,

        # entropy params
        'alpha': 'auto',
        'target_entropy':'auto' # fixed or auto define with -act_dim
    }

    file_appendix = (
        "lstm_vae_" + rl_params['env_name'] + "_" + str(MAX_EPISODES)
        + "epi_repeat"+ str(REPEAT) + "_" + str(LR) + "_"
        + str(OPE_LR) + "_"
        + str(OPE_DS) + "_"
        + str(OPE_DR) + "_"
        + str(CODE_SIZE) + "_"
        + str(beta) + "_"
        + str(RANDOM_SEED)
    )

#     env = gym.make(rl_params['env_name'])
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
#     env.seed(RANDOM_SEED)

    env_state_dim = DATA[0]['observations'].shape[1] # NEED MOD
    # state_dim = CODE_SIZE
    env_action_dim = DATA[0]['actions'].shape[1] # NEED MOD
    # get mean and std
    ob = [i for u in DATA for j in u['observations'] for i in j]
    OBS_MEAN = sum(ob)/len(ob)
    OBS_STD = np.std(ob)

    rw = [j for u in DATA for j in u['rewards']]
    REW_MEAN = sum(rw)/len(rw)
    REW_STD = np.std(rw)
    env_action_bound = None
    env_state_bound = None
    # Ensure action bound is symmetric
#     assert (env.action_space.high == -env.action_space.low)

    graph_ope_models = tf.Graph()

    graph_ac = tf.Graph()
    
    graph_behavior = tf.Graph()
    
    with tf.Session(config=config, graph=graph_behavior) as sess_behavior:

        with tf.Session(config=config, graph=graph_ope_models) as sess_ope_models:

            with graph_ope_models.as_default():

                ope_model = OPE_Model(
                    graph_ope_models, sess_ope_models, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,
                    env_state_dim, env_state_bound, env_action_dim, file_appendix,
                    BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, beta,
                    is_training=False
                )

                ope_saver = ope_model.saver

                ope_saver.restore(sess_ope_models, ope_path)


    #             d4rl_qlearning = d4rl.qlearning_dataset(env)

                obs_mean = OBS_MEAN
                obs_std = OBS_STD

                rew_mean = REW_MEAN
                rew_std = REW_STD

                class LearnedEnv(object):
                    def __init__(self, model):

                        self.model = model

                    def reset(self):
                        self.model.init_z0_s0()
                        s0 = self.model.sess.run(self.model.decoder_state_sample, 
                                           feed_dict={self.model.decoder_zt_holder:self.model.zt}).reshape(-1)

                        self.obs = s0
                        return s0

                    def step(self, u):
                        new_obs, reward = self.model.get_zt1_s2_r(np.reshape(u, (1, env_action_dim)))
                        self.obs = new_obs
                        self.model.update_zt()

                        return new_obs, reward, False, {}

                learned_env = LearnedEnv(ope_model)

                np.random.seed(RANDOM_SEED)
                tf.set_random_seed(RANDOM_SEED)

                ep_rewards = []
    #             policy = D4RL_Policy(target_policy_path)
                
        
                # read learnt behavior
                with graph_behavior.as_default():
                    state_holder = tf.placeholder(shape=[None, env_state_dim], dtype=tf.float32, name='state_holder')
                    action_holder = tf.placeholder(shape=[None, env_action_dim], dtype=tf.float32, name='action_holder')
                    _learn_dist_from_s = learn_dist_from_s(state_holder, env_action_dim, reuse=tf.AUTO_REUSE, is_training=False)

                    #First let's load meta graph and restore weights
                    behavior_saver = tf.train.Saver()
                    behavior_saver.restore(sess_behavior, './saved_dist/state_action_dist.ckpt')

                for i in range(50):

                    terminal = 0
                    user_seg = {'observations':[],'actions':[],'rewards':[],'next_observations':[]}

                    s = learned_env.reset()
                    s = s.reshape(env_state_dim)*obs_std + obs_mean
                    ep_reward = 0

                    for j in range(MAX_EPISODE_LEN):
                        user_seg['observations'].append(s)
                        if j % REPEAT == 0:
    #                         a, _ = policy.act(np.reshape(s, (env_state_dim,)), np.zeros((env_action_dim,)))
#                             a = np.eye(env_action_dim)[np.random.choice(env_action_dim, 1)] # NEED MOD, OLD VERSION WITH RANDOM POLICY                            
                            feed_dict={state_holder : [s],} # two dimension [[]]
                            a = sess_behavior.run(_learn_dist_from_s[0], feed_dict=feed_dict)[0]

                        s2, r, terminal, info = learned_env.step(a)
                        r = r*rew_std + rew_mean
                        s2 = s2.reshape(env_state_dim)*obs_std + obs_mean


                        ep_reward += r*(GAMMA**j)

                        s = s2

    #                     if terminal or j == MAX_EPISODE_LEN-1:
    #                         ep_rewards += [ep_reward]

    #                         break
                        user_seg['next_observations'].append(s2)
                        user_seg['rewards'].append(r)
                        user_seg['actions'].append(a)

                    vae_seg.append(user_seg)
                with open('./saved_augmented_data/'+ope_path.replace('./saved_model/', '').replace('/aug_best.ckpt', '')+'_augmented_segment.npy', 'wb') as f:
                    np.save(f, vae_seg)

if __name__ == '__main__':
	OPEs = ["./saved_model/"+i+"/aug_best.ckpt" for i in os.listdir("./saved_model/") ]
	pool = mp.Pool(3)
	pool.map(evaluate, [o_path for o_path in OPEs])
	pool.close()
	pool.join()