import os
import glob
import torch
import argparse
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR

from utils import str2bool
from datasets import get_dataset
from training import ProtoNetTrainer
from models import SetTaskInterpolator, ProtoNetLinear

parser = argparse.ArgumentParser()
parser.add_argument('--root', default='', type=str, help='dataset path.')
parser.add_argument('--run', default='0', type=str, help='run identifier.')
parser.add_argument('--model', default='settaskinterpolator_protonet', type=str, help='model to train.')
parser.add_argument('--wandb_entity', default='', type=str, help='wandb entity name.')
parser.add_argument('--wandb_project', default='', type=str, help='project name on wandb.')
parser.add_argument('--dataset', default='esc50', type=str, help='meta dataset.')
parser.add_argument('--optimizer', default='adam', type=str, help='optimizer to use for training.')
parser.add_argument('--batch_size', default=256, type=int, help='dataloader batch size.')
parser.add_argument('--train_batch_size', default=256, type=int, help='train dataloader batch size.')
parser.add_argument('--epochs', default=3000, type=int, help='number of iterations.')
parser.add_argument('--N', default=30, type=int, help='number of iterations.')
parser.add_argument('--N_valid', default=300, type=int, help='number of iterations.')
parser.add_argument('--N_test', default=3000, type=int, help='number of iterations.')
parser.add_argument('--ways', default=2, type=int, help='number of ways.')
parser.add_argument('--shots', default=5, type=int, help='number of shots.')
parser.add_argument('--query_shots', default=10, type=int, help='train time query shots.')
parser.add_argument('--epsilon', default=1.0, type=float, help='reptile stepsize.')
parser.add_argument('--num_workers', default=1, type=int, help='number of workers in dataloader.')
parser.add_argument('--num_workers_train', default=2, type=int, help='number of workers in dataloader.')
parser.add_argument('--inner_lr', default=0.01, type=float, help='inner lr for maml')
parser.add_argument('--inner_steps', default=5, type=int, help='the number of inner steps for MAML')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate.')
parser.add_argument('--mt', default=0.9, type=float, help='momentum.')
parser.add_argument('--tau', default=0.5, type=float, help='temperature.')
parser.add_argument('--adaptation_lr', default=0.0001, type=float, help='learning rate for adaptation stage.')
parser.add_argument('--valid_freq', default=100, type=int, help='how often to evaluate on the validation set.')
parser.add_argument('--interp', default=False, type=str2bool, help='perform task interpolation or not.')
parser.add_argument('--num_tasks', default=2, type=int, help='number of tasks to interpolate.')
parser.add_argument('--interp_name', default='mlti', type=str, help='name of interpolation method.')
parser.add_argument('--alpha', default=2.0, type=float, help='concentration value for beta distribution')
parser.add_argument('--beta', default=2.0, type=float, help='concentration value for beta distribution')
parser.add_argument('--devices', type=str, default='0,1,2,3', help='multiple devices for multiprocessing')
parser.add_argument('--ilayer', type=int, default=3, help='interpolation layer')
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--noise', type=str, default='sq', help='add noise to support(s) or query(q), both(sq) or none')
parser.add_argument('--mix', type=str, default='sq', help='mix support(s) or query(q), both(sq) or none')
parser.add_argument('--outer_episodes', type=int, default=100, help='outer episodes for BO')
parser.add_argument('--inner_episodes', type=int, default=50, help='inner episodes for BO')
parser.add_argument('--BS', type=int, default=200, help='batch size for validation set in BO')
parser.add_argument('--num_runs', type=int, default=1, help='number of forward passes for set transformer.')
parser.add_argument('--interpname', type=str, default='', help='comment to add to logger')
parser.add_argument('--visualize', type=str2bool, default=False, help='comment to add to logger')
args = parser.parse_args()

if __name__ == '__main__':
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    settaskinterpolator_models = ['settaskinterpolator_protonet']

    if 'protonet' in args.model:
        if args.dataset in ['esc50']:
            model = ProtoNetLinear(alpha=args.alpha, beta=args.beta, name=args.model, dataset=args.dataset).to(args.device)
        else:
            raise NotImplementedError('{} not implemented'.format(args.model))

        if args.optimizer == 'adam':
            optimizer = Adam(model.parameters(), lr=args.lr)
        elif args.optimizer == 'sgd':
            optimizer = SGD(model.parameters(), lr=args.lr)
        else:
            raise NotImplementedError('{} optimizer not implemented'.format(args.optimizer))
        scheduler = None
        
        layer_shapes = {
                #[ , , ,] in_dim, num_heads, hidden_dim
                'esc50': {-1: [640, 4, 1024], 0: [500, 4,  1024], 1: [500, 4,  1024], 2: [500, 4,  1024]},
                }

        interpolator, optimizer_interp = None, None
        if args.model in settaskinterpolator_models:
            layer_shape = layer_shapes[args.dataset]
            dim_size = layer_shape[args.ilayer]
            interpolator = SetTaskInterpolator(dim_in=dim_size[0], dim_hidden=dim_size[2], num_heads=dim_size[1], \
                    num_runs=args.num_runs, name=args.interpname, layer=args.ilayer, noise=args.noise, mix=args.mix).to(args.device)
            optimizer_interp = Adam(interpolator.parameters(), 1e-4)
        
        trainloader, validloader, testloader = get_dataset(args)
        
        trainer = ProtoNetTrainer(model=model, optimizer=optimizer, interpolator=interpolator, optimizer_interp=optimizer_interp,\
                    scheduler=scheduler, trainloader=trainloader, validloader=validloader, testloader=testloader, args=args)
        test_loss, test_acc = trainer.fit()
        print('\nTest Loss: {:.4f} Accuracy: {:.4f}'.format(test_loss, test_acc))
    else:
        raise NotImplementedError('{} not implemented'.format(args.model))
