import argparse


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', '-dataset', default='mnist')
    p.add_argument('--log_dir_root', '-o', default='/data1/mclgan_exp/test')
    p.add_argument('--manual_seed', '-seed', type=int)
    p.add_argument('--n_disc', '-ndisc', type=int, default=10) # number of discriminators
    p.add_argument('--n_expert', '-nexp', type=int, default=1) # should be <= n_gens
    p.add_argument('--gan_type', type=str, default='dcgan') # dcgan / lsgan / hinge
    p.add_argument('--use_bn', '-use_bn', type=str2bool, default=False) # batch norm in discriminator
    p.add_argument('--d_batch_size', '-d_batch', type=int, default=64) 
    p.add_argument('--g_batch_size', '-g_batch', type=int, default=128) 
    p.add_argument('--fixed_z_batch_size', '-fixed_z_batch', type=int, default=100) 
    p.add_argument('--z_prior', '-zp', type=str, default='n') # u :uniform, n :normal
    p.add_argument('--n_epoch', '-nepoch', type=int, default=40) 
    p.add_argument('--img_size', '-img_size', type=int, default=32) # side length of square image
    p.add_argument('--num_channel', '-num_channel', type=int, default=3)  
    p.add_argument('--d_learning_rate', '-d_lr', type=float, default=0.0002)
    p.add_argument('--g_learning_rate', '-g_lr', type=float, default=0.0002)  
    p.add_argument('--d_weight_decay', '-d_wd', type=float, default=0.0)
    p.add_argument('--g_weight_decay', '-g_wd', type=float, default=0.0)
    p.add_argument('--beta1', '-beta1', type=float, default=0.5) # for adam optimizer
    p.add_argument('--beta2', '-beta2', type=float, default=0.999) # for adam optimizer
    p.add_argument('--lr_update_freq', '-lrf', type=int, default=1) # learning rate update frequency
    p.add_argument('--lr_gamma', '-lrg', type=float, default=0.0) 
    p.add_argument('--kld_decay', '-kd', type=float, default=0.9) # kld loss weight decay
    p.add_argument('--nz', '-nz', type=int, default=100)
    p.add_argument('--ndf', '-ndf', type=int, default=128)
    p.add_argument('--ngf', '-ngf', type=int, default=128)
    p.add_argument('--d_lambda_kld', '-dkld', type=float, default=0.5) # balance loss weight for discriminator (kld loss)
    p.add_argument('--g_lambda_kld', '-gkld', type=float, default=0.0) # balance loss weight for generator (kld loss)
    p.add_argument('--lambda_ne', '-lambda_ne', type=float, default=1.0) # nonexpert loss weight
    p.add_argument('--temperature', '-t', type=float, default=1) # softmax temperature
    p.add_argument('--lambda_l1', '-lambda_l1', type=float, default=0.0) # L1 loss weight
    p.add_argument('--nonexpert_label', '-ne_label', type=float, default=0.5) # soft label for real data for nonexpert loss
    p.add_argument('--kld_update_freq', '-kf', type=int, default=10) # kld loss weight update frequency

    opt = p.parse_args()
    return opt