'''Training helper script of parametric pacmap.
'''
import argparse
import os
import time
import pickle as pkl

import torch
import torch.utils.data
import numpy as np
from sklearn import preprocessing

from paramrepulsor.models import module, dataset
from paramrepulsor.utils import data, utils


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str)
    args = parser.parse_args()
    return args


def get_config(args):
    config = args.config
    config = utils.read_yaml(config)
    config = utils.impute_default(config, utils.DEFAULT_CONFIG)
    return config


def convert_pairs(pair_neighbors, pair_FP, pair_MN, N):
    pair_neighbors = pair_neighbors[:, 1].reshape((N, -1))
    pair_FP = pair_FP[:, 1].reshape((N, -1))
    pair_MN = pair_MN[:, 1].reshape((N, -1))
    return pair_neighbors, pair_FP, pair_MN


def get_loaders(config):
    '''
    Prepare the dataloader for training based on the dataset name and desired shape.
    '''
    # Load the data
    X, y = data.data_prep(dataset=config['dataset'],
                          size=config['datasize'],
                          dim=config['datadim'],
                          pca=config['datapca'],)
    input_dims = X.shape[1]
    # Construct the pairs
    n_neighbors, n_FP, n_MN = config['n_neighbors'], config['n_FP'], config['n_MN']
    pair_neighbors, pair_MN, pair_FP, _ = data.generate_pair(
        X, n_neighbors=n_neighbors, n_MN=n_MN, n_FP=n_FP,
        distance=config['distance'], verbose=False
    )
    if config['datascale'] == 1:
        scaler = preprocessing.StandardScaler()
        X = scaler.fit_transform(X)
    elif config['datascale'] == 2:
        scaler = preprocessing.MinMaxScaler()
        X = scaler.fit_transform(X)
    nn_pairs, fp_pairs, mn_pairs = convert_pairs(pair_neighbors, pair_FP, pair_MN, X.shape[0])
    assert isinstance(config['use_negative_sampling'], bool)
    if config['use_negative_sampling']:
        train_set = dataset.NegativeSamplingDataset(
            data=X,
            nn_pairs=nn_pairs,
            fp_pairs=fp_pairs,
            mn_pairs=mn_pairs,
            reshape=config['datareshape']
        )
    else:
        train_set = dataset.PaCMAPDataset(
            data=X,
            nn_pairs=nn_pairs,
            fp_pairs=fp_pairs,
            mn_pairs=mn_pairs,
            reshape=config['datareshape']
        )
    train_loader = torch.utils.data.DataLoader(dataset=train_set, 
                                               batch_size=config['batch_size'],
                                               shuffle=True,
                                               drop_last=False,
                                               pin_memory=True,
                                               num_workers=config['dlworker'],
                                               persistent_workers=True)
    val_set = dataset.PaCMAPDataset(data=X,
                                    nn_pairs=nn_pairs,
                                    fp_pairs=fp_pairs,
                                    mn_pairs=mn_pairs,
                                    reshape=config['datareshape'])
    val_loader = torch.utils.data.DataLoader(dataset=val_set, 
                                             batch_size=config['batch_size_val'],
                                             shuffle=True,
                                             drop_last=False,
                                             pin_memory=True,
                                             num_workers=config['dlworker'],
                                             persistent_workers=True)
    test_set = dataset.TensorDataset(data=X, reshape=config['datareshape'])
    test_loader = torch.utils.data.DataLoader(dataset=test_set, 
                                             batch_size=config['batch_size_inference'],
                                             shuffle=False,
                                             drop_last=False,
                                             pin_memory=True,
                                             num_workers=config['dlworker'],
                                             persistent_workers=True)
    return train_loader, val_loader, test_loader, input_dims

