#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : read_and_plot_results.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

"""
Usage: python3 scripts/read_and_plot_results.py {RESULT-DIRS} [-l {LEGENDS}]
supports multiple dirs, can use * to glob, e.g. results/{SOME-DATASET}/*
When regex is provided by --regex {REG-EXPRESSION}, the subdirs in dump dirs
that match regex will be visited.
"""
import os, sys


sys.path.append(os.getcwd())

import argparse
import glob
import re
import json
import os.path as osp
import numpy as np

from sweeping import get_stats, get_stats_str, get_summary, is_better, print_list
from sweeping.plot import main as plot
from sweeping.plot import register_plot_args

parser = argparse.ArgumentParser()

parser.add_argument(
    "dump_dirs",
    type=str,
    nargs="+",
    help="The dump dir for results",
)
parser.add_argument(
    "--regex",
    "-re",
    type=str,
    default=None,
    help="The regex to filter dirs",
)
parser.add_argument(
    "--do-plot",
    "-pl",
    action="store_true",
    help="do plot",
)
parser.add_argument(
    "--max-epochs",
    "-me",
    type=int,
    default=None,
    help="The maximal number of epochs to plot",
)
parser.add_argument(
    "--max-runs",
    "-mr",
    type=int,
    default=5,
    help="The number of runs to plot",
)
parser.add_argument(
    "--legends",
    "-l",
    type=str,
    nargs="+",
    default=None,
    help="legends for the runs",
)
parser.add_argument(
    "--params-fname",
    "-pf",
    type=str,
    default="params.json",
    help="The params file of the run",
)
parser.add_argument(
    "--summary-fname",
    "-sf",
    type=str,
    default="progress.csv",
    help="The name of the summary file",
)
parser.add_argument(
    "--show-multi-runs",
    "-mul",
    action="store_true",
    help="show multi runs seperately",
)
parser.add_argument(
    "--enable-train-res",
    "-etr",
    action="store_true",
    help="enable train res",
)
parser.add_argument(
    "--enable-val-res",
    "-evr",
    action="store_true",
    help="enable val res",
)
parser.add_argument(
    "--smaller-better",
    "-sb",
    action="store_true",
    help="smaller is better, useful for regression tasks",
)
parser.add_argument(
    "--acc",
    "-acc",
    action="store_true",
    help="use acc key (backward compatible)",
)
parser.add_argument(
    "--silence",
    "-silence",
    action="store_true",
    help="do not print when True",
)
register_plot_args(parser)
args = parser.parse_args()
args.verbose = not args.silence


def get_xs_ys(runs):
    num_runs = min(args.max_runs, runs["run"].max() + 1)
    num_epochs = runs["epoch"].max() + 1
    if args.max_epochs is not None:
        num_epochs = min(num_epochs, args.max_epochs)
    train_res = np.array((num_runs, num_epochs), dtype=object)

    ys = []
    legends = []
    key = "acc" if args.acc else "res"
    if args.show_multi_runs:
        train_res = [[] for i in range(num_runs)]
        val_res = [[] for i in range(num_runs)]
        test_res = [[] for i in range(num_runs)]
        loss = [[] for i in range(num_runs)]
        for row in runs.itertuples():
            run_id = row.run
            if run_id < num_runs:
                train_res[run_id].append(getattr(row, f"train_{key}"))
                val_res[run_id].append(getattr(row, f"val_{key}"))
                test_res[run_id].append(getattr(row, f"test_{key}"))
                loss[run_id].append(getattr(row, "train_loss", None))
        for i in range(num_runs):
            if args.enable_train_res:
                ys.append(dict(mean=train_res[i]))
                legends.append(f"train_res_{i}")
            if args.enable_val_res:
                ys.append(dict(mean=val_res[i]))
                legends.append(f"val_res_{i}")
            ys.append(dict(mean=test_res[i]))
            legends.append(f"test_res_{i}")
            # ys.append(dict(mean=loss[i]))
            # legends.append(f"loss_{i}")
    else:
        train_res = [[] for i in range(num_epochs)]
        val_res = [[] for i in range(num_epochs)]
        test_res = [[] for i in range(num_epochs)]
        loss = [[] for i in range(num_epochs)]
        for row in runs.itertuples():
            epoch = row.epoch
            if epoch >= num_epochs:
                continue
            train_res[epoch].append(getattr(row, f"train_{key}"))
            val_res[epoch].append(getattr(row, f"val_{key}"))
            test_res[epoch].append(getattr(row, f"test_{key}"))
            loss[epoch].append(getattr(row, "train_loss", None))

        def get_(name, val):
            y = dict(mean=[], min=[], max=[], std=[])
            for i in range(num_epochs):
                stats = get_stats(val[i])
                for k, v in stats.items():
                    y[k].append(v)
            ys.append(y)
            legends.append(name)

        if args.enable_train_res:
            get_("train_res", train_res)
        if args.enable_val_res:
            get_("val_res", val_res)
        get_("test_res", test_res)
        # get_("train_res", train_res)

    return num_epochs, ys, legends


def get_results(dump_dir):
    dirname = osp.basename(dump_dir)
    summary_file = osp.join(dump_dir, args.summary_fname)
    params_file = osp.join(dump_dir, args.params_fname)
    with open(params_file, "r") as f:
        params = json.load(f)
    cmd = params["raw_cmdline"]
    if args.verbose:
        print("-" * 30 + f" Dir: {dirname} " + "-" * 30)
        print(cmd)
    runs, summary = get_summary(
        summary_file,
        smaller_better=args.smaller_better,
        key="acc" if args.acc else "res",
    )
    best_epoch_id, best_train_ress, best_val_ress, best_test_ress, avg_time = summary
    stats = get_stats(best_test_ress)

    if args.verbose:
        excel_fmt = f'{stats["mean"]:.4f} | {stats["std"]:.4f} '
        print(excel_fmt + get_stats_str(stats) + f" Avg Epoch time: {avg_time:.4f} s")

        print_list("Epoch", best_epoch_id)
        print_list("Train", best_train_ress)
        print_list("Val", best_val_ress)
        print_list("Test", best_test_ress)

    plot_ys = None
    if args.do_plot:
        plot_ys = get_xs_ys(runs)

    return cmd, stats, plot_ys


def main():
    best_dir = None
    best_cmd = None
    max_epochs, ys, legends = 0, [], []
    for dump_dir in args.dump_dirs:
        dirs = [dump_dir]
        if args.regex is not None:
            dirs = []
            for d in glob.glob(osp.join(dump_dir, "*")):
                dirname = osp.basename(d)
                if re.match(args.regex, dirname):
                    # print(f"matched: {dirname}")
                    dirs.append(d)
        if len(dirs) == 0:
            print(f"Warning: no dirs matched in {dump_dir}")
        dirs = sorted(dirs)
        for d in dirs:
            # if True:
            try:
                cmd, stats, plot_ys = get_results(d)
                avg = stats["mean"]
                if best_dir is None or is_better(
                    avg, best_mean_stat["mean"], smaller_better=args.smaller_better
                ):
                    best_mean_stat = stats
                    best_dir = osp.basename(d)
                    best_cmd = cmd
                dirname = osp.basename(d)
                if plot_ys is not None:
                    num_epochs, y, legend = plot_ys
                    max_epochs = max(max_epochs, num_epochs)
                    ys.extend(y)
                    if len(legend) > 1:
                        legend = [f"{dirname}_{l}" for l in legend]
                    else:
                        legend = [dirname]
                    legends.extend(legend)
            except Exception as e:
                pass
    if best_cmd is not None and args.verbose:
        print(f"Best cmd: {best_dir}")
        print(f"Best cmd: {best_cmd}")
        print(f"Best mean stat: {get_stats_str(best_mean_stat)}")
    if args.do_plot:
        xs = [list(range(max_epochs))]
        if args.legends is not None:
            legends = args.legends
        plot(xs, ys, args, legends)
        print(f"Done plot to {args.output}")


if __name__ == "__main__":
    main()
