import os
import torch

from absl import app, flags
import jax
import flax
from flax import linen as nn
import numpy as np
import random

from influence.estimate import compute_influence
import pandas as pd
from copy import deepcopy
import gc

# additional hyper-parameters
flags.DEFINE_enum('dataset', 'cifar10c', ['cmnist', 'cifar10c', 'bffhq', 'cifar10_lff', 'waterbird'],
help='training dataset')
flags.DEFINE_enum('model', 'resnet', ['resnet'],
help='network architecture')
flags.DEFINE_integer('seed', 0, 
help='random number seed')
flags.DEFINE_integer('test_batch_size_total', 1000, 
help='total batch size (not device-wise) for evaluation')
flags.DEFINE_integer('target_epoch', 5, 
help='target_epoch')

# Dataset Spec
flags.DEFINE_string("percent", "0.5pct",
help="percentage of conflict")
flags.DEFINE_integer('num_workers', 4, 
help='workers number')
flags.DEFINE_bool("use_type0", False,
help="whether to use type 0 CIFAR10C")
flags.DEFINE_bool("use_type1", False, 
help="whether to use type 1 CIFAR10C")
flags.DEFINE_integer("target_attr_idx", 0,
help="target_attr_idx")
flags.DEFINE_string("data_dir", "../dataset",
help="percentage of conflict")

# Optimization Spec
flags.DEFINE_float('lr', 0.001, 
help='learning rate')

# tunable hparams for generalization
flags.DEFINE_float('weight_decay', 0, 
help='l2 regularization coeffcient')
flags.DEFINE_integer('train_batch_size_total', 1000, 
help='total batch size (not device-wise) for training')

flags.DEFINE_integer('topk', 100, 
help='topk')

FLAGS = flags.FLAGS

def main(_):
    os.environ['PYTHONHASHSEED'] = str(FLAGS.seed)
    random.seed(FLAGS.seed)
    # tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)
    torch.cuda.manual_seed(FLAGS.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    hparams = [
        FLAGS.model,
        FLAGS.lr,
        FLAGS.train_batch_size_total,
        FLAGS.seed,
        ]
    hparams = '_'.join(map(str, hparams))
    res_dir = f'./res/{FLAGS.dataset}_{FLAGS.percent}/'+hparams

    # define pseudo-random number generator
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_ = jax.random.split(rng)
    
    cur_directory = os.path.join(res_dir,'networks',str(FLAGS.target_epoch))
    
    # load self influence
    bias_conflict_mask = None

    if FLAGS.seed == 0:
        seed_list = [0, 1, 2]
    elif FLAGS.seed == 3:
        seed_list = [3, 4, 5]
    else:
        seed_list = [6, 7, 8]

    for inf_seed in range(FLAGS.seed, FLAGS.seed + 3):
        inf_hparams = [
            FLAGS.model,
            FLAGS.lr,
            FLAGS.train_batch_size_total,
            inf_seed,
            ]
        inf_hparams = '_'.join(map(str, inf_hparams))
        inf_res_dir = f'./res/{FLAGS.dataset}_{FLAGS.percent}/'+inf_hparams

        inf_file_name = f'{inf_res_dir}/networks/{FLAGS.target_epoch}/influence_train_df.tsv'
        df = pd.read_csv(inf_file_name, sep='\t')
        index = df['index'].values.astype(np.int64)
        ori_true_label = df['true_label'].values.astype(np.int64)
        ori_bias_label = df['bias_label'].values.astype(np.int64)
        assert (index - np.arange(len(index))).sum() == 0
        self_influence = df['self_influence'].values
        
        sort_idx = np.argsort(self_influence)

        num_classes = ori_true_label.max() + 1
        if FLAGS.topk == 0:
            num_sample_per_class = int((ori_true_label != ori_bias_label).sum() / num_classes)
        else:
            num_sample_per_class = FLAGS.topk
        
        unbias_idx = []
        for c in range(num_classes):
            class_mask = (ori_true_label == c)[sort_idx]
            unbias_idx.append(sort_idx[class_mask][-num_sample_per_class:])
        unbias_idx = np.concatenate(unbias_idx,axis=0)
        
        if bias_conflict_mask is None:
            bias_conflict_mask = np.ones(len(index), dtype=np.bool)
        
        new_bias_conflict_mask = np.zeros(len(index), dtype=np.bool)
        new_bias_conflict_mask[unbias_idx] = True
        bias_conflict_mask = bias_conflict_mask & new_bias_conflict_mask

    unbias_idx = np.arange(len(index))[bias_conflict_mask]
    
    df = pd.DataFrame(np.stack([unbias_idx, ori_true_label[unbias_idx], \
        ori_bias_label[unbias_idx]],axis=1), \
        columns=['index', 'true_label', 'bias_label'])

    df.to_csv(cur_directory+f'/bias_conflict_detect_df_{str(num_sample_per_class)}_infonly_{str(3)}.tsv', sep='\t', index=False)

if __name__ == "__main__":
    app.run(main)
