import pickle
import time
from os.path import join

import numpy as np
import torch.nn as nn
import torch.utils.data
import torchvision
from einops import reduce
from scipy.stats import pearsonr
from tqdm import tqdm

from model.Dataset import FullDataset


def get_idx(idx_path, config):
    """
    Util function to align indices of stimuli and responses in the spike matrix.

    Returns: train/val/test idx for images and targets
    """
    with open(idx_path, 'rb') as fp:
        idx_dict = pickle.load(fp)
    train_idx = sorted(idx_dict['train'])
    val_idx = sorted(idx_dict['val'])

    # Randomly select test images from the validation set
    test_idx = sorted(list(np.random.choice(val_idx, size=config['test_size'], replace=False)))
    for test_id in test_idx:
        val_idx.remove(test_id)

    # Align the rows of the spike matrix with the train/test/val set
    all_idx = train_idx.copy()
    all_idx.extend(val_idx)
    all_idx.extend(test_idx)

    train_target_idx, val_target_idx, test_target_idx = [], [], []

    for i, id in enumerate(sorted(all_idx)):
        if id in train_idx:
            train_target_idx.append(i)
        if id in val_idx:
            val_target_idx.append(i)
        if id in test_idx:
            test_target_idx.append(i)

    return train_idx, val_idx, test_idx, train_target_idx, val_target_idx, test_target_idx

def make_train_set(model, raw_train_set):
    '''
    Since training is based on latent activations of the CNN backbone,
    cache latent activations for all training images to speed up  training.
    Args:
        model: instance of class GaussianReadoutModel
        image_set: instance of class Dataset containing the images

    Returns: TensorDataset consisting of activations output by backbone CNN
    '''
    unprocessed_loader = torch.utils.data.DataLoader(raw_train_set, batch_size=128)

    features = []
    targets = []
    with torch.no_grad():
        for x, y in tqdm(unprocessed_loader):
            x = x.to('cuda')
            z = model.core_model(x)
            features.append(z.cpu().detach())
            targets.append(y.cpu())
    features = torch.cat(features)
    targets = torch.cat(targets)

    train_set = torch.utils.data.TensorDataset(features, targets)
    return train_set


def build_datasets(model, config, idx_path, spike_array_path, use_silhouettes=False):
    '''
    Builds train/val/test datasets.
    Args:
        model:
        idx_path: Path leading to dict containing the indices of images presented to the monkey.
        spike_array_path: Path leading to neural data
        use_silhouettes: Whether to use silhouettes for augmentation

    Returns: train/val/test set, indices of the random split

    '''
    train_idx, val_idx, test_idx, train_target_idx, val_target_idx, test_target_idx = get_idx(idx_path, config)
    # Put the used indices in a dict for later reference
    keys = ["train", "val", "test", "train_target", "val_target", "test_target"]
    vals = [train_idx, val_idx, test_idx, train_target_idx, val_target_idx, test_target_idx]
    idx_dict = {}
    for key, val in zip(keys, vals):
        idx_dict[key] = val


    # Make val and test set
    base_path = config['base_path']
    stimulus_path = config['stimulus_path']

    val_set = FullDataset(stimulus_path=join(base_path, stimulus_path),
                                   name_path=None, transform=torchvision.transforms.ToTensor(),
                                    spike_array_path=spike_array_path, idx=val_idx, target_idx=val_target_idx)


    test_set = FullDataset(stimulus_path=join(base_path, stimulus_path),
                                 name_path=None, transform=torchvision.transforms.ToTensor(),
                                 spike_array_path=spike_array_path, idx=test_idx, target_idx=test_target_idx)

    # Get output of core model because features are precomputed for the train images
    raw_train_set = FullDataset(stimulus_path=join(base_path, stimulus_path),
                          name_path=None, transform=torchvision.transforms.ToTensor(),
                          spike_array_path=spike_array_path, idx=train_idx, target_idx=train_target_idx)
    train_set = make_train_set(model, raw_train_set)

    if use_silhouettes:
        raw_silhouette_train_set = FullDataset(stimulus_path=join(base_path, stimulus_path),
                                    name_path=None, transform=torchvision.transforms.ToTensor(),
                                    spike_array_path=spike_array_path, idx=train_idx,
                                    target_idx=train_target_idx, silhouettes=True)
        silhouette_train_set = make_train_set(model, raw_silhouette_train_set)
        train_set = torch.utils.data.ConcatDataset([train_set, silhouette_train_set])


    return train_set, val_set, test_set, idx_dict


def l2_loss(output, target):
    '''
    Standard l2 loss
    Returns: tuple of overall loss and loss per recording channel
    '''
    assert output.shape == target.shape

    loss = (output - target) ** 2
    per_neuron_loss = reduce(loss, 'images neurons -> neurons', 'mean')
    loss = reduce(loss, 'images neurons -> 1', 'mean')
    return loss, per_neuron_loss


def compute_ro(output, target):
    try:
        b, n = output.shape
        ros = []
        for i in range(n):
            ro = pearsonr(output[:, i], target[:, i])[0]
            ros.append(ro)
        ros = np.nan_to_num(ros)
        return ros
    except:
        return np.inf


def train(model, config, spike_array_path, idx_path, n_epochs, lambd=1, use_silhouettes=False):
    '''
    Trains the model
    Args:
        model: instance of class GaussianReadoutModel
        config: config containing training parameters
        spike_array_path: path to spike data
        name_path:
        idx_path: path to json containing stimuli shown during recordings
        n_epochs: n_epochs
        lambd: weight for regularizer
        use_silhouettes: whether to augment data using silhouettes

    Returns: dict containing training results
    '''
    if use_silhouettes:
        print('Including silhouettes for training.')
    # Initialize best parameters
    best_weights = torch.rand(model.readout._features.shape) * (1 / 1000)
    best_mu = torch.zeros(model.readout.mu.shape)
    best_sigma = torch.zeros(model.readout.sigma.shape)
    best_val_loss = torch.zeros(best_weights.shape[-1]) + 1e12

    train_set, val_set, test_set, idx_dict = build_datasets(model, config, idx_path, spike_array_path,
                                                  use_silhouettes=use_silhouettes)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
    print("Train set size:", len(train_set))
    if val_set is not None:
        print("Val set size:", len(val_set))
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=128)
    else:
        print("Not using a validation set.")

    model = model.to('cuda')
    model.eval()

    optimizer = torch.optim.Adam([
        {'params': [model.readout._mu, model.readout.sigma]},
        {'params': model.readout._features, 'lr': config['lr']}], lr=config['lr'])

    train_loss, val_loss = [], []
    train_corr, val_corr = [], []

    for epoch in range(n_epochs):
        # Train
        model.readout.train()
        for i, (x, target) in enumerate(train_loader):
            x, target = x.to('cuda'), target.to('cuda')
            output = model.readout(x)
            prediction_term, per_neuron_loss = l2_loss(output, target)
            l1_term = lambd.unsqueeze(0).unsqueeze(0).unsqueeze(0) * (torch.abs(model.readout._features) ** (1 / 2))
            l1_term = l1_term.squeeze()
            l1_term = reduce(l1_term, 'features neurons -> neurons', 'sum')
            l1_term = reduce(l1_term, 'neurons -> 1', 'mean')
            loss = prediction_term + l1_term
            if epoch > 0:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        if epoch == n_epochs - 1:   # Set best weights
            print('Setting best weights found during training...')
            model.readout._features = nn.Parameter(best_weights.to('cuda'))
            model.readout._mu = nn.Parameter(best_mu)
            model.readout.sigma = nn.Parameter(best_sigma)
            model = model.to('cuda')

        if epoch % 50 == 0 or epoch == n_epochs - 1:
            # Validate
            model.readout.eval()

            epoch_train_loss, epoch_val_loss = 0, 0
            epoch_per_neuron_loss_train = torch.zeros(best_val_loss.shape)
            epoch_per_neuron_loss_val = torch.zeros(best_val_loss.shape)

            # Evaluate on train set
            y, yhat = [], []
            with torch.no_grad():
                for i, (x, target) in enumerate(train_loader):
                    x, target = x.to('cuda'), target.to('cuda')
                    output = model.readout(x)
                    prediction_term, per_neuron_loss = l2_loss(output, target)
                    epoch_train_loss += prediction_term.detach()
                    epoch_per_neuron_loss_train += per_neuron_loss.detach().cpu()
                    y.append(target.detach())
                    yhat.append(output.detach())
                y = torch.cat(y).cpu().numpy()
                yhat = torch.cat(yhat).cpu().numpy()
                train_ros = compute_ro(y, yhat)

            epoch_train_loss /= len(train_set)
            epoch_per_neuron_loss_train /= len(train_set)
            train_loss.append(epoch_train_loss.cpu().numpy())
            train_corr.append(np.mean(train_ros))


            if val_set is not None:
                # Evaluate on val set
                y, yhat = [], []
                with torch.no_grad():
                    for i, (x, target) in enumerate(val_loader):
                        x, target = x.to('cuda'), target.to('cuda')
                        output = model(x)
                        prediction_term, per_neuron_loss = l2_loss(output, target)
                        epoch_val_loss += prediction_term
                        epoch_per_neuron_loss_val += per_neuron_loss.detach().cpu()
                        y.append(target)
                        yhat.append(output)
                    y = torch.cat(y).cpu().numpy()
                    yhat = torch.cat(yhat).cpu().numpy()
                    val_ros = compute_ro(y, yhat)
                epoch_val_loss /= len(val_set)
                epoch_per_neuron_loss_val /= len(val_set)
                val_loss.append(epoch_val_loss.cpu().numpy())
                val_corr.append(np.mean(val_ros))
            else:
                epoch_val_loss, val_ros = None, np.array([0])


            # save best parameters
            for n in range(best_weights.shape[-1]):
                if epoch_per_neuron_loss_val[n] < best_val_loss[n]:
                    best_val_loss[n] = epoch_per_neuron_loss_val[n]
                    best_weights[:, :, :, n] = model.readout._features[:, :, :, n]
                    best_mu[:, n, :, :] = model.readout._mu[:, n, :, :]
                    best_sigma[:, n, :, :] = model.readout.sigma[:, n, :, :]


            print('Epoch: {}. Train Loss: {}. Val Loss: {}. Train Corr: {}. Val Corr: {}'.format(epoch, epoch_train_loss.item(), epoch_val_loss.item(), np.mean(train_ros), np.mean(val_ros)))


    out_dict = {'Model': model, 'Train Loss': epoch_per_neuron_loss_train, 'Val Loss': epoch_per_neuron_loss_val,
                'idx_dict': idx_dict}
    return out_dict


def run_training(model, config, spike_array_path):
    '''
    Wrapper function for model training pipeline. This is the function called by main.py
    Args:
        model: instance of class GaussianReadoutModel
        spike_array_path: path to neural data
        config: dict containing training parameters

    Returns: dict containing training results
    '''
    t0 = time.time()

    n_epochs = config['n_epochs']
    base_path = config['base_path']

    idx_path = join(base_path, config['idx_path'])

    spike_matrix = np.load(spike_array_path)
    n_neurons = spike_matrix.shape[-1]

    # Initialize model metrics
    lambd = config['lambda']
    lambd = lambd * torch.ones(n_neurons).to('cuda')

    dict = train(model, config, spike_array_path, idx_path, n_epochs=n_epochs, lambd=lambd,
                 use_silhouettes=True)

    print("Time for fit:", time.time() - t0)
    print()
    return dict