import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=40)

parser.add_argument('--log_interval', type=int, default=400)

parser.add_argument('--truncation', type=int, default=25)   # T

parser.add_argument('--arch', type=str, default='MLP', choices=['MLP', 'conv'])
parser.add_argument('--dataset', type=str, default='fixedmnist', choices=['omniglot', 'fixedmnist'])
parser.add_argument('--dataset_dir', type=str, default='')

parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--test_batch_size', type=int, default=20)
parser.add_argument('--epochs', type=int, default=1000)

parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--arch_learning_rate', type=float, default=1e-3,
                    help="learning rate for architecture related parameters")

parser.add_argument('--test_arch_n', type=int, default=1)
parser.add_argument('--S', type=int, default=8)           # S
parser.add_argument('--M', type=int, default=8)           # M
parser.add_argument('--K', type=int, default=8)           # K

parser.add_argument('--h_dim', type=int, default=200)     #O
parser.add_argument('--z_dim', type=int, default=50)      #h
parser.add_argument("--arch_beta", type=float, default=0.05)  # architecture kl loss scaler

parser.add_argument('--loss_logfile', type=str, default='')
parser.add_argument('--model_file', type=str, default='')


def get_args():
    args = parser.parse_args()
    file_string = f"{args.dataset}_{args.arch}_{args.S}x{args.M}x{args.K}" \
                  f"_hdim_{args.h_dim}_zdim_{args.z_dim}_truncation_{args.truncation}_seed_{args.seed}"

    #file naming for the saved model file and the log file
    if args.loss_logfile == '':
        args.loss_logfile = f"loss_logs/{args.dataset}/log_vae_{file_string}.txt"
    if args.model_file == '':
        args.model_file = f"saved_models/{args.dataset}/vae_{file_string}.pt"
    print(args.model_file)

    # number of posterior samples for negative log likelihood estimation
    if args.arch == "MLP":
        args.log_likelihood_k = int(5000/args.S)*args.S
        args.img_shape = (28, 28)
    else:
        args.log_likelihood_k = int(3000 / args.S) * args.S  # using fewer samples for CNN layers to avoid out of memory error
        args.img_shape = (1, 28, 28)

    return args
