import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed, read_config_file
import numpy as np
import os

def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
        model_name = config['General']['model_name']
    else:
        raise FileNotFoundError("no config file")

    for _filename, _label in zip(files, labels):
        with open(_filename, "rb") as f:
            data = pickle.load(f)["nll_stats_token"]
            smoothed = [0]
            var_values = []
            for _d in data.values():
                smoothed.append(
                    _d["mean"] * (1 - smooth_gamma) +
                    smoothed[-1] * smooth_gamma
                )
                var_values.append(_d["var"])
            smoothed.pop(0)

            y = np.array(smoothed)
            var_values = np.array(var_values)
            plt.plot(y, label=_label, linewidth=1)
            plt.fill_between(y - (var_values), y + (var_values), alpha=0.1)
            # plt.plot(smoothed, label=_label, linewidth=1.5)

    # # 设置坐标轴标签
    # plt.xlabel("token length")
    # plt.ylabel("nll")
    # plt.ylim(bottom=0, top=15)
    # # plt.xticks()
    # # 设置标题
    # plt.title("perplexity task: {} (smooth:{})".format(model_name.split("_")[0], smooth_gamma))
    # plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)


    # 论文做图
    # 设置坐标轴标签
    plt.xlabel("Token Length", fontsize=12)
    plt.ylabel("NLL", fontsize=12)
    plt.ylim(bottom=0, top=15)
    # plt.xticks()
    # 设置标题
    # plt.title("perplexity task: {} (smooth:{})".format(model_name.split("_")[0], smooth_gamma))
    # plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)
    plt.subplots_adjust(left=0.09, right=0.97, bottom=0.1, top=0.97)


    plt.legend()
    plt.savefig('ppl_{}.png'.format(model_name))
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/mpt-7b-ppl-pile-result17.json")
    args = parser.parse_args()
    main(args)
