from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
sys.path.append('.')
sys.path.append('..')
from pendulum_domain.pendulum import pendulum_config

def Thomas_lower_bound(c, X, delta):
    # X: N
    # Y = min(X, c)
    N = X.shape[0]
    Y = np.minimum(X, c)
    Y_mean = np.mean(Y)
    Y_var = np.var(Y)
    term2 = 7 * c * np.log(2/delta) / 3 / (N - 1)
    term3 = np.sqrt(2 * np.log(2/delta) / (N - 1) * Y_var)

    lower_bound = Y_mean - term2 - term3
    print('c = {}, lower bound = {}'.format(c, lower_bound))
    return lower_bound

def episodic_reward(REW, gamma):
    # REW: (truncate_size)
    truncate_size = REW.shape[0]
    return np.mean(REW * np.exp(np.arange(truncate_size) * np.log(gamma)))

def importance_ratio_trajectory(S, A, policy_target, policy_behavior):
    # S, A: (truncate_size * dim)
    log_prob_target = policy_target.logpis(S, A)
    log_prob_behavior = policy_behavior.logpis(S, A)
    return np.exp(np.sum(log_prob_target - log_prob_behavior))

def data_to_random_variable(SASR, gamma, policy_target, policy_behavior):
    # data generate by policy behavior
    S, A, SN, REW  = SASR
    # R: (n_traj * truncate_size)
    # S, A, SN: (n_traj * truncate_size * dim)
    num_trajectory, truncate_size = REW.shape
    ep_reward = np.zeros(num_trajectory)
    im_ratio = np.zeros(num_trajectory)
    X = np.zeros(num_trajectory)
    for i_traj in range(num_trajectory):
        reward = episodic_reward(REW[i_traj,:], gamma)
        importance_ratio = importance_ratio_trajectory(S[i_traj], A[i_traj], policy_target, policy_behavior)
        ep_reward[i_traj] = reward
        im_ratio[i_traj] = importance_ratio
    R_max = 0.0#ep_reward.max()
    R_min = -100.0#ep_reward.min()
    # print(R_max)
    # print(R_min)
    ep_reward = (ep_reward - R_min) / (R_max - R_min)
    X = ep_reward * im_ratio
    return X, R_max, R_min


if __name__ == '__main__':
    config = pendulum_config()
    gamma = config.gamma
    truncate_size = config.truncate_size
    num_trajectory = config.num_trajectory
    num_seed = 20
    # seed = 43
    groud_truth = (-37.3473558721 + 100.0) / 100.0
    print('ground truth = {}'.format(groud_truth))

    Thresholds = [0.00, 0.001, 0.01, 0.1, 1.0, 10.0]
    RES = np.zeros([len(config.NT), num_seed, len(Thresholds)])

    for (i_nt, num_trajectory) in enumerate(config.NT):
        print("===== nt = ", num_trajectory)
        for seed in range(num_seed):
            print('----- seed = ', seed)
            SASR = config.get_trasition_data(num_trajectory, truncate_size, config.policy_behavior, seed)
            X, _, _ = data_to_random_variable(SASR, gamma, config.policy_target, config.policy_behavior)
            delta = 0.05
            print('importance sampling estimator = {}'.format(X.mean()))
            for i_c,c in enumerate(Thresholds):
                RES[i_nt, seed, i_c] = Thomas_lower_bound(c, X, delta)
    np.save(config.result_path + 'Thomas_lower_bound.npy', RES)
