'''
Code for running the parallel backpropagation pipeline.

RUN:
python parallel_backpropagation.py -d day_05_03_24
'''

import argparse
import os
import pickle

import torch
import torchvision
from einops import rearrange
from tqdm import tqdm

from evaluate.build_model import build_model
from evaluate.post_analysis.PostData import PostData
from evaluate_on_objects import compute_activations
from evaluate.utils import *
from model.Dataset import FullDataset, ImageDataset
from model.Gradients import ParallelBackprop, IntegratedParallelBackprop


def parallel_backprop(recording_day, model, use_integrated_gradients=False, get_negative_examples=False,
                      use_weak_activators=False):
    base_path = os.getcwd()
    print('Computing gradients for:', recording_day)
    if use_integrated_gradients:
        print('Using integrated gradients.')
    if get_negative_examples:
        print('Gathering negative examples (low similarity images)')
    if use_weak_activators:
        print('Using weakly activating objects instead of strongly activating ones.')

    # Load configs and paths
    config_path = join(base_path, 'evaluate/configs', 'training.json')
    config = load_config(config_path=config_path)

    # Set correct paths for this machine
    config = set_config(config, base_path)

    # Save paths to config
    with open(join(base_path, 'evaluate/configs/training.json'), 'w') as f:
        json.dump(config, f)

    post_data = PostData(config, recording_day)
    spike_array = np.copy(post_data.spike_matrix)

    pre_spike_array_path = join(config['base_path'], 'submission_data', 'spike_data', recording_day) + '.npy'

    model.eval()
    model.readout.eval()

    if not use_integrated_gradients:
        gradient_analysis = ParallelBackprop(model.to('cuda'), 'cuda')
    else:
        gradient_analysis = IntegratedParallelBackprop(model.to('cuda'), 'cuda')

    # Load avatar set
    with open(config['idx_path'], 'rb') as fp:
        idx_dict = pickle.load(fp)
    avatar_idx = sorted(list(idx_dict['val']))
    avatar_target_idx = range(len(avatar_idx))

    avatar_set = FullDataset(stimulus_path=join(config['base_path'], config['stimulus_path']),
                                       name_path=None, transform=torchvision.transforms.ToTensor(),
                                        spike_array_path=pre_spike_array_path, idx=avatar_idx, target_idx=avatar_target_idx)

    # Load object set
    object_set = ImageDataset(stimulus_path=join(config['base_path'], config['object_path']))
    domain_rows = post_data.object_rows

    # Compute latent features of both sets. Necessary to compute pairwise similarities s()
    print('Computing object activations...')
    object_a_, object_y = compute_activations(model, object_set, return_predictions=True)
    print('Computing avatar activations...')
    avatar_a_, avatar_y = compute_activations(model, avatar_set, avatar_set=True, return_predictions=True)

    save_path = join(base_path, 'plots', 'gradients', recording_day)

    if use_integrated_gradients:
        save_path = save_path + '_integrated_gradients'

    if get_negative_examples:
        save_path = save_path + '_negative_examples'

    if use_weak_activators:
        save_path = save_path + '_weak_activators'

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print('Saving to: ', save_path)

    print('Looping over neurons to compute parallel backprop...')
    for neuron_idx in tqdm(range(spike_array.shape[-1])):
        # Get readout weights of model to compute s()
        w = model.readout._features.detach().cpu().numpy().squeeze()[:, neuron_idx]
        # Get top-activating bodies
        best_avatar_idx = np.argsort(spike_array[post_data.avatar_rows, neuron_idx])[-15:-1]
        # Get top-activating objects
        object_mask = np.zeros(spike_array.shape)
        object_mask[domain_rows, :] = 1
        best_object_matrix_idx = np.argsort((object_mask * spike_array)[:, neuron_idx])[-6:-1] # index in terms of spike matrix
        if use_weak_activators:
            object_mask = np.zeros(spike_array.shape) + np.infty
            object_mask[domain_rows, :] = 1
            best_object_matrix_idx = np.argsort((object_mask * spike_array)[:, neuron_idx])[
                                     :5]  # index in terms of spike matrix
        best_object_idx = [int(post_data.filenames[id].split('_')[1][:7]) for id in best_object_matrix_idx] # index in terms of object dataset

        # Weight latent features based on their relevance for the neurons
        weighted_avatar = avatar_a_[best_avatar_idx, :, neuron_idx] * w
        weighted_object = object_a_[best_object_idx, :, neuron_idx] * w

        weighted_avatar = weighted_avatar / torch.linalg.norm(weighted_avatar, axis=1, keepdim=True)
        weighted_object = weighted_object / torch.linalg.norm(weighted_object, axis=1, keepdim=True)

        # Compute pairwise similarities between bodies and objects
        sim = weighted_avatar @ weighted_object.T
        sim = sim.numpy()
        if not get_negative_examples and not use_weak_activators:
            avatar_id, object_id = np.unravel_index(np.argmax(sim), sim.shape) # will be in range of 0,...4
        else:
            # Get the object image which has the lowest maximum similarity
            avatar_id_per_object = np.argmax(sim, axis=0)   # 'avatar object -> object'
            max_sim_per_object = np.max(sim, axis=0)
            object_id = np.argmin(max_sim_per_object)
            avatar_id = avatar_id_per_object[object_id]


        best_avatar_id = best_avatar_idx[avatar_id]
        best_object_id = best_object_idx[object_id]

        # Select the chosen images within their respective datasets
        avatar_img = avatar_set[best_avatar_id][0]
        object_img = object_set[best_object_id]

        imgs = [object_img.cuda(), avatar_img.cuda()]

        # compute gradients
        if not use_integrated_gradients:
            grads = gradient_analysis.joint_gradient(imgs, neuron=neuron_idx)
        elif use_integrated_gradients:
            grads = gradient_analysis.integrated_gradient(imgs, neuron=neuron_idx,)

        # Save images
        save_id = (5 - len(str(neuron_idx))) * '0' + str(neuron_idx)

        # plot and save
        gradient_analysis.show_side_by_side([rearrange(img, 'c h w -> h w c').cpu().numpy() for img in imgs], grads,
                                            save_path=join(save_path, save_id))


def main():
    base_path = os.getcwd()
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--recording_day',
                        help='name of the np file containing data of the recoding day', type=str,
                        default='day_05_03_24')
    parser.add_argument('-m', '--model_checkpoint',
                        help='path to the model checkpoint containing the readout weights', type=str,
                        default=None)
    parser.add_argument('--integrated_gradients',
                        help='if TRUE, use integrated gradients instead of standard gradients', type=bool, default=False)
    parser.add_argument('--negative_examples',
                        help='if TRUE, use use objects without similarity to any of the bodies', type=bool, default=False)
    parser.add_argument('--weak_activators',
                        help='if TRUE, use weakly activating objects', type=bool,
                        default=False)

    args = parser.parse_args()

    # build model
    if args.model_checkpoint is None:
        args.model_checkpoint = join(base_path, 'submission_data',
                                     'pretrained_models', args.recording_day + '_0.1', 'model')

    # Load configs and paths
    config_path = join(base_path, 'evaluate/configs', 'training.json')
    config = load_config(config_path=config_path)
    config = set_config(config, base_path)

    # Load spike array, only needed to build the model correctly
    pre_spike_array_path = join(config['base_path'], 'submission_data', 'spike_data', args.recording_day) + '.npy'
    model = build_model(config, pre_spike_array_path)
    print('Loading weights from:', args.model_checkpoint)
    try:
        model.readout.load_state_dict(torch.load(args.model_checkpoint))
    except:
        raise RuntimeError('Error trying to load the model. Does the model match the recording day?')

    parallel_backprop(args.recording_day, model, args.integrated_gradients, args.negative_examples,
                      args.weak_activators)

if __name__ == '__main__':
    main()