import argparse
import numpy as np
import os
import pickle
from tqdm import tqdm

import hydra
from utils import *

# Calculates forgetting statistics per example
#
# diag_stats: dictionary created during training containing 
#             loss, accuracy, and missclassification margin 
#             per example presentation
# npresentations: number of training epochs
#
# Returns 4 dictionaries with statistics per example
#
def compute_forgetting_statistics(diag_stats, npresentations, start_epoch=0):
    presentations_needed_to_learn = {}
    unlearned_per_presentation = {}
    margins_per_presentation = {}
    first_learned = {}
    norms = {}
    orig_indicies = {}

    for example_id, example_stats in tqdm(diag_stats.items()):
        # Skip 'train' and 'test' keys of diag_stats
        if not isinstance(example_id, str):

            # Forgetting event is a transition in accuracy from 1 to 0
            presentation_acc = np.array(example_stats[1][start_epoch:start_epoch+npresentations])
            transitions = presentation_acc[1:] - presentation_acc[:-1]

            # Find all presentations when forgetting occurs
            if len(np.where(transitions == -1)[0]) > 0:
                unlearned_per_presentation[example_id] = np.where(
                    transitions == -1)[0] + 2
            else:
                unlearned_per_presentation[example_id] = []

            # Find number of presentations needed to learn example, 
            # e.g. last presentation when acc is 0
            if len(np.where(presentation_acc == 0)[0]) > 0:
                if np.sum(presentation_acc) != 0:
                    presentations_needed_to_learn[example_id] = np.where(
                        presentation_acc == 0)[0][-1] + 1
                else:
                    presentations_needed_to_learn[example_id] = -1
            else:
                presentations_needed_to_learn[example_id] = 0

            # Find the misclassication margin for each presentation of the example
            margins_per_presentation = np.array(
                example_stats[2][start_epoch:start_epoch+npresentations])

            # Find the presentation at which the example was first learned, 
            # e.g. first presentation when acc is 1
            if len(np.where(presentation_acc == 1)[0]) > 0:
                first_learned[example_id] = np.where(
                    presentation_acc == 1)[0][0]
            else:
                first_learned[example_id] = np.nan

            orig_indicies[example_id] = example_stats[6][0]
            if len(example_stats)>7:
                norms[example_id] = example_stats[7]

            

    return presentations_needed_to_learn, unlearned_per_presentation, margins_per_presentation, first_learned, norms, orig_indicies

# Sorts examples by number of forgetting counts during training, in ascending order
# If an example was never learned, it is assigned the maximum number of forgetting counts
# If multiple training runs used, sort examples by the sum of their forgetting counts over all runs
#
# unlearned_per_presentation_all: list of dictionaries, one per training run
# first_learned_all: list of dictionaries, one per training run
# npresentations: number of training epochs
#
# Returns 2 numpy arrays containing the sorted example ids and corresponding forgetting counts
#
def sort_examples_by_forgetting(unlearned_per_presentation_all,
                                first_learned_all, npresentations, norms, orig_indices):
    # Initialize lists
    example_original_order = []
    example_stats = []

    for example_id in tqdm(unlearned_per_presentation_all[0].keys()):

        # Add current example to lists
        example_original_order.append(example_id)
        example_stats.append(0)

        # Iterate over all training runs to calculate the total forgetting count for current example
        for i in range(len(unlearned_per_presentation_all)):

            # Get all presentations when current example was forgotten during current training run
            stats = unlearned_per_presentation_all[i][example_id]

            # If example was never learned during current training run, add max forgetting counts
            if np.isnan(first_learned_all[i][example_id]):
                example_stats[-1] += npresentations
            else:
                example_stats[-1] += len(stats)

    # print('Number of unforgettable examples: {}'.format(
    #     len(np.where(np.array(example_stats) == 0)[0])))
    
    indices, fs = np.array(example_original_order)[np.argsort(example_stats)], np.sort(example_stats)
    orig_indices = [orig_indices[i] for i in indices]
    norms = [norms[i] for i in indices if norms]
    return orig_indices, indices, fs, norms

def clean(fs_dict):
    indices = fs_dict['indices']
    orig_indices = fs_dict['orig indices']
    fs_og = fs_dict['forgetting counts']
    fl_og = fs_dict['first learned'][0]
    ll_og = fs_dict['last learned'][0]
    gn_og = fs_dict['norms']
    i_to_orig_i = {k:v for k,v in zip(indices,orig_indices)}

    clean_dict = {'fs':{},'fl':{},'ll':{},'gn':{},
                  'fs_us':{},'fl_us':{},'ll_us':{},'gn_us':{},
                  'fs_comb':{},'fl_comb':{},'ll_comb':{},'gn_comb':{}}

    fs = {}
    fs_us = {}
    for i,orig_i,c in zip(indices, orig_indices, fs_og):
        if i < 45000:
            fs[orig_i] = c
        else:
            fs_us[orig_i] = c

    fs_comb = {orig_i:fs[orig_i]+fs_us[orig_i] if orig_i in fs_us else fs[orig_i] for orig_i in orig_indices}

    clean_dict['fs'] = fs
    clean_dict['fs_us'] = fs_us
    clean_dict['fs_comb'] = fs_comb

    fl = {}
    fl_us = {}
    for i,c in fl_og.items():
        orig_i = i_to_orig_i[i]
        if i < 45000:
            fl[orig_i] = c
        else:
            fl_us[orig_i] = c

    fl_comb = {orig_i:max(fl[orig_i],fl_us[orig_i]) if orig_i in fl_us else fl[orig_i] for orig_i in orig_indices}

    clean_dict['fl'] = fl
    clean_dict['fl_us'] = fl_us
    clean_dict['fl_comb'] = fl_comb

    ll = {}
    ll_us = {}
    for i,c in ll_og.items():
        orig_i = i_to_orig_i[i]
        if i < 45000:
            ll[orig_i] = c
        else:
            ll_us[orig_i] = c

    ll_comb = {orig_i:max(ll[orig_i],ll_us[orig_i]) if orig_i in ll_us else ll[orig_i] for orig_i in orig_indices}

    clean_dict['ll'] = ll
    clean_dict['ll_us'] = ll_us
    clean_dict['ll_comb'] = ll_comb

    if gn_og:
        gn = {}
        gn_us = {}
        for i,orig_i,c in zip(indices, orig_indices, gn_og):
            if i < 45000:
                gn[orig_i] = c
            else:
                gn_us[orig_i] = c

        gn_comb = {orig_i:np.mean(gn[orig_i],gn_us[orig_i]) if orig_i in gn_us else gn[orig_i] for orig_i in orig_indices}

        clean_dict['gn'] = gn
        clean_dict['gn_us'] = gn_us
        clean_dict['gn_comb'] = gn_comb

    return clean_dict

def calculate_forget_scores(args):
    # Argument validation
    assert args.forget_score_epochs[0]<=args.forget_score_epochs[1] and \
            0<=args.forget_score_epochs[0] and \
            isinstance(args.forget_score_epochs[0],int) and \
            isinstance(args.forget_score_epochs[1],int), \
            '(forget_score_epochs) must be a valid integer bounded interval between [0,inf)'
    assert args.exp_name and args.run_name, \
        'Please specify an experiment name (exp_name) and run name (run_name).'
    
    # Setup directories
    args.local_run_path = os.path.join(args.team_path, args.exp_name, args.run_name)

    # Compute, sort, and save forget scores
    print('Calculating forget scores...')

    npresentations = args.forget_score_epochs[1] - args.forget_score_epochs[0]
    unlearned_per_presentation_all, first_learned_all, last_learned_all = [], [], []

    for d, _, fs in os.walk(args.local_run_path):
        for f in fs:
            if f.endswith('stats_dict.pkl'):
                print('\nIncluding file: ' + f)

                with open(os.path.join(d, f), 'rb') as fin:
                    stats_dict = pickle.load(fin)

                if not stats_dict['get_fs_stats']: continue

                last_learned, unlearned_per_presentation, _, first_learned, norms, orig_indices = \
                    compute_forgetting_statistics(stats_dict['stats'],
                                                  npresentations,
                                                  start_epoch=args.forget_score_epochs[0])

                unlearned_per_presentation_all.append(unlearned_per_presentation)
                first_learned_all.append(first_learned)
                last_learned_all.append(last_learned)

    if len(unlearned_per_presentation_all) == 0:
        raise NameError(f'No input files found in {args.local_run_path} that end with {"stats_dict.pkl"}')
    else:
        orig_idx_ordered_examples, ordered_examples, ordered_values, norms = \
            sort_examples_by_forgetting(unlearned_per_presentation_all, first_learned_all, npresentations, norms, orig_indices)
    
        fs_dict = {
                    'indices': ordered_examples,
                    'orig indices': orig_idx_ordered_examples, 
                    'forgetting counts': ordered_values,
                    'first learned': first_learned_all,
                    'last learned': last_learned_all,
                    'norms': norms
                }
        
        if args.fs_postprocess:
            fs_dict_clean = clean(fs_dict)

            with open(os.path.join(args.local_run_path, f'clean_forget_scores_[{args.forget_score_epochs[0]}-{args.forget_score_epochs[1]}].pkl'), 'wb') as file:
                pickle.dump(fs_dict_clean, file)
        else:
            with open(os.path.join(args.local_run_path, f'forget_scores_[{args.forget_score_epochs[0]}-{args.forget_score_epochs[1]}].pkl'), 'wb') as file:
                pickle.dump(fs_dict, file)

    print('\nForget scores saved.\n')

@hydra.main(version_base=None, config_path='.', config_name='forget_scores_params')
def main(cfg):
    # Remove hydra logger
    os.remove(f'{os.path.splitext(os.path.basename(__file__))[0]}.log')

    # Argument intake
    args = DotDict(cfg)

    # Calculate forget scores
    calculate_forget_scores(args)

if __name__ == "__main__":
    main()