import argparse
import os
from config import RESULTS_DIR


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def get_train_args():
    parser = argparse.ArgumentParser()
    # Dataset
    parser.add_argument('--dataset',
                        type=str,
                        default=None,
                        help="Which dataset to use/download: cifar10, cifar100, mnist")
    parser.add_argument('--classes',
                        type=int,
                        nargs="+",
                        help="Binarize multiclass problem: which classes to be considered as class 1")
    parser.add_argument('--num_classes',
                        type=int,
                        default=1,
                        help="Number of classes")
    parser.add_argument('--classes1',
                        type=int,
                        nargs='+',
                        help="Which classes to consider as class 1 for the 3-class problem")
    parser.add_argument('--classes2',
                        type=int,
                        nargs='+',
                        help="Which classes to consider as class 2 for the 3-class problem")
    parser.add_argument('--classes3',
                        type=int,
                        nargs='+',
                        help="Which classes to consider as class 3 for the 3-class problem")
    # Model
    parser.add_argument('--model_name',
                        type=str,
                        default=None)
    parser.add_argument('--depth',
                        type=int,
                        help="Depth of the NN")
    parser.add_argument('--block_name',
                        type=str,
                        default='basicblock',
                        help='The building block for Resnet: BasicBlock or Bottleneck')
    # Training
    parser.add_argument('--epochs',
                        type=int,
                        default=200)
    parser.add_argument('--batch_size',
                        type=int,
                        default=128)
    parser.add_argument('--lr',
                        type=float,
                        default=0.1)
    parser.add_argument('--dropout',
                        type=float,
                        default=0)
    parser.add_argument('--decrease_lr_epochs',
                        type=int,
                        nargs='+',
                        default=[150, 225],
                        help="What epochs to decrease the lr at"),
    parser.add_argument('--decrease_lr_factor',
                        type=float,
                        default=0.1,
                        help="Factor to multiply the lr by at decrease_lr_epochs")
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9)
    parser.add_argument('--wd',
                        type=float,
                        default=1e-4,
                        help="Weight decay")
    parser.add_argument('--widen_factor',
                        type=int,
                        help="Widen factor")
    parser.add_argument('--loss',
                        type=str,
                        choices=['mse', 'ce', 'mmce', 'focal', 'kde_ce', 'kde_mse'],
                        default='mse',
                        help="Which loss function to use from given choices")
    parser.add_argument('--death_mode',
                        choices=['linear', 'uniform'],
                        default='linear',
                        help='Death mode for Resnet with stochastic depth')
    parser.add_argument('--death_rate',
                        type=float,
                        default=0.5,
                        help='Death rate for Resnet with stochastic depth')
    parser.add_argument('--adaptive',
                        type=str2bool,
                        default=False,
                        help='Whether to use adaptive gamma for focal loss')
    parser.add_argument('--loss_param',
                        type=float,
                        default=1,
                        help="Gamma for focal components or reg. weight for MMCE and KDE ")
    # Checkpoints
    parser.add_argument('--experiment_dir',
                        type=str,
                        default=None,
                        help="An experiment directory containing the last saved checkpoint to start training from")
    parser.add_argument('--how_often_ckp',
                        type=int,
                        default=50,
                        help="How often to save a checkpoint")
    # Regularization
    parser.add_argument('--b',
                        type=str,
                        default='auto',
                        help="The bandwidth parameter: auto for automatic selection, otherwise float value")
    parser.add_argument('--p',
                        type=float,
                        default=2,
                        help="Which Lp norm to use")
    parser.add_argument('--n_bins',
                        type=int,
                        default=15,
                        help="How many bins for binned estimate")
    parser.add_argument('--mc_type',
                        choices=['top_label', 'marginal', 'canonical'],
                        default='canonical',
                        help="The type of calibration for multiclass classification")

    parser.add_argument('--use_cuda',
                        type=str2bool,
                        default=True)
    parser.add_argument('--plots_on',
                        type=str2bool,
                        default=False)
    parser.add_argument('--seed',
                        type=int,
                        default=0)

    return parser.parse_args()


def get_temp_scaling_args():
    parser = argparse.ArgumentParser()
    add_folder_arg(parser)
    parser.add_argument('--criterion',
                        type=str,
                        default='nll',
                        help='The criterion to minimize for temperature scaling: nll, kde_ce or nll_kde_ce')
    parser.add_argument('--reg_weight',
                        type=float,
                        help='Regularization weight for scaling with nll_kde_ce')
    parser.add_argument('-seed',
                        default=0)

    return vars(parser.parse_args())


def get_folder_arg():
    parser = argparse.ArgumentParser()
    add_folder_arg(parser)

    return vars(parser.parse_args())


def add_folder_arg(parser):
    parser.add_argument('--folders',
                        nargs='+',
                        type=str,
                        default=os.listdir(RESULTS_DIR)[-1])
