# coding=utf-8

import argparse
import random
import os
import numpy as np
import torch

def get_args():
    delay_types = ['fixed', 'random']
    parser = argparse.ArgumentParser(description="meta config of experiment")
    parser.add_argument('--dataset', default=None, type=str, metavar='data')
    parser.add_argument('--model', default='bert', type=str, metavar='model')
    parser.add_argument('--num-epochs', default=200, type=int, metavar='N', help='number of epochs')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--delay', default=16, type=int)
    parser.add_argument('--delay-type', default='random', type=str, choices=delay_types)
    parser.add_argument('--num-workers', default=8, type=int, metavar='W')
    parser.add_argument('--batch-size', default=16, type=int, metavar='b', help='batch size per worker')
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--logdir', default='./log/', type=str)
    parser.add_argument('--lr-schedule', default='const', type=str, choices=['const', 'decay', 't'])
    parser.add_argument('--lr-decay', default=0.1, type=float)
    parser.add_argument('--cuda-device', default=0, type=int, metavar='c')
    parser.add_argument('--print-freq', default=50, type=int, metavar='p')
    parser.add_argument('--eval-freq', default=0, type=int)
    parser.add_argument("--cuda-ps", action='store_true')
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data "
        "files) for the task.",
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, "
        "bert-base-multilingual-uncased, bert-base-multilingual-cased, "
        "bert-base-chinese.",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece "
        "tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.",
    )
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--vocab_file',
                        type=str,
                        default=None,
                        required=True,
                        help="Vocabulary mapping/file BERT was pretrainined on")
    parser.add_argument("--config_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The BERT model config")
    args = parser.parse_args()

    return args

def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
