import argparse

import numpy as np
import matplotlib.pyplot as plt
import pickle

from utils.utils import read_config_file
import matplotlib.patheffects as pe

smooth_gamma = 0.1  # 您的平滑因子


linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'p', '*', 'x', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['blue', 'green', 'purple', 'orange', 'brown', 'pink', 'red', 'gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink', 'gray']
labels = [
    "Origin",
    "ReRoPE",
    "Leaky-ReRoPE",
    "Dynamic-NTK",
    "LM-Infinite",
    "Streaming-LLM",
    "Mesa-Extrapolation",
    "other"
]

label_setting = {}
for label, linestyle, marker, color in zip(labels, linestyles, markers, colors):
    label_setting[label] = {
        "linestyle": linestyle,
        "marker": marker,
        "color": color
    }

print(label_setting)




def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        smooth_gamma = config['General']['smooth_gamma']
        model_name = config['General']['model_name']
    else:
        raise FileNotFoundError("no config file")

    fig, ax = plt.subplots()

    for _filename, _label in zip(files, labels):
        if "../logs/" not in _filename:
            _filename = "../logs/" + _filename
        with open(_filename, "rb") as f:
            data = pickle.load(f)["all_length_acc"]

            x = [-1]
            # y = [1]
            y = [0.5]

            pre_value = []
            for length, value in data.items():
                if abs(length - x[-1]) > 1000:
                    x.append(length)
                    pre_value = value
                else:
                    pre_value.extend(value)
                    y.pop(-1)

                mean = np.nanmean(pre_value)
                var = np.nanvar(pre_value)
                y.append(mean * (1 - smooth_gamma) + y[-1] * smooth_gamma)

            y.pop(0)
            x.pop(0)
            if _label in label_setting.keys():
                ax.plot(x, y, label=_label, linewidth=1.5, linestyle=label_setting[_label]["linestyle"], marker=label_setting[_label]["marker"],
                        color=label_setting[_label]["color"])
            else:
                ax.plot(x, y, label=_label, linewidth=1.5, linestyle=label_setting["other"]["linestyle"], marker=label_setting["other"]["marker"],
                        color=label_setting["other"]["color"])

    # 在横坐标的取值点位置添加纵向白线
    for x_value in x:
        ax.axvline(x_value, color='white', linestyle='-', linewidth=5.0, zorder=1)

    # 设置坐标轴标签
    ax.set_xlabel("Token Length", fontsize=14)
    ax.set_ylabel("Accuracy", fontsize=14)

    # 设置标题
    ax.set_title("{}".format(model_name))

    # 背景和网格线
    ax.set_facecolor('#f8f8f8')  # 设置背景色

    plt.subplots_adjust(left=0.1, right=0.9, bottom=0.13, top=0.9)


    # 优化子图效果，添加浅色边框
    ax.spines['top'].set_color('lightgrey')
    ax.spines['right'].set_color('lightgrey')
    ax.spines['bottom'].set_color('lightgrey')
    ax.spines['left'].set_color('lightgrey')

    ax.spines['right'].set_path_effects([pe.withStroke(linewidth=2, foreground='lightgrey')])  # 添加立体效果
    ax.spines['bottom'].set_path_effects([pe.withStroke(linewidth=2, foreground='lightgrey')])
    ax.set_facecolor('#f8f8f8')  # 设置背景色

    x = np.arange(4*1024, 16 * 1024, 1024)
    plt.xticks(np.array(x), [str(int(l / 1024)) + "k" if int(l / 1024) % 2 == 0 else "" for l in x],
               fontsize=12)

    plt.legend()
    plt.savefig('longeval-lines_{}.png'.format(model_name))
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/llama2-7b-chat-longeval-lines-result14.json")
    args = parser.parse_args()
    main(args)
