'''Modified from https://github.com/alinlab/LfF and https://github.com/kakaoenterprise/Learning-Debiased-Disentangled'''

from absl import app, flags

from finetuning_learner import Learner
from torch_util import set_seed

flags.DEFINE_integer('mini_batch_size', 256, 'mini_batch_size')
flags.DEFINE_integer('gpu', 0, 'gpu')

flags.DEFINE_float('lr', 0.001, 'learning rate')
flags.DEFINE_float('lr_decay_rate', 0.1, 'learning rate')
flags.DEFINE_float('weight_decay', 1e-4, 'weight_decay')
flags.DEFINE_integer('num_workers', 4, 'workers number')
flags.DEFINE_string('device', 'cuda', 'cuda or cpu')
flags.DEFINE_integer('num_steps', 100, '# of iterations')
flags.DEFINE_integer('target_attr_idx', 0, 'target_attr_idx')
flags.DEFINE_integer('bias_attr_idx', 1, 'bias_attr_idx')
flags.DEFINE_string('dataset', 'cifar10c', 'data to train, [cmnist, cifar10c, bffhq, cifar10_lff, waterbird]')
flags.DEFINE_string('percent', '0.5pct', 'percentage of conflict')
flags.DEFINE_bool('use_type0', False, 'whether to use type 0 CIFAR10C')
flags.DEFINE_bool('use_type1', False, 'whether to use type 1 CIFAR10C')

# logging
flags.DEFINE_string('data_dir', './dataset', 'path for loading data')
flags.DEFINE_string('inf_path', './influence.tsv', 'path for loading influence')
flags.DEFINE_string('model_path', './model.pt', 'path for loading model')

# retrain
flags.DEFINE_integer('seed', 0, 'seed')
flags.DEFINE_float('b_weight', 0.1, 'weight for the remaining dataset')
flags.DEFINE_bool('cosine', False, 'use cosine annealing')
flags.DEFINE_bool('dfa', False, 'whether pretrained model is DFA')
args = flags.FLAGS

# actual training
def main(_):
    set_seed(args.seed)
    learner = Learner(args)

    if args.dfa:
      learner.retrain_dfa(args)
    else:
      learner.retrain(args)

if __name__ == '__main__':
    app.run(main)