import numpy as np
import pandas as pd
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from pathlib import Path
import tinydb
import anymarkup

import exp_utils as PQ


Run = Query()
_cache = {}


def patch_tinydb():
    cls = tinydb.database.Document
    PQ.log.warning(f'Patching tinydb')

    def read(self, tag):
        df = pd.DataFrame(self['events'].Scalars(tag))
        return df

    def read_value(self, tag):
        mean = self.read(tag + '/mean')
        std = self.read(tag + '/std')
        data = {'wall_time': mean['wall_time'], 'step': mean['step'], 'mean': mean['value'], 'std': std['value']}
        if tag + '/n' in self.tags()['scalars']:
            data['n'] = self.read(tag + '/n')['value']
        df = pd.DataFrame(data)
        df.set_index('step', inplace=True)
        return df

    def tags(self):
        return self['events'].Tags()

    def scalars(self):
        return self['events'].Tags()['scalars']

    cls.read = read
    cls.read_value = read_value
    cls.tags = tags
    cls.scalars = scalars


patch_tinydb()


def load(base_dir, db: tinydb.TinyDB = None, verbose=False, unload=False):
    if db is None:
        db = TinyDB(storage=MemoryStorage)
    for evt_file in Path(base_dir).expanduser().glob('**/events.out.tfevents.*'):
        query = str(evt_file)
        if query in _cache and not unload:
            doc = _cache[query]
            db.insert(doc)
            doc['events'].Reload()
            continue
        if verbose:
            PQ.log.info(f'Loading {evt_file}')
        evt_acc = EventAccumulator(str(evt_file), purge_orphaned_data=False)
        evt_acc.Reload()

        for config_file in ['config.toml', 'config.json5']:
            p = evt_file.parent / config_file
            if p.exists():
                config = anymarkup.parse_file(p)
                break
        else:
            assert False, "Can't find config file"

        config['events'] = evt_acc
        config['base_dir'] = str(base_dir)
        config['log_dir'] = str(evt_file.parent)

        db.insert(config)
        _cache[query] = config
    return db


def clear_cache():
    _cache.clear()


def get_runs(*log_dirs, unload=False):
    db = TinyDB(storage=MemoryStorage)
    for log_dir in log_dirs:
        load(log_dir, db=db, unload=unload)
    return db


def resolve_runs(src):
    if isinstance(src, str):
        return get_runs(src)
    elif isinstance(src, list):
        if isinstance(src[0], str):
            return get_runs(*src)
        return src
    elif isinstance(src, tinydb.TinyDB):
        return src
    assert 0, f'Unknown type: {type(src)}'


__all__ = ['Run', 'load', 'clear_cache', 'get_runs', 'resolve_runs']
