clear;clc;
%%
load('./results/sort50_ebrahimi.mat')
clearvars -except eval_lst methods
[results] = compute_evaluation_metrics(eval_lst);
dspname = {};
for i=1:12
    method = methods{i};
    dspname{i} = get_legend_name(method);
end
clearvars -except methods results
%% get human average performance
load("ebrahimi_reanalysis.mat")
choices_all = cell(1,4);
choices_all{1} = guinea_pig_choices;
choices_all{2} = lion_choices;
choices_all{3} = dragon_choices;
choices_all{4} = cheetah_choices;

choices_gt_all = cell(1,4);
choices_gt_all{1} = 2*((lion_choices+dragon_choices+cheetah_choices)>0)-1;
choices_gt_all{2} = 2*((guinea_pig_choices+dragon_choices+cheetah_choices)>0)-1;
choices_gt_all{3} = 2*((guinea_pig_choices+lion_choices+cheetah_choices)>0)-1;
choices_gt_all{4} = 2*((guinea_pig_choices+lion_choices+dragon_choices)>0)-1;


matrix_acc_human = zeros(1,4);
matrix_tpr_human = zeros(1,4);
matrix_tnr_human = zeros(1,4);
matrix_precision_human = zeros(1,4);
matrix_recall_human    = zeros(1,4);
matrix_fscore_human    = zeros(1,4);


for ann=1:4
    ann_idx = ann;
    choices    = choices_all{ann_idx};
    choices_gt = choices_gt_all{ann_idx};

    eval_metrics_human = get_ex_accuracy(choices', choices_gt');
    i = ann;
    matrix_acc_human(i) = eval_metrics_human.ACC;
    matrix_tpr_human(i) = eval_metrics_human.TPR;
    matrix_tnr_human(i) = eval_metrics_human.TNR;
    matrix_precision_human(i) = eval_metrics_human.Precision;
    matrix_recall_human(i)    = eval_metrics_human.Recall;
    matrix_fscore_human(i)    = eval_metrics_human.Fscore;
end

avg_acc_human = mean(matrix_acc_human);
avg_tpr_human = mean(matrix_tpr_human);
avg_tnr_human = mean(matrix_tnr_human);
avg_precision_human = mean(matrix_precision_human);
avg_recall_human    = mean(matrix_recall_human);
avg_fscore_human    = mean(matrix_fscore_human);

avg_acc_human_std = std(matrix_acc_human) / sqrt(4);
avg_tpr_human_std = std(matrix_tpr_human) / sqrt(4);
avg_tnr_human_std = std(matrix_tnr_human) / sqrt(4);
avg_precision_human_std = std(matrix_precision_human) / sqrt(4);
avg_recall_human_std    = std(matrix_recall_human) / sqrt(4);
avg_fscore_human_std    = std(matrix_fscore_human) / sqrt(4);
%%
method_indices = [1,2,3,6,8,10,12];

color_map = [0,      0.4470, 0.7410; % blue
             0.8500, 0.3250, 0.0980; % orange
             0.9290, 0.6940, 0.1250; % yellow
             0.4940, 0.1840, 0.5560; % purple
             0.4660, 0.6740, 0.1880; % green
             0.6350, 0.0780, 0.1840; % red
             0.3010 0.7450 0.9330];  % light blue
% Dims of results: num_methods x lam (lam == 4 is the auto option) 
% x annotators x 558 (up to 50%)
lam_all = [1e-3,1e-2,1e-1,1,10];
H = 500;
% x = linspace(1,50,H);
x = 1:1:H;
x = x./10;
ACC = squeeze(results.ACC(:,4,:,1:H));
TPR = squeeze(results.TPR(:,4,:,1:H));
TNR = squeeze(results.TNR(:,4,:,1:H));
AUC = squeeze(results.AUC(:,4,:,1:H));
Precision = squeeze(results.Precision(:,4,:,1:H));
Recall = squeeze(results.Recall(:,4,:,1:H));
Fscore = squeeze(results.Fscore(:,4,:,1:H));

avg_acc = squeeze(mean(ACC,2));
avg_tpr = squeeze(mean(TPR,2));
avg_tnr = squeeze(mean(TNR,2));
avg_precision = squeeze(mean(Precision,2));
avg_recall = squeeze(mean(Recall,2));
avg_fscore = squeeze(mean(Fscore,2));
avg_auc = squeeze(mean(AUC,2));

ACC_std = squeeze(results.ACC_std(:,4,:,1:H));
TPR_std = squeeze(results.TPR_std(:,4,:,1:H));
TNR_std = squeeze(results.TNR_std(:,4,:,1:H));
Precision_std = squeeze(results.Precision_std(:,4,:,1:H));
Recall_std = squeeze(results.Recall_std(:,4,:,1:H));
Fscore_std = squeeze(results.Fscore_std(:,4,:,1:H));
AUC_std = squeeze(results.AUC_std(:,4,:,1:H));

avg_acc_std = squeeze(mean(ACC_std, 2));
avg_tpr_std = squeeze(mean(TPR_std, 2));
avg_tnr_std = squeeze(mean(TNR_std, 2));
avg_precision_std = squeeze(mean(Precision_std, 2));
avg_recall_std = squeeze(mean(Recall_std, 2));
avg_fscore_std = squeeze(mean(Fscore_std, 2));
avg_auc_std = squeeze(mean(AUC_std, 2));

plot_metric(avg_auc, avg_auc_std, 0, 0, methods, method_indices, x, H, color_map, 'AUC', [0.8,1])
plot_metric(avg_acc, avg_acc_std, avg_acc_human, avg_acc_human_std, methods, method_indices, x, H, color_map, 'Accuracy', [0.6,1])
plot_metric(avg_tpr, avg_tpr_std, avg_tpr_human, avg_tpr_human_std, methods, method_indices, x, H, color_map, 'True Positive Rate', [0.7,1])
plot_metric(avg_tnr, avg_tnr_std, avg_tnr_human, avg_tnr_human_std, methods, method_indices, x, H, color_map, 'True Negative Rate', [0.6,1])
plot_metric(avg_precision, avg_precision_std, avg_precision_human, avg_precision_human_std, methods, method_indices, x, H, color_map, 'Precision', [0.7,1])
plot_metric(avg_recall, avg_recall_std, avg_recall_human, avg_recall_human_std, methods, method_indices, x, H, color_map, 'Recall', [0.7,1])
plot_metric(avg_fscore, avg_fscore_std, avg_recall_human, avg_fscore_human_std, methods, method_indices, x, H, color_map, 'F-score', [0.7,1])


% plot(x,squeeze(mean(ACC,2)))
% legend(methods)
% xlabel('Percentage of sorted')
% ylabel('ACC')