from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import argparse
import numpy as np


# from hiv_domain.hiv import hiv_config
# config = hiv_config()

# from toy_domain.toy import toy_config
# config = toy_config()

from pendulum_domain.pendulum import pendulum_config
config = pendulum_config()

import matplotlib
font = {'family' : 'normal',
        'size'   : 28}
matplotlib.rc('font', **font)
matplotlib.use('Agg')
import matplotlib.pyplot as plt

colormap = ['#3498db','#865f84','#FFA500','#2ecc71','#e74c3c','#A9A9A9', '#A300DF']
markermap = ['x', 's', 'v', 'o', 'd', 'x', '*']
lineWidth = 5.0
markerSize = 13.0
marker_shape = ['x', 'o', 'v', 'h', 'p', 's', 'd', '^', 'o']

marker_args = dict(
markerfacecolor = 'None',
markeredgewidth = 5,
ms = 14,
)

def reset_fig(xlogscale = False):
    plt.clf()
    plt.figure(figsize = (7,7), facecolor = "blue")
    plt.rc('axes.spines', top = False, right = False)
    if xlogscale:
        plt.xscale('log')

def plot_interval_estimation(config):
    # nt
    reset_fig(xlogscale = True)
    UPPER1 = []
    UPPER2 = []
    LOWER1 = []
    LOWER2 = []

    for x in config.NT:
        Y = np.load(config.result_path + 'interval_nt={}_ts={}_eta={}_size={}.npy'.format(x, config.truncate_size, config.eta, config.subsample_size))
        LOWER1.append(Y[:,0])
        UPPER1.append(Y[:,1])
        LOWER2.append(Y[:,2])
        # UPPER1.append(Y[:,2])
        # LOWER2.append(Y[:,1])
        UPPER2.append(Y[:,3])
    LOWER1 = np.array(LOWER1)
    LOWER2 = np.array(LOWER2)
    UPPER1 = np.array(UPPER1)
    UPPER2 = np.array(UPPER2)

    # plt.plot(config.NT, [1]*len(config.NT), color = colormap[4], marker = markermap[4], markeredgecolor = colormap[4], lw = lineWidth, ls = '--', **marker_args)
    plt.plot(config.NT, [1]*len(config.NT), color = colormap[4], lw = lineWidth, ls = '--')
    UPPER_avg1 = np.mean(UPPER1, axis = 1)
    UPPER_avg2 = np.mean(UPPER2, axis = 1)
    LOWER_avg1 = np.mean(LOWER1, axis = 1)
    LOWER_avg2 = np.mean(LOWER2, axis = 1)

    # plt.plot(config.NT, UPPER_avg1/config.ground_truth, color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    # plt.plot(config.NT, LOWER_avg1/config.ground_truth, color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    # plt.fill_between(config.NT, LOWER_avg1/config.ground_truth, UPPER_avg1/config.ground_truth, alpha = 0.2, color = colormap[0], lw = 0.1)
    plt.plot(config.NT, UPPER_avg2/config.ground_truth, color = colormap[0], marker = markermap[1], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.plot(config.NT, LOWER_avg2/config.ground_truth, color = colormap[0], marker = markermap[1], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.fill_between(config.NT, LOWER_avg2/config.ground_truth, UPPER_avg2/config.ground_truth, alpha = 0.2, color = colormap[0], lw = 0.1)
    plt.xticks([1,10],[r'$10^0$', r'$10^1$'])
    # plt.savefig('./figures/nt_'+config.figure_name, bbox_inches = 'tight')

    print(config.NT)
    print(LOWER_avg2/config.ground_truth)
    Thomas = np.load(config.result_path + 'Thomas_lower_bound.npy')
    Thomas_ground_truth = (-37.3473558721 + 100.0) / 100.0
    print(np.mean(Thomas, axis = 1).T)
    #
    # subsample_size
    return
    reset_fig(xlogscale = False)
    UPPER1 = []
    UPPER2 = []
    LOWER1 = []
    LOWER2 = []

    for x in config.SSIZE:
        Y = np.load(config.result_path + 'interval_nt={}_ts={}_eta={}_size={}.npy'.format(config.num_trajectory, config.truncate_size, config.eta, x))
        LOWER1.append(Y[:,0])
        # UPPER1.append(Y[:,1])
        # LOWER2.append(Y[:,2])
        UPPER1.append(Y[:,2])
        LOWER2.append(Y[:,1])
        UPPER2.append(Y[:,3])
    LOWER1 = np.array(LOWER1)
    LOWER2 = np.array(LOWER2)
    UPPER1 = np.array(UPPER1)
    UPPER2 = np.array(UPPER2)

    plt.plot(config.SSIZE, [1]*len(config.SSIZE), color = colormap[4], lw = lineWidth, ls = '--')
    UPPER_avg1 = np.mean(UPPER1, axis = 1)
    UPPER_avg2 = np.mean(UPPER2, axis = 1)
    LOWER_avg1 = np.mean(LOWER1, axis = 1)
    LOWER_avg2 = np.mean(LOWER2, axis = 1)

    # plt.plot(config.SSIZE, UPPER_avg1/config.ground_truth, color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    # plt.plot(config.SSIZE, [UPPER_avg1[-1]/config.ground_truth]*len(config.SSIZE), color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.plot(config.SSIZE, [UPPER_avg1[-1]/config.ground_truth]*len(config.SSIZE), color = colormap[1], lw = lineWidth, ls = '--')
    # plt.plot(config.SSIZE, LOWER_avg1/config.ground_truth, color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    # plt.plot(config.SSIZE, [LOWER_avg1[-1]/config.ground_truth]*len(config.SSIZE), color = colormap[0], marker = markermap[0], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.plot(config.SSIZE, [LOWER_avg1[-1]/config.ground_truth]*len(config.SSIZE), color = colormap[1], lw = lineWidth, ls = '--')
    # plt.fill_between(config.SSIZE, LOWER_avg1/config.ground_truth, UPPER_avg1/config.ground_truth, alpha = 0.2, color = colormap[0], lw = 0.1)
    plt.fill_between(config.SSIZE, [LOWER_avg1[-1]/config.ground_truth]*len(config.SSIZE), [UPPER_avg1[-1]/config.ground_truth]*len(config.SSIZE), alpha = 0.2, color = colormap[1], lw = 0.1)
    plt.plot(config.SSIZE, UPPER_avg2/config.ground_truth, color = colormap[0], marker = markermap[1], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.plot(config.SSIZE, LOWER_avg2/config.ground_truth, color = colormap[0], marker = markermap[1], markeredgecolor = colormap[0], lw = lineWidth, ls = '--', **marker_args)
    plt.fill_between(config.SSIZE, LOWER_avg2/config.ground_truth, UPPER_avg2/config.ground_truth, alpha = 0.2, color = colormap[0], lw = 0.1)
    plt.xticks([500,1000])
    plt.savefig('./figures/ssize_'+config.figure_name, bbox_inches = 'tight')

if __name__ == '__main__':
    plot_interval_estimation(config)
