import matplotlib.pyplot as plt
import torch
from matplotlib import cm 
import matplotlib.patches as mpatch
import numpy as np
from tqdm import tqdm 
from my_utils import *

# plt.style.use('ggplot')
import seaborn as sns

sns.set()
sns.set_theme(style="darkgrid")
plt.rcParams['text.usetex'] = True

def print_results():
    modes = ['bottom', 'random', 'bottom_top', 'top']
    cmap = cm.get_cmap('RdYlGn')
    # mode_to_color = dict({mode:cmap(i/len(modes)) for i,mode in enumerate(modes)})
    mode_to_color = dict({mode:color for mode, color in zip(modes, ['red', 'black', 'blue', 'yellow'])})

    f, ax = plt.subplots(1,1)
    for mode in ['bottom', 'random', 'bottom_top', 'top']:
        for k in [10, 50, 100, 150, 200, 250]:
            # results = torch.load(f'ft_heads/resnet50/{mode}_{k}.pth')
            results = torch.load(f'ft_heads/deit_small/{mode}_{k}.pth')
            spur_gap, val_acc = [results[x] for x in ['spur_gap', 'val_acc']]
            print(f'Mode: {mode:<15}, k: {k:<3}, Spurious Gap: {spur_gap:.2f}%, Validation Acc: {val_acc:.2f}%')

            ax.scatter(spur_gap, val_acc, color=mode_to_color[mode], s=k/10)#, label=f'{mode}_{k}')

    ax.set_xlabel('Spurious Gap')
    ax.set_ylabel('ImageNet Val Accuracy')
    ax.legend(handles=[mpatch.Patch(color=mode_to_color[m], label=m) for m in modes], loc='lower left', ncol=2)
    f.savefig('test5.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)


plt.style.use('ggplot')
def gap_vs_accuracy():    
    f, ax = plt.subplots(1,1, figsize=(5,4.5))
    num_pts = 10

    mkeys = ['resnet50', 'deit_small', 'simclr_resnet50', 'moco_vit-s', 'robust_resnet50_linf_eps4.0']
    cmap = cm.get_cmap('tab20')
    colors = [cmap(i/10) for i in range(len(mkeys))]
    mkey_to_color = dict({m:c for m,c in zip(mkeys, colors)})

    arrow_handles = []
    # for mkey, mcolor in tqdm(zip(mkeys, colors)):
    for mkey in tqdm(mkeys):
        mcolor = mkey_to_color[mkey]
        num_epochs=500
        # for mode, color in [('bottom', 'black'), ('random', 'gray')]:
        for mode, color, marker, s, ls in [('bottom', mcolor, '*', 150, '--'), ('random', 'gray', 's', 50, '-')]:
            d = torch.load(f'ft_heads3/{mkey}/{mode}_100.pth')
            val_accs, spur_gaps = [d[x] for x in ['val_accs', 'spur_gaps']]
            best_val_acc, best_spur_gap = [d[x] for x in ['val_acc', 'spur_gap']]
            

            val_accs = d['val_accs']
            num_epochs = min(num_epochs, len(val_accs))
            val_accs, spur_gaps = [x[:num_epochs] for x in [val_accs, spur_gaps]]
            ax.scatter(val_accs[0], spur_gaps[0], marker='o', color=mcolor)#'black')

            num_pts = num_epochs
            period = num_epochs // min(num_pts, num_epochs)
            condensed_val_accs = [val_accs[0]] + [np.mean(val_accs[i*period:(i+1)*period]) for i in range(num_pts)] + [val_accs[-1]]
            condensed_spur_gaps = [spur_gaps[0]] + [np.mean(spur_gaps[i*period:(i+1)*period]) for i in range(num_pts)] + [spur_gaps[-1]]
            # ax.plot(condensed_val_accs, condensed_spur_gaps, '--', color=mcolor if mode=='bottom' else 'gray')
            # ax.scatter(val_accs[-1], spur_gaps[-1], marker=marker, color=mcolor if mode=='bottom' else 'gray', s=s)


            ax.scatter(best_val_acc, best_spur_gap, marker=marker, color=mcolor if mode=='bottom' else 'gray', s=s)
            ax.arrow(val_accs[0], spur_gaps[0], (best_val_acc-val_accs[0]), (best_spur_gap-spur_gaps[0]), 
                    color=color, length_includes_head=True, width=0.05, head_width=0.7 if mode=='bottom' else 0.1, linestyle=ls)


            # ax.scatter(val_accs[-1], spur_gaps[-1], marker=marker, color=mcolor if mode=='bottom' else 'gray', s=s)

            # ax.arrow(val_accs[0], spur_gaps[0], (val_accs[-2]-val_accs[0]), (spur_gaps[-2]-spur_gaps[0]), length_includes_head=True)
            # ax.arrow(val_accs[0], d['spur_gaps'][0], (val_accs[-2]-val_accs[0]), (d['spur_gap']-d['spur_gaps'][0]), length_includes_head=True)
            # ax.arrow(val_accs[0], spur_gaps[0], (val_accs[-1] - val_accs[0]), (spur_gaps[-1]-spur_gaps[0]), 
            #         color=color, length_includes_head=True, width=0.05, head_width=0.3, linestyle=ls)



    ax.set_xlabel('ImageNet Accuracy', fontsize=20); ax.set_ylabel('Spurious Gap', fontsize=20)
    
    from matplotlib.lines import Line2D
    extra_handles = [Line2D([0], [0], color='gold', lw=0, markerfacecolor='gold', marker='*', label='Low Spuriosity\nTuned', markersize=10), 
                     Line2D([0], [0], color='gray', lw=0, markerfacecolor='gray', marker='s', label='Randomly Tuned'),
                     Line2D([0], [0], color='black', lw=0, markerfacecolor='black', marker='o', label='Original\n(No Tuning)')]

    arrow_handles = []#[ax.arrow(0,0,0,0, width=0, linestyle=ls, color=c, label=l) for c,l,ls in zip(['gray'], ['Random Tuning'], ['-'])]
    # arrow_handles = [ax.arrow(0,0,0,0, width=0, linestyle=ls, color=c, label=l) for c,l,ls in zip(['black', 'gray'], ['Low Spuriousity Tuning', 'Random Tuning'], ['--','-'])]

    mkeys_preferred_order = ['robust_resnet50_linf_eps4.0', 'simclr_resnet50', 'moco_vit-s', 'resnet50', 'deit_small']
    ax.legend(handles=[mpatch.Patch(color=mkey_to_color[m], label=nickname(m).replace('$\ell', '\n$\ell').replace("CLR ", "CLR\n")) 
            for m in mkeys_preferred_order]+extra_handles, ncol=1, bbox_to_anchor=(1.01, 0.97), fontsize=14) #loc='upper right', 
    # ax.legend(handles=[])
    ax.tick_params(axis='x', labelsize=12); ax.tick_params(axis='y', labelsize=12)
    f.savefig('plots/gap_vs_accuracy.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)
    # f.savefig('plots/gap_vs_accuracy_full_trajectory.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)

def accuracy_by_rank(smooth_c=7):
    f, axs = plt.subplots(1, 2, figsize=(5,4.5), sharey=True)

    mkeys = ['resnet50', 'deit_small', 'simclr_resnet50', 'moco_vit-s', 'robust_resnet50_linf_eps4.0']
    cmap = cm.get_cmap('tab20')
    colors = [cmap(i/10) for i in range(len(mkeys))]
    mkey_to_color = dict({m:c for m,c in zip(mkeys, colors)})

    for mkey in ['deit_small', 'resnet50', 'moco_vit-s', 'simclr_resnet50', 'robust_resnet50_linf_eps4.0']:
        for ax, mode in zip(axs, ['random', 'bottom']):
            d = torch.load(f'ft_heads3/{mkey}/{mode}_100.pth')
            accs_by_rank = d['accs_by_rank'] / d['num_spur_classes']
            smoothed_accs = np.array([np.nanmean(accs_by_rank[max(i-smooth_c,0):min(i+smooth_c, 49)]) for i in range(50)])

            ax.plot(np.arange(50), smoothed_accs, label=nickname(mkey).replace('$\ell', '\n$\ell').replace("CLR ", "CLR\n"))
    
    for ax in axs:
        ax.set_xlabel('Spuriosity Rank', fontsize=20)
    axs[0].set_ylabel('(Smoothed) Accuracy', fontsize=20)
    axs[1].legend(ncol=1, fontsize=14, bbox_to_anchor=(1.01,0.8))
    axs[0].set_title('Tuned on\nRandom Data')
    axs[1].set_title('Tuned on\nLow Spuriosity Data')
    for ax in axs:
        ax.tick_params(axis='x', labelsize=12); ax.tick_params(axis='y', labelsize=12)
    f.subplots_adjust(wspace=0.01)
    f.savefig('plots/acc_vs_rank.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)



if __name__ == '__main__':
    # print_results()
    gap_vs_accuracy()
    accuracy_by_rank()