from tqdm import tqdm
import h5py
import math
import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
import copy
import numpy as np
import os
import json
import scipy.stats
import time
from types import SimpleNamespace
import random
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable

matlab_xlim = (-108.0278, 108.0278)
matlab_ylim = (-72.9774, 72.9774)

save_dir = 'brain_plotting/'
base_path = 'brain_plotting/'
left_hem_file_name = 'left_hem_clean.png'
right_hem_file_name = 'right_hem_clean.png'
coords_file_name = 'elec_coords_full.csv'
correlations_file = 'lag_correlation.json'

left_hem_img = plt.imread(os.path.join(base_path, left_hem_file_name))
right_hem_img = plt.imread(os.path.join(base_path, right_hem_file_name))
coords_df = pd.read_csv(os.path.join(base_path, coords_file_name))
split_elec_id = coords_df['ID'].str.split('-')
coords_df['Subject'] = [t[0] for t in split_elec_id]
coords_df['Electrode'] = [t[1] for t in split_elec_id]

# # Scale Matlab electrode locations to Python format
def scale(x, s, d):
    return -(x - d) * s

x_scale = left_hem_img.shape[1] / (matlab_xlim[1] - matlab_xlim[0])
y_scale_l = left_hem_img.shape[0] / (matlab_ylim[1] - matlab_ylim[0])

y_scale_r = right_hem_img.shape[0] / (matlab_ylim[1] - matlab_ylim[0])

scaled_coords_df = coords_df.copy()

# scale left hemisphere coordinates
scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 1, 'X'] = coords_df.loc[coords_df['Hemisphere'] == 1, 'X'].apply(lambda x: scale(x, x_scale, matlab_xlim[1]))
scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 1, 'Y'] = coords_df.loc[coords_df['Hemisphere'] == 1, 'Y'].apply(lambda x: scale(x, y_scale_l, matlab_ylim[1]))

# scale right hemisphere coordinates
scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 0, 'X'] = coords_df.loc[coords_df['Hemisphere'] == 0, 'X'].apply(lambda x: -scale(x, y_scale_r, matlab_xlim[0]))
scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 0, 'Y'] = coords_df.loc[coords_df['Hemisphere'] == 0, 'Y'].apply(lambda x: scale(x, y_scale_r, matlab_ylim[1]))

def plot_hemispheres_separately(significant_electrodes=None, title=None, save_dir='brain_plotting/', electrode_values=None, circled_electrodes=None):
    fig, ax = plt.subplots(1, 1, figsize=(10,10))
    plot_hemisphere_axis(ax, hemisphere="left", significant_electrodes=significant_electrodes, title=title, save_dir=save_dir, electrode_values=electrode_values, anonymize=True, circled_electrodes=circled_electrodes)
    fig.tight_layout()
    plt.savefig(os.path.join(save_dir, f'left_all_electrodes.png'), bbox_inches='tight')
    plt.clf()
    plt.close()

    fig, ax = plt.subplots(1, 1, figsize=(10,10))
    plot_hemisphere_axis(ax, hemisphere="right", significant_electrodes=significant_electrodes, title=title, save_dir=save_dir, electrode_values=electrode_values, anonymize=True, circled_electrodes=circled_electrodes)
    fig.tight_layout()
    plt.savefig(os.path.join(save_dir, f'right_all_electrodes.png'), bbox_inches='tight')
    plt.clf()
    plt.close()

def plot_hemisphere(significant_electrodes=None, title=None, save_dir='brain_plotting/', electrode_values=None, circled_electrodes=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,30))
    plot_hemisphere_axis(ax1, hemisphere="left", significant_electrodes=significant_electrodes, title=title, save_dir=save_dir, electrode_values=electrode_values, circled_electrodes=circled_electrodes)
    plot_hemisphere_axis(ax2, hemisphere="right", significant_electrodes=significant_electrodes, title=title, save_dir=save_dir, electrode_values=electrode_values, circled_electrodes=circled_electrodes)
    plt.text(-0.10, 1.08, title,
         horizontalalignment='center',
         fontsize=20,
         transform=ax2.transAxes)
    #fig.suptitle(title)
    fig.tight_layout()
    plt.savefig(os.path.join(save_dir, f'all_electrodes.png'), bbox_inches='tight')
    plt.clf()
    plt.close()

def plot_hemisphere_axis(ax=None, hemisphere="left", significant_electrodes=None, title=None, save_dir='brain_plotting/', electrode_values=None, anonymize=False, vmin=None, vmax=None, circled_electrodes=None, fig_size="large"):
    #significant_electrodes is {sub_name: [list of elecs]}
    #electrode_values is {subj_name: {elec_name: value}}

    colors = np.random.random_sample(len(scaled_coords_df))

    ax.set_aspect('equal')

    if hemisphere=="left":
        ax.imshow(left_hem_img)
    elif hemisphere=="right":
        ax.imshow(right_hem_img)

    assert hemisphere in ["left", "right"]
    hem_index = 1 if hemisphere=="left" else 0

    selected = scaled_coords_df[(scaled_coords_df['Hemisphere'] == hem_index)]
    if significant_electrodes:
        #print(selected.keys())
        dfs = []
        for subject in significant_electrodes:
            s_electrodes = significant_electrodes[subject]
            dfs.append(selected[((selected.Subject == subject) & (selected.Electrode.isin(s_electrodes)))])
        selected = pd.concat(dfs)

    if len(selected)==0:
        return

    if hemisphere=="left":
        plot_title = 'Left Hemisphere'
    elif hemisphere=="right":
        plot_title = 'Right Hemisphere'

    if title:
       plot_title += f' {title}' 

    import matplotlib.cm as cm
    cmap = cm.tab20b.colors
    
    label_counter = 0
    size = 200
    if fig_size=="large":
        size = 300
    for s in set(selected.Subject.tolist()):
        x = list(selected[selected.Subject == s]['X'])
        y = list(selected[selected.Subject == s]['Y'])
        subject2int = lambda x : int(x[1:]) % 20
        colors = [cmap[subject2int(l)] for l in selected[selected.Subject == s]['Subject']]
        selected_subj = selected[selected.Subject == s]
        elec_subj = list(zip(selected_subj.Electrode.tolist(), selected_subj.Subject.tolist()))
        if electrode_values:
            colors = [electrode_values[es[1]][es[0]] for es in elec_subj]
        else:
            alphas = [1]*len(colors)

        subj_color = cmap[subject2int(s)]
        if circled_electrodes:
            edge_colors = ['red' if e in circled_electrodes[s] else 'black' for e in selected_subj.Electrode.tolist()]
            line_widths = [4 if x=="red" else 1 for x in edge_colors]
        else:
            edge_colors = ["black"]*len(colors)
            line_widths = [1]*len(colors)

        label = s
        if anonymize:
            label = f'Subject_{label_counter}'
            label_counter += 1
        sc = ax.scatter(x,y,
                        label=label,
                        s=size,
                        #alpha=alphas,
                        linewidths=line_widths,
                        c=colors,
                        edgecolors=edge_colors)
    if electrode_values:
        #https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(sc, cax=cax)
    if fig_size=="large":
        ax.legend(fontsize=16)
        plt.axis("off")
        ax.set_title(plot_title, fontsize=30)
    else:
        ax.set_title(plot_title)
        ax.legend()
