import argparse
import os
import os.path as osp
import pandas as pd
import numpy as np
from datetime import datetime


import matplotlib
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt

font = {'family': 'CMU Serif',
        # 'weight': 'bold',
        'size': 12}

matplotlib.rc('font', **font)
# matplotlib.rc('xtick', labelsize=20)
# matplotlib.rc('ytick', labelsize=20)

color_list = [
    'tab:blue',
    'tab:orange',
    'tab:green',
    'tab:red',
    'tab:purple',
    'tab:brown',
    'tab:pink',
    'tab:gray',
    'tab:olive',
    'tab:cyan',
    'deeppink',
    'b',
    'g',
    'r',
    'c',
    'm',
    'y',
]


def get_file_prefix(params=None):
    if params is not None and params['exper_name'] is not None:
        folder = os.path.join('outputs', params['exper_name'])
    else:
        now = datetime.now()
        date_string = now.strftime("%Y-%m-%d/%H-%M-%S")
        folder = os.path.join('outputs', date_string)
    if params is not None and params['seed'] != -1:
        folder = os.path.join(folder, str(params['seed']))
    return folder


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('data_dirs', nargs='*')
    parser.add_argument('--labels', nargs='*')
    parser.add_argument('--colors', nargs='*')
    parser.add_argument('--lines', nargs='*')
    parser.add_argument('--x-axis', type=str, default='TotalEnvInteracts')
    parser.add_argument('--y-axis', type=str, default='AverageTestEpRet')
    parser.add_argument('--x-ticks', type=int, nargs='*')
    parser.add_argument('--y-ticks', type=int, nargs='*')
    parser.add_argument('--x-label', type=str, default=None)
    parser.add_argument('--y-label', type=str, default=None)
    parser.add_argument('--y-lim', nargs=2, type=float)
    parser.add_argument('--x-lim', nargs=2, type=float)
    parser.add_argument('--title', type=str, default=None)
    parser.add_argument('--no-legend', action='store_true')
    parser.add_argument('--gen-legend', action='store_true')
    parser.add_argument('--legend-loc', type=str, default=None)
    parser.add_argument('--no-y-label', action='store_true')
    parser.add_argument('--plot-std', action='store_true')
    parser.add_argument('--smoothing', action='store_true')
    parser.add_argument('--font-size', default=12, type=int)
    parser.add_argument('--crop', default=0, type=int)
    parser.add_argument('--reduction', type=str, choices=('mean', 'median'), default='mean')
    parser.add_argument('--legend-size', type=int, default=10)
    parser.add_argument('--legend-ncol', type=int, default=1)
    parser.add_argument('--file-name', type=str, default=None)
    args = parser.parse_args()

    logdir = get_file_prefix()
    os.makedirs(logdir)

    assert args.labels is None or len(args.labels) == len(args.data_dirs), \
        'Must give no labels or same number of labels and data dirs'
    labels = args.labels if args.labels is not None else args.data_dirs

    assert args.colors is None or len(args.colors) == len(args.data_dirs), \
        'Must give no colors or same number of colors and data dirs'
    colors = args.colors if args.colors is not None else color_list[:len(labels)]

    assert args.lines is None or len(args.lines) == len(args.data_dirs), \
        'Must give no lines or same number of lines and data dirs'
    lines = args.lines if args.lines is not None else ['-' for _ in args.data_dirs]

    font = {'family': 'CMU Serif',
        # 'weight': 'bold',
        'size': args.font_size}

    matplotlib.rc('font', **font)

    fig, ax = plt.subplots()
    for data_dir, label, color, line in list(zip(args.data_dirs, labels, colors, lines)):
        dfs = []
        for root, dirs, files in os.walk(data_dir):
            try:
                if 'progress.csv' in files:
                    df = pd.read_csv(osp.join(root, 'progress.csv'), delimiter=',')
                    if not df.empty:
                        dfs.append(df)
                if 'progress.txt' in files:
                    df = pd.read_csv(osp.join(root, 'progress.txt'), delimiter='\t')
                    if not df.empty:
                        dfs.append(df)
            except pd.errors.EmptyDataError:
                continue

        min_len = min([len(df) for df in dfs])
        print(min_len)
        if args.crop:
            min_len = min(args.crop, min_len)
        dat = []
        for df in dfs:
            dat.append(np.array(df[args.y_axis][:min_len]))
        dat = np.vstack(dat)
        if args.reduction == 'mean':
            dat_mean = np.mean(dat, axis=0)
        elif args.reduction == 'median':
            dat_mean = np.median(dat, axis=0)
        else:
            raise NotImplementedError
        xs = df[args.x_axis][:min_len]

        if args.smoothing:
            dat_smoothed = []
            for point in dat_mean:
                if len(dat_smoothed) == 0:
                    dat_smoothed.append(point)
                else:
                    last = dat_smoothed[-1]
                    last = last * 0.9 + point * 0.1
                    dat_smoothed.append(last)
            plt.plot(xs, dat_smoothed, label=label, c=color, linestyle=line)
            plt.plot(xs, dat_mean, c=color, linestyle=line, alpha=0.1)
        else:
            plt.plot(xs, dat_mean, label=label, c=color, linestyle=line)

#        plt.plot(dat_mean, label=label, c=color, linestyle=line)
        if args.plot_std:
            dat_std = np.std(dat, axis=0)
            dat_ste = dat_std / np.sqrt(len(dfs))
            dist = dat_ste / 2
            plt.fill_between(xs, dat_mean - dist, dat_mean + dist, color=color, alpha=0.1)
    if args.title is not None:
        plt.title(args.title)
    plt.xlabel(args.x_label or args.x_axis)
    if not args.no_y_label:
        plt.ylabel(args.y_label or args.y_axis)
    if not args.no_legend:
        plt.legend(loc=args.legend_loc, ncol=args.legend_ncol)
        #plt.legend(loc='lower center', ncol=len(args.data_dirs))
    if args.x_lim is not None:
        plt.xlim([args.x_lim[0], args.x_lim[1]])
    if args.x_ticks is not None:
        plt.xticks(args.x_ticks)
    if args.y_ticks is not None:
        plt.yticks(args.y_ticks)
    fig.tight_layout()
    if args.file_name is not None:
        if '.' in args.file_name:
            file_name = args.file_name
        else:
            file_name = args.file_name + '.pdf'
    else:
        file_name = 'plot.pdf'
    plt.savefig(osp.join(logdir, file_name))
    plt.close()

    if args.gen_legend:
        legend_elements = []
        for label, color, line in list(zip(labels, colors, lines)):
            legend_elements.append(Line2D([0], [0], color=color, label=label, linestyle=line))
        # legend_elements = [Line2D([0], [0], color='b', lw=4, label='Line'),
        #                    Line2D([0], [0], marker='o', color='w', label='Scatter',
        #                           markerfacecolor='g', markersize=15),
        #                    Patch(facecolor='orange', edgecolor='r',
        #                          label='Color Patch')]

        # Create the figure
        fig_legend = plt.figure(figsize=(args.legend_size, .3))
        # ax = fig_legend.add_subplot(111)
        fig_legend.legend(handles=legend_elements, loc='center', ncol=args.legend_ncol, frameon=False)
        plt.savefig(osp.join(logdir, 'legend.pdf'))

        # plt.show()
