import sys
import argparse
import time

def get_args(argv=None):
    parser = argparse.ArgumentParser(description='AG attack')
    parser.add_argument('--dataset', type=str, choices=['MNIST', 'CIFAR10', 'TinyImgNet'], default='CIFAR10', help='Which dataset to use')
    parser.add_argument('--ds_type', type=str, choices=['train', 'test'], default='train', help='Which dataset to use')
    
    # Neptune
    parser.add_argument('--neptune', type=str, help='Neptune project name, leave empty to not use neptune', default=None)
    parser.add_argument('--neptune_label', dest='label', type=str, help='Neptune label of the experiment')
    parser.add_argument('--neptune_offline', action='store_true', help='Run Neptune in offline mode')

    parser.add_argument('--random_seed', type=int, default=0)
    parser.add_argument('--B', type=int, required=True, help="Batch size")
    parser.add_argument('--L', type=int, default=6, help="Layers")
    parser.add_argument('--W', type=int, required=True, help="Widht of Layers")
    parser.add_argument('--N', type=int, default=int(1e+9), help="Number of samples")
    parser.add_argument('--par-SVD', type=int, default=int(5e+5), help="Number of parallel SVDs")
    parser.add_argument('--pFN', type=float, default=1e-5, help="The probability of not finding a vector")

    parser.add_argument('--en', type=int, default=50, help="How many samples to apply on")
    parser.add_argument('--steps', type=int, help="At which training step to attack")
    parser.add_argument('--st', type=int, default=0, help="How many samples to skip")
    parser.add_argument('--cond', type=str, choices=['gt', 'early', 'samples'], default='early', help='When to stop collecting directions q_i')
    parser.add_argument('--sigma_tol', type=float, default=1e-7, help='Tolerance of sigma calc')
    parser.add_argument('--sigma_treshold', type=float, default=0.99, help='Threshold on sigma for early stopping')
    parser.add_argument('--sparsity_tol', type=float, default=1e-6, help='Floating point tolerance for accepting sparsity in parallel SVD')

    parser.add_argument('--count_hack', action='store_true', default=False, help='Only count do not run the recovery')
    parser.add_argument('--true_B', action='store_true', default=False, help='Should the algorithm use the correct B or estimate it from gradients')

    if argv is None:
        argv = sys.argv[1:]
    args=parser.parse_args(argv)

    if args.neptune is not None:
        import neptune.new as neptune
        assert('label' in args)
        nep_par = { 'project':f"{args.neptune}", 'source_files':["*.py"] } 
        if args.neptune_offline:
            nep_par['mode'] = 'offline'
            args.neptune_id = 'AG-0'
        
        run = neptune.init( **nep_par )
        args_dict = vars(args)
        for k in args_dict:
            if k == 'N':
                run[f"parameters/{k}"] = float( args_dict[k] )
            else:
                run[f"parameters/{k}"] = args_dict[k]
        args.neptune = run
        if not args.neptune_offline:
            print('waiting...')
            start_wait=time.time()
            args.neptune.wait()
            print('waited: ',time.time()-start_wait)
            args.neptune_id = args.neptune['sys/id'].fetch()
        print( '\n\n\nArgs:', *argv, '\n\n\n' ) 
    return args
