'''
Script to produce the plot showing responses to the objects for which responses are visualized
in the manuscript.

RUN:
python evaluate/post_analysis/global_analysis/plot_obj_responses.py
'''

import pickle
import matplotlib.pyplot as plt
import numpy as np
import os

def remove_leading_zeros(s):
    import re
    # This pattern looks for "_unit" followed by any number of zeros and optionally digits
    pattern = r'(_unit)0+(\d*)'

    # Using re.sub to replace the found pattern, ensuring we check if digits are present after zeros
    new_string = re.sub(pattern, lambda m: m.group(1) + (m.group(2) if m.group(2) else '0'), s)

    return new_string

def main():

    with open(os.path.join(os.getcwd(), 'evaluate/post_analysis/global_analysis\selected_object_dict'), "rb") as file:
        selected_object_dict = pickle.load(file)

    # units used in the manuscript
    monkeyG_msb_keys = ['day_05_03_24_unit00030', 'day_05_03_24_unit00029', 'day_12_02_24_unit00001', 'day_12_02_24_unit00005',
                      'day_05_03_24_unit00028', 'day_01_03_24_unit00025', 'day_27_02_24_unit00022', 'day_05_03_24_unit00021',
                      'day_09_02_24_unit00015']

    monkeyT_msb_keys = ['day_30_04_24_unit00004', 'day_30_04_24_unit00014', 'day_29_04_24_unit00008',
                      'day_30_04_24_unit00002', 'day_06_05_24_unit00017', 'day_06_05_24_unit00012']

    monkeyG_asb_keys = ['day_05_03_24_unit00012', 'day_23_02_24_unit00008', 'day_23_02_24_unit00003', 'day_01_03_24_unit00010',
                      'day_23_02_24_unit00013', 'day_01_03_24_unit00012', 'day_29_02_24_unit00001', 'day_01_03_24_unit00001']


    monkeyT_asb_keys = ['day_02_05_24_unit00002', 'day_02_05_24_unit00003', 'day_06_05_24_unit00004']

    # wrangle filenames
    monkeyG_msb_keys = [remove_leading_zeros(s) for s in monkeyG_msb_keys]
    monkeyG_asb_keys = [remove_leading_zeros(s) for s in monkeyG_asb_keys]
    monkeyT_msb_keys = [remove_leading_zeros(s) for s in monkeyT_msb_keys]
    monkeyT_asb_keys = [remove_leading_zeros(s) for s in monkeyT_asb_keys]

    plt.rcParams.update({'font.size': 7})
    plt.rcParams['pdf.fonttype'] = 42
    marker_size = 4

    pos = [i for i in range(len(selected_object_dict.keys()) - 1)]
    i = 0
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5.5 / 2.54, 3 / 2.54))
    for unit in monkeyG_msb_keys:
        if i == 0:
            ax.scatter(x=pos[i], y=5 * selected_object_dict[unit]['best_object_response'], label='Selected object',
                        c='#fbad27', s=marker_size)
            ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['body_response']), label='Body mean', c='#32bed7', s=marker_size)
            ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['other_object_response']), label='Object mean', c='#9b8476', s=marker_size)
        else:
            ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['body_response']), c='#32bed7', s=marker_size)
            ax.scatter(x=pos[i], y=5 * selected_object_dict[unit]['best_object_response'], c='#fbad27', s=marker_size)
            ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['other_object_response']),
                        c='#9b8476', s=marker_size)
        i += 1

    for unit in monkeyG_asb_keys:
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['body_response']), c='#32bed7', s=marker_size)
        ax.scatter(x=pos[i], y=5 * selected_object_dict[unit]['best_object_response'], c='#fbad27', s=marker_size)
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['other_object_response']),
                    c='#9b8476', s=marker_size)
        i += 1


    for unit in monkeyT_msb_keys:
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['body_response']), c='#32bed7', s=marker_size)
        ax.scatter(x=pos[i], y=5 * selected_object_dict[unit]['best_object_response'], c='#fbad27', s=marker_size)
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['other_object_response']),
                    c='#9b8476', s=marker_size)
        i += 1

    for unit in monkeyT_asb_keys:
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['body_response']), c='#32bed7', s=marker_size)
        ax.scatter(x=pos[i], y=5 * selected_object_dict[unit]['best_object_response'], c='#fbad27', s=marker_size)
        ax.scatter(x=pos[i], y=5 * np.mean(selected_object_dict[unit]['other_object_response']),
                    c='#9b8476', s=marker_size)
        i += 1



    ax.set_ylim([0,None])
    fig.suptitle('Responses to visualized objects')
    ax.set_xlabel('Recording Channel')
    ax.set_ylabel('Spike Rate (1/s)')
    ax.legend()
    plt.savefig(os.path.join(os.getcwd(), 'plots/obj_response')
                + '.pdf', bbox_inches='tight')
    plt.show()


if __name__ == '__main__':
    main()
