from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
sys.path.append('.')
sys.path.append('..')
from methods.lips_bound import lips_bound_evaluation, estimate_rpi

class hiv_config(object):
    # domain parameters
    state_dim = 6
    action_size = 4

    gamma = 0.75
    num_trajectory = 50
    truncate_size = 40
    eta = 40.0
    subsample_size = 500
    NT = [1,2,4,6,10,15,20,30,50,100]
    ETA = [20.0, 30.0, 35.0, 40.0, 45.0, 50.0, 60.0]
    SSIZE = [100, 200, 300, 400, 500, 600, 800, 1000, 1500]
    result_path = './results/hiv_results/'
    data_path = './transition_data/hiv_data/'
    figure_name = 'hiv.pdf'
    ground_truth = 4.49778

    behavior_eps = 0.10
    target_eps = 0.05
    ins = 20

    hidden_dim = 100
    feature_dim = 3
    Learning_rate = 3e-5

    max_iteration = 200

    def decode(self, SASRpi, s0pi0):
        state_dim = self.state_dim
        action_size = self.action_size
        S = SASRpi[:,:state_dim]
        A = SASRpi[:,state_dim:(state_dim+action_size)]
        SN = SASRpi[:,(state_dim+action_size):(2*state_dim+action_size)]
        REW = SASRpi[:,2*state_dim+action_size] #/ 1000000.0
        pi = SASRpi[:,-action_size:]
        s0 = s0pi0[:,:state_dim]
        pi0 = s0pi0[:,-action_size:]
        return S,A,SN,REW,pi,s0,pi0

    def feature_naive(self, S, A):
        N = S.shape[0]
        AA = np.zeros([N,2])
        index0 = (A[:,1] + A[:,3] > 0)
        AA[index0,0] = 1
        index1 = (A[:,2] + A[:,3] > 0)
        AA[index1,1] = 1
        return np.hstack([S,AA])

    def interval_estimation(self, num_trajectory, eta, subsample_size, seedID):
        print('======== Current Setting for hiv =========')
        print('---nt = {}, ts = {}, eta = {}, sample_size = {}, seed = {}---'.format(num_trajectory, self.truncate_size, eta, subsample_size, seedID))
        concurrent_size = 30
        i = seedID // concurrent_size
        j = seedID % concurrent_size
        N = num_trajectory * self.truncate_size
        SASRpi = np.load(self.data_path + 'hiv_traj_{}.npy'.format(i))[j,:N,:]
        s0pi0 = np.load(self.data_path + 'hiv_init_{}.npy'.format(i))[j,:,:]
        S, A, SN, REW, pi, s0, pi0 = self.decode(SASRpi, s0pi0)

        max_iteration = self.max_iteration
        if subsample_size > 500:
            max_iteration = int(200 * 500 / subsample_size)
        Q_lower, Q_upper = lips_bound_evaluation(s0, [S, A, SN, REW], pi, self.feature_naive, self.gamma, eta, subsample_size = subsample_size, pi0 = pi0, max_iteration = max_iteration, discrete_action = True)
        Q0_lower, Q0_upper = estimate_rpi(s0, pi0, self.feature_naive, S, A, Q_lower, Q_upper, self.gamma, eta, discrete_action = True)

        Q_lower, Q_upper = lips_bound_evaluation(s0, [S, A, SN, REW], pi, self.feature_naive, self.gamma, eta, subsample_size = subsample_size, pi0 = pi0, double_sample = True, max_iteration = max_iteration, discrete_action = True)
        Q0_lower2, Q0_upper2 = estimate_rpi(s0, pi0, self.feature_naive, S, A, Q_lower, Q_upper, self.gamma, eta, discrete_action = True)
        print('-----end calculation-----')
        print('lower = {}, upper = {}'.format(Q0_lower, Q0_upper))
        print('double sample: lower = {}, upper = {}'.format(Q0_lower2, Q0_upper2))
        print('============================')
        sys.stdout.flush()
        return Q0_lower, Q0_lower2, Q0_upper, Q0_upper2


#
# class fitted_q_iteration_hiv(object):
#     eps = 0.05
#     env = model(perturb_rate = 0.05, dt=10)
#     def __init__(self, feature_dim, hidden_dim = 100, Learning_rate = 1e-3, reg_weight = 1e-2, seedID = 43, fixed_model = True):
#         self.state_dim = 6
#         self.action_dim = 4
#         self.state_action_dim = self.state_dim + self.action_dim
#         self.feature_dim = feature_dim
#         self.hidden_dim = hidden_dim
#
#         tf.set_random_seed(seedID)
#         # Input
#         self.xs = tf.placeholder(tf.float32, [None, self.state_action_dim])
#         self.BQs = tf.placeholder(tf.float32, [None])
#
#         # Calculation
#         self.features, self.Qs, params = self._build_q_network('current_Q', True)
#         _, self.target_Qs, target_params = self._build_q_network('target_Q', False)
#
#         # Loss and operation
#         if fixed_model == False:
#             self.update_target_op = [target.assign(current) for current, target in zip(params, target_params)]
#             self.loss = tf.reduce_mean(tf.square(self.Qs - self.BQs))
#             self.reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'current_Q'))
#             self.train_op = tf.train.AdamOptimizer(Learning_rate).minimize(self.loss + reg_weight * self.reg_loss)
#             # self.train_op = tf.train.GradientDescentOptimizer(Learning_rate).minimize(self.loss + reg_weight * self.reg_loss)
#
#         # Debug
#
#         # sess and cpu
#         self.sess = tf.Session()
#         self.sess.run(tf.global_variables_initializer())
#         self.saver = tf.train.Saver()
#
#     def reset(self):
#         self.sess.run(tf.global_variables_initializer())
#
#     def save_model(self, filename = './hiv_domain/hiv_model/feature_q_fitted_hiv.ckpt'):
#         self.saver.save(self.sess, filename)
#
#     def load_model(self, filename = './hiv_domain/hiv_model/feature_q_fitted_hiv.ckpt'):
#         self.saver.restore(self.sess, filename)
#
#     def _build_q_network(self, name, trainable):
#         with tf.variable_scope(name, reuse = tf.AUTO_REUSE):
#             l1 = tf.layers.dense(self.xs, self.hidden_dim, tf.nn.relu, trainable = trainable)
#             features = tf.concat([tf.layers.dense(l1, self.feature_dim, tf.nn.relu, trainable = trainable) , self.xs], -1)
#             W2 = tf.get_variable('W2', initializer = tf.zeros(shape = [self.state_action_dim + self.feature_dim, 1]), regularizer = tf.contrib.layers.l2_regularizer(0.), trainable = trainable)
#             b2 = tf.get_variable('b2', initializer = tf.zeros([1]), regularizer = tf.contrib.layers.l2_regularizer(0.), trainable = trainable)
#             Qs = tf.squeeze(tf.matmul(features, W2) + b2)
#             features = features * tf.squeeze(W2) / np.sqrt(self.feature_dim + self.state_action_dim)
#         params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = name)
#         return features, Qs, params
#
#     def get_features(self, S, A):
#         return self.sess.run(self.features, feed_dict = {
#             self.xs: np.hstack([S,A])
#         })
#
#     def get_Q_value(self, S, A):
#         return self.sess.run(self.Qs, feed_dict = {
#             self.xs: np.hstack([S,A])
#         })
#
#     def encode(self, action):
#         encode_action = np.zeros([self.action_dim])
#         encode_action[action] = 1
#         return encode_action
#
#     def get_all_q_value(self, S):
#         if S.ndim == 1:
#             return self.get_Q_value(np.zeros([self.action_dim, self.state_dim]) + S, np.eye(self.action_dim))
#         else:
#             N = S.shape[0]
#             return np.vstack([self.get_Q_value(S, self.encode(a) + np.zeros([N, self.action_dim]))] for a in range(self.action_dim))    # shape: action_dim * N
#
#     def choose_action(self, s, eps = 0):
#         if s.ndim == 1:
#             if np.random.rand() < eps:
#                 return np.random.randint(self.action_dim)
#             else:
#                 return self.get_all_q_value(s).argmax()
#         else:
#             N = s.shape[0]
#             index = np.random.rand(N) < eps
#             action = np.argmax(self.get_all_q_value(s), axis = 0)    # greedy_action
#             rand_action = np.random.randint(self.action_dim, size = N)
#             action[index] = rand_action[index]
#             return action
#
#
#     def action_prob(self, S, eps = -1):
#         if eps < 0:
#             eps = self.eps
#         if S.ndim == 1:
#             greedy_action = self.get_all_q_value(S).argmax()
#             pi = np.ones(self.action_dim) * eps / self.action_dim
#             pi[greedy_action] = 1 - (self.action_dim - 1) * eps / self.action_dim
#             return pi
#         else:
#             N = S.shape[0]
#             greedy_action = np.argmax(self.get_all_q_value(S), axis = 0)
#             pi = np.ones([N, self.action_dim]) * eps / self.action_dim
#             pi[range(N), greedy_action] = 1 - (self.action_dim - 1) * eps / self.action_dim
#             return pi
#
#     def roll_out(self, num_trajectory, truncate_size, gamma, seed = 0):
#         eps = self.eps
#         np.random.seed(seed)
#         env = self.env
#         with open('./hiv_domain/hiv_simulator/hiv_preset_hidden_params', 'rb') as f:
#             preset_hidden_params = pickle.load(f, encoding='latin1')
#
#         S = []; A = []; SN = []; REW = []; average_rew = np.zeros(truncate_size)
#         for i_trajectory in range(num_trajectory):
#             # print(i_trajectory)
#             env.reset(perturb_params = True, **preset_hidden_params[20])
#             state = env.observe()
#             for i_t in range(truncate_size):
#                 action = self.choose_action(state, eps = eps)
#                 reward, next_state = env.perform_action(action, perturb_params = True, **preset_hidden_params[20])
#                 S.append(state); A.append(self.encode(action)); REW.append(reward); SN.append(next_state)
#                 average_rew[i_t] += reward
#                 state = next_state
#         average_rew /= num_trajectory
#         discounted = np.exp(np.arange(truncate_size) * np.log(gamma))
#         print('eps = {}, on policy average reward = {}'.format(eps, np.sum(average_rew * discounted ) / np.sum(discounted)))
#         return np.array(S), np.array(A), np.array(SN), np.array(REW)
#
#     def roll_out_initial(self, num_trajectory, seed = 0):
#         np.random.seed(seed)
#         env = self.env
#         with open('./hiv_domain/hiv_simulator/hiv_preset_hidden_params', 'rb') as f:
#             preset_hidden_params = pickle.load(f, encoding='latin1')
#
#         S0 = []
#         for i_trajectory in range(num_trajectory):
#             env.reset(perturb_params = True, **preset_hidden_params[20])
#             state = env.observe()
#             S0.append(state)
#         return np.array(S0)
#
#     def fitted_q_update(self, num_trajectory, truncate_size, batch_size, gamma, eps = 0.15, fitted_iter = 4000, max_iter = 30):
#         for i_iter in range(max_iter):
#             print('===iteration {}===='.format(i_iter))
#             self.sess.run(self.update_target_op)
#             S, A, SN, REW = self.roll_out(num_trajectory, truncate_size, gamma, seed = i_iter)
#             N = S.shape[0]
#             j = N
#             for i_train in range(fitted_iter):
#                 if j + batch_size > N:
#                     perm = np.random.permutation(N)
#                     j = 0
#                 subsamples = perm[j:j+batch_size]
#                 xs = np.hstack([S[subsamples], A[subsamples]])
#                 r = REW[subsamples]
#                 sn = SN[subsamples]
#                 BQs = r + gamma * np.max(self.get_all_q_value(sn), axis = 0)
#                 loss, reg_loss, _ = self.sess.run([self.loss, self.reg_loss, self.train_op], feed_dict = {
#                     self.xs: xs,
#                     self.BQs: BQs
#                 })
#                 if i_train % 500 == 0:
#                     print('---train iter {}, loss = {}'.format(i_train, loss))
#                 j += batch_size
#         return
#
#     def fitted_q_evaluation(self, SASR, policy_target, S0, batch_size, gamma, fitted_iter = 2000, max_iter = 5):
#         S, A, SN, REW = SASR
#         N = S.shape[0]
#         for i_iter in range(max_iter):
#             print('===iteration {}===='.format(i_iter))
#             self.sess.run(self.update_target_op)
#             j = N
#             for i_train in range(fitted_iter):
#                 if j + batch_size > N:
#                     perm = np.random.permutation(N)
#                     j = 0
#                 subsamples = perm[j:j+batch_size]
#                 xs = np.hstack([S[subsamples], A[subsamples]])
#                 r = REW[subsamples]
#                 sn = SN[subsamples]
#                 Q_nexts = self.get_all_q_value(sn).T        # shape: batch_size * action_dim
#                 pi_nexts = policy_target.action_prob(sn)    # shape: batch_size * action_dim
#                 BQs = r + gamma * np.sum(Q_nexts * pi_nexts, axis = -1)
#                 loss, reg_loss, _ = self.sess.run([self.loss, self.reg_loss, self.train_op], feed_dict = {
#                     self.xs: xs,
#                     self.BQs: BQs
#                 })
#                 if i_train % 500 == 0:
#                     print('---train iter {}, loss = {}'.format(i_train, loss))
#                     print('est_model = {}'.format(self.est_reward(S0, policy_target, gamma)))
#                 j += batch_size
#
#     def est_reward(self, S0, policy_target, gamma):
#         Q0 = self.get_all_q_value(S0).T
#         pi0 = policy_target.action_prob(S0)
#         return (1 - gamma) * np.mean(np.sum(Q0 * pi0, axis = -1))
