import os
import argparse
from collections import deque, defaultdict
import numpy as np
from tabulate import tabulate
import matplotlib.pyplot as plt
from copy import deepcopy
import subprocess
plt.switch_backend('agg')


class Log:
    def __init__(self, name, folder):
        self.name = name
        self.folder = folder
        self.train_win_rate = defaultdict(list)
        self.eval_win_rate = defaultdict(list)

        self.num_epochs = 0
        train_log = os.path.join(folder, 'train.log')
        if os.path.exists(train_log):
            self.num_epochs = self._parse_log(train_log, 'train')

        # eval_log = os.path.join(folder, 'eval.log')
        # if os.path.exists(eval_log):
        #     self._parse_log(eval_log, 'eval')

    @property
    def train_max(self):
        if len(self.train_win_rate['a-overall']):
            return np.max(self.train_win_rate['a-overall'])
        else:
            return 0

    @property
    def train_recent(self):
        if len(self.train_win_rate['a-overall'][-3:]):
            return np.mean(self.train_win_rate['a-overall'][-3:])
        else:
            return 0

    @property
    def train_resource(self):
        if len(self.train_win_rate['resource'][-3:]):
            return np.mean(self.train_win_rate['resource'][-3:])
        else:
            return 0

    @property
    def eval_max(self):
        return np.max(self.eval_win_rate['a-overall'])

    @property
    def eval_recent(self):
        return np.mean(self.eval_win_rate['a-overall'][-3:])

    def _parse_log(self, filename, mode):
        with open(filename, 'r') as f:
            lines = f.readlines()

        num_epochs = 0
        for l in lines:
            if l.startswith('win'):
                num_epochs += 1
                win_rate = float((l.split(' ')[-1]).strip()[:-1])
                win_rate /= 100
                if mode == 'train':
                    self.train_win_rate['a-overall'].append(win_rate)
                else:
                    self.eval_win_rate['a-overall'].append(win_rate)
            elif l.startswith('avg resource scale'):
                scale = float(l.split()[-1])
                self.train_win_rate['resource'].append(scale)
            else:
                self._parse_unit_type_win_rate(l, mode)
        return num_epochs

    def _parse_unit_type_win_rate(self, line, mode):
        uts  = ['strong',
                'simple-SPEARMAN',
                'simple-SWORDMAN',
                'simple-CAVALRY',
                'simple-DRAGON',
                'simple-ARCHER',
                'medium-SPEARMAN',
                'medium-SWORDMAN',
                'medium-CAVALRY',
                'medium-DRAGON',
                'medium-ARCHER']
        for ut in uts:
            if not line.startswith(ut):
                continue
            if mode == 'train':
                win_rate = line.split()[3]
                win_rate = float(win_rate.strip()[:-1]) / 100
                self.train_win_rate[ut].append(1 - win_rate)
                # print(win_rate)
                # self.train_percent[ut].append(float(percent))
            else:
                # win, loss, _ = map(float, line.split()[-3:])
                # win_rate = (1 - win / (win + loss))
                win_rate = line.split()[-1]
                win_rate = float(win_rate.strip()[:-1]) / 100
                win_rate = 1 - win_rate
                self.eval_win_rate[ut].append(win_rate)

    def render_winrate(self, first_k):
        fig, ax = plt.subplots(1, 2, figsize=(18, 9))
        x = list(range(self.num_epochs))
        if first_k > 0:
            x = x[:first_k]
        for key in sorted(self.train_win_rate.keys()):
            rates = self.train_win_rate[key]
            if first_k > 0:
                rates = rates[:first_k]
            ax[0].plot(x, rates, label=key)
            ax[0].set_title('train')
            ax[0].legend(loc='lower right')

        # for key in sorted(self.eval_win_rate.keys()):
        #     rates = self.eval_win_rate[key]
        #     xx = [20 * i for i in range(len(rates))]
        #     ax[1].plot(xx, rates, label=key)
        #     ax[1].set_title('eval')
        #     ax[1].legend(loc='lower right')

        plt.suptitle(self.name)
        plt.tight_layout()
        img_path = os.path.join(self.folder, 'winrate.png')
        print('writing to', img_path)
        plt.savefig(img_path)
        plt.close()


def parse_from_root(root, pattern):
    logs = {}
    root = os.path.abspath(root)
    names = []
    for exp_folder in os.listdir(root):
        exp_name = exp_folder
        if len(pattern) and pattern not in exp_name:
            continue
        exp_folder = os.path.join(root, exp_folder)
        train_log = os.path.join(exp_folder, 'train.log')
        if os.path.isdir(exp_folder) and os.path.exists(train_log):
            logs[exp_name] = Log(exp_name, exp_folder)
        names.append(exp_name)

    name2id = get_jobid(names)
    # import pprint
    # pprint.pprint(name2id)
    return logs, name2id


def get_jobid(exp_names):
    """
    return -1 if the job is already dead
    """
    log = subprocess.run(
        ['sacct', '-s', 'running', '--format=JobID,Jobname%200,elapsed'],
        stdout=subprocess.PIPE)
    log = log.stdout.decode().split('\n')
    lines = [l.split() for l in log]
    name2id = {}
    for name in exp_names:
        for l in lines:
            if len(l) < 3:
                continue
            if name == l[1]:
                name2id[name] = int(l[0])
        if name not in name2id:
            name2id[name] = -1
    return name2id


def print_tabulate(logs, name2id):
    headers = [
        'JobId', 'Name', 'Epoch', 'TMax','TRe', 'Res',
        # 'EMax', 'ERe'
    ]
    contents = []
    for name, log in logs.items():
        contents.append(
            [name2id[name],
             name,
             log.num_epochs,
             log.train_max,
             log.train_recent,
             log.train_resource,
             # log.eval_max,
             # log.eval_recent
            ])
    contents = sorted(contents, key=lambda x: x[1])
    print(tabulate(contents, headers,  floatfmt='.2f'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='analysis')
    parser.add_argument('--root', type=str)
    parser.add_argument('--folder', type=str)
    parser.add_argument('--first-k', type=int, default=-1)
    parser.add_argument('--pattern', type=str, default='')

    args = parser.parse_args()

    if args.root is not None:
        logs, name2id = parse_from_root(args.root, args.pattern)
        print_tabulate(logs, name2id)

    if args.folder is not None:
        name = args.folder.split('/')[-1]
        log = Log(name, args.folder)
        log.render_winrate(args.first_k)
