import numpy as np
import pickle

def import_data(data_directory, sample_size=1, fraction_neuron=1.0, *args, **kwargs):
    '''
    Loads the data of a task-tained RNN of motor perturbation learning in order to fit an ltrRNN to it.

    :param data_directory: directory containing the pickled data
    :param sample_size: the RNN was trained with batching, but we may only want 1 trial per batch.
    :param fraction_neuron: whether to only sample a fraction of neurons.
    :param args: ignored.
    :param kwargs: ignored.

    :return: rnn_activity: (ndarray) time x trial x neuron. The activity of the RNN over learning.
    :return: data: (dict) additional variables. The only one used in fit_data.py is the field 'epoch' containing a #trial-dimensional array indicating the epoch (baseline, perturbed, washout).
    :return: angles: (ndarray) trial. The angle of the reach (0 to 2pi).
    :return: condition: (ndarray) trial. The condition i.e. the target number (integer).
    :return: times: (ndarray) time. The time within a trial.
    :return: epoch_values: (ndarray) 3. The trial at which the epoch changes.
    :return: trial_id: (ndarray) trial. The true id of a trial, in case some trials are discarded.
    '''

    with open(data_directory+'/data.pkl', 'rb') as f:
        data = pickle.load(f)

    rnn_activity = []
    condition = []
    for i in range(len(data['condition'])):
        id = np.random.permutation(data['condition'].shape[1])
        for j in range(sample_size):
            rnn_activity.append(data['rnn_activity'][i,:,id[j]])
            condition.append(data['condition'][i,id[j]])
    rnn_activity = np.stack(rnn_activity, axis=1)
    condition = np.array(condition)

    condition = (condition*8).astype(int)

    epochs = data['epochs']

    epoch = np.array(['BL' for i in range(sample_size*epochs['perturbation'])] +
                     ['AD' for i in range(sample_size*(epochs['washout']-epochs['perturbation']))] +
                     ['WO' for i in range(sample_size*(data['condition'].shape[0]-epochs['washout']))])

    go_cue = data['additional_information']['preparatory_duration']+data['additional_information']['random_duration']/2

    angles = np.pi*2*condition/(np.max(condition)+1)

    times = data['time']-go_cue

    data['epoch'] = epoch

    rnn_activity = rnn_activity[..., np.random.rand(rnn_activity.shape[2])<fraction_neuron]

    epoch_values = list(epochs.values())
    trial_id = np.arange(len(condition))

    return rnn_activity, data, angles, condition, times, epoch_values, trial_id
