
def reproduce(args):
    if args.dataset == 'colored_mnist':
        args.model = 'CONV'
        args.opt = 'SGD'
        args.batch_size = 128
        args.lr = 0.02
        args.n_lr = 0.02
        args.weight_decay = 0.001
        args.momentum = 0.9
        args.num_class = 10
        args.use_lr_decay=True
        args.lr_decay = 0.1
        args.lr_decay_step = 40
        args.epochs = 100
        
    elif args.dataset == 'biased_mnist':
        args.model = 'CONV2'
        args.opt = 'SGD'
        args.batch_size = 32
        args.lr = 0.01
        args.weight_decay=0.0001
        args.momentum = 0.9
        args.num_class = 10
        args.use_lr_decay = True
        args.lr_decay_step = 100
        args.lr_decay = 0.1
        args.epochs = 100
        
        
    elif args.dataset == 'corrupted_cifar':
        args.model = 'ResNet18'
        # args.opt = 'SGD'
        args.opt = 'Adam'
        # args.batch_size = 128
        args.batch_size = 256
        args.lr = 0.001
        args.n_lr = 0.1
        # args.momentum = 0.9
        # args.weight_decay=0.0
        args.weight_decay=0.001
        args.num_class = 10
        args.use_lr_decay = True
        # args.lr_decay = 0.6
        args.lr_decay = 0.5
        args.lr_decay_step = 40
        # args.epochs = 100
        args.epochs = 200
        
    
    elif args.dataset == 'bar':
        args.model = 'ResNet18'
        args.opt = 'SGD'
        args.batch_size = 16
        args.lr = 0.0005
        args.n_lr = 0.001
        args.n_epochs = 20
        args.momentum = 0.9
        args.weight_decay = 1e-5
        args.num_class = 6
        args.use_lr_decay = True
        args.lr_decay = 0.1
        args.lr_decay_step = 20
        args.epochs = 100
        
    elif args.dataset == 'bffhq':
        args.model = 'ResNet18'
        args.opt = 'Adam'
        args.batch_size = 64
        args.lr = 0.0001
        args.n_lr = 0.0001
        args.weight_decay = 0
        args.num_class = 2
        args.lr_decay = 0.1
        args.use_lr_decay = True
        args.lr_decay_step = 32
        args.epochs = 160
        args.n_epochs = 1

    else:
        print("Wrong data")
        import sys
        sys.exit(0)
    
    return args


# python train.py --reproduce --dataset colored_mnist --bratio 0.005 --nratio 0 --exp ablation --save_stats --train --seed 0 --algorithm lff --gpu 0 
