import sys
sys.path.append("../")

# from __future__ import division
import tensorflow as tf
import numpy as np
from collections import deque
import random
import gym
from gym import wrappers
from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv
from core import *
from utils_latentPolicy_sac_lstm_zt_zt1 import *
import os
import tensorflow_probability as tfp
import multiprocessing as mp
import os
import d4rl
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

slim = tf.contrib.slim
rnn = tf.contrib.rnn
tfd = tfp.distributions
config=tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True



CLUSTER = 4
ENV_NAME = 'ITS'
ENV_INFO = 'ITS_cluster_'+str(CLUSTER)

# use normalized pattern-seg
with open('../processed_data/{}/train_cluster_{}.npy'.format(ENV_NAME, CLUSTER), 'rb') as f: 
    DATA = np.load(f, allow_pickle=True)


def main(args):

    
    lr, ope_lr, ope_ds, ope_dr, beta = args
    # lr, ope_lr, ope_ds, ope_dr, beta = 0.0003, 0.0001, 1000, 0.95, 1.

    LR = lr
    GAMMA = 0.995
    BUFFER_SIZE_SAC = 2*10**6
    MINIBATCH_SIZE_SAC = 256
    MINIBATCH_SIZE_OPE = 4 # 64 # ADJUST ACCORDING TO DATA SIZE
    RANDOM_SEED = 2599
    MAX_EPISODES = 100 # 2000 # ADJUST ACCORDING TO DATA SIZE
    MAX_EPISODE_LEN = len(DATA[0]['observations'])
    NUM_OPE_MODELS = 1
    CODE_SIZE = 16
    EXPLORATION = .4
    REPEAT = 1
    BUFFER_SIZE_OPE = 3000

    OPE_LR = ope_lr
    OPE_DS = ope_ds
    OPE_DR = ope_dr

    BEST_ELBO = -9999.

    #     OPE_LR = 1e-03
    #     OPE_DS = 1000
    #     OPE_DR = .995

    network_params = {
    'hidden_sizes':[256, 256],
    'activation':'relu',
    'policy':mlp_gaussian_policy
    }

    rl_params = {
        'env_name':env_name, 

        # control params
        'seed': RANDOM_SEED,
        'epochs': MAX_EPISODES,
        'actor_critic':mlp_actor_critic,
        'steps_per_epoch': MAX_EPISODE_LEN,
        'replay_size': BUFFER_SIZE_SAC,
        'batch_size': MINIBATCH_SIZE_SAC,
        'start_epis': 0,
        'max_ep_len': MAX_EPISODE_LEN,
        'save_freq': 10,
        'render': False,

        # rl params
        'gamma': 0.99,
        'polyak': 0.995,
        'lr': LR,
        'grad_clip_val':None,

        # entropy params
        'alpha': 'auto',
        'target_entropy':'auto' # fixed or auto define with -act_dim
    }

    file_appendix = (
        "lstm_vae_" + rl_params['env_name'] + "_" + str(MAX_EPISODES)
        + "epi_repeat"+ str(REPEAT) + "_" + str(LR) + "_"
        + str(OPE_LR) + "_"
        + str(OPE_DS) + "_"
        + str(OPE_DR) + "_"
        + str(CODE_SIZE) + "_"
        + str(beta) + "_"
        + str(RANDOM_SEED)
    )

    #     env = gym.make(rl_params['env_name'])
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    #     env.seed(RANDOM_SEED)

    env_state_dim = DATA[0]['observations'].shape[1] # NEED MOD
    # state_dim = CODE_SIZE
    env_action_dim = DATA[0]['actions'].shape[1] # NEED MOD
    env_action_bound = None # NEED MOD
    env_state_bound = None
    # Ensure action bound is symmetric
    #     assert (env.action_space.high == -env.action_space.low)
    
    # get mean and std
    ob = [i for u in DATA for j in u['observations'] for i in j]
    OBS_MEAN = sum(ob)/len(ob)
    OBS_STD = np.std(ob)

    rw = [j for u in DATA for j in u['rewards']]
    REW_MEAN = sum(rw)/len(rw)
    REW_STD = np.std(rw)
    
    # Ensure action bound is symmetric
    #     assert (env.action_space.high == -env.action_space.low)

    graph_ope_models = tf.Graph()


    graph_ope_models_eval = tf.Graph()


    with tf.Session(config=config, graph=graph_ope_models) as sess_ope_models:
        with tf.Session(config=config, graph=graph_ope_models_eval) as sess_ope_models_eval:

    #             d4rl_qlearning = d4rl.qlearning_dataset(env)
            
            obs_mean = OBS_MEAN
            obs_std = OBS_STD

            rew_mean = REW_MEAN
            rew_std = REW_STD
            


            with graph_ope_models.as_default():

                ope_model = OPE_Model(
                    graph_ope_models, sess_ope_models, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,
                    env_state_dim, env_state_bound, env_action_dim, file_appendix,
                    BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, beta
                )

                ope_saver = ope_model.saver

                sess_ope_models.run(tf.global_variables_initializer())

                ope_model.replay_buffer.port_d4rl_data(
    #                     d4rl.sequence_dataset(env), # original D4RL data format !!!!
                    DATA,
                    obs_mean,
                    obs_std,
                    rew_mean,
                    rew_std,
                )


            with graph_ope_models_eval.as_default():

                ope_model_eval = OPE_Model(
                    graph_ope_models_eval, sess_ope_models_eval, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,
                    env_state_dim, env_state_bound, env_action_dim, file_appendix,
                    BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, 
                    beta, is_training=False
                )


            actor_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(env_action_dim))

            # Initialize replay memory
    #             replay_buffer = sac.replay_buffer

    #         print "Start"

            for i in range(MAX_EPISODES):

    #                     print ("epi_{}".format(i))

    #                         env.seed(RANDOM_SEED)
    #                 s = env.reset()

                ep_reward = 0
                ep_ave_max_q = 0
                ep_elbo = []
                ep_likelihood_s = []
                ep_likelihood_r = []
                ep_divergence1 = []
                ep_divergence2 = []
                ep_divergence3 = []
                ep_mse = []

                if ope_model.replay_buffer.size > MINIBATCH_SIZE_OPE:

    #                             for l in range(MAX_EPISODE_LEN):

    #                                 if l % 20 == 0:

                    batch = ope_model.replay_buffer.sample_batch(MINIBATCH_SIZE_OPE)

                    ope_model.train(batch)
                    ope_model.train(batch)
                    ep_elbo += [np.mean([ope_model.elbo_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_likelihood_s += [np.mean([ope_model.likelihood_s_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_likelihood_r += [np.mean([ope_model.likelihood_r_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_divergence1 += [np.mean([ope_model.divergence1_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_divergence2 += [np.mean([ope_model.divergence2_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_divergence3 += [np.mean([ope_model.divergence3_evaluated for k in range(NUM_OPE_MODELS)])]
                    ep_mse += [np.mean([ope_model.encoder_decoder_lstm_states_mse_evaluated for k in range(NUM_OPE_MODELS)])]
                    if ep_elbo[-1] > BEST_ELBO:
                        BEST_ELBO = ep_elbo[-1]
                        ope_model.saver.save(ope_model.sess, ope_model.save_appendix.replace("ope.ckpt", "aug_best.ckpt"))

                    if np.isnan(ep_elbo[-1]):
                        return

                with open("./rl_stats/"+file_appendix+".txt", "a") as myfile:
                    myfile.write(
                        '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | MSE: {:.4f} \n'
                        .format(
                            int(ep_reward), 
                            i, 
                            np.mean(ep_elbo),
                            np.mean(ep_divergence1),
                            np.mean(ep_divergence2),
                            np.mean(ep_divergence3),
                            np.mean(ep_likelihood_s),
                            np.mean(ep_likelihood_r),
                            np.mean(ep_mse)
                        )
                    )


                print(
                    '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | MSE: {:.4f} \n'
                    .format(
                        int(ep_reward), 
                        i, 
                        np.mean(ep_elbo),
                        np.mean(ep_divergence1),
                        np.mean(ep_divergence2),
                        np.mean(ep_divergence3),
                        np.mean(ep_likelihood_s),
                        np.mean(ep_likelihood_r),
                        np.mean(ep_mse)
                    )
                )


if __name__ == '__main__':                            
	LRs = [0.0003]
	OPE_LRs = [0.003, 0.0001, 0.0003, 0.0005, 0.0007]
	OPE_DSs = [1000]
	OPE_DRs = [.9]
	BETAs = [1., .1, .05, .01, 5., 10.]


	pool = mp.Pool(3)
	pool.map(main, [(lr, ope_lr, ope_ds, ope_dr, beta) for lr in LRs for ope_lr in OPE_LRs for ope_ds in OPE_DSs for ope_dr in OPE_DRs for beta in BETAs])
	pool.close()
	pool.join()