% this script plot the hemisphere AL results, including the average
% performance for each method, the std for each method, and the comparison
% with human performance. 
%% load data
clear;clc;
load('./results-balance0/sort50_allmethods.mat')
%%
num_dataset = 3;
num_anns = 4;
%% average over the 12 annotators
H1 = size(eval_lst{1,1}.ACC,2);
H2 = size(eval_lst{2,1}.ACC,2);
H3 = size(eval_lst{3,1}.ACC,2);
H = min([H1, H2, H3]);
H = 30;

avg_acc = zeros(num_methods, H);
avg_tpr = zeros(num_methods, H);
avg_tnr = zeros(num_methods, H);
avg_precision = zeros(num_methods, H);
avg_recall    = zeros(num_methods, H);
avg_fscore    = zeros(num_methods, H);
avg_auc       = zeros(num_methods, H);

avg_acc_std = zeros(num_methods, H);
avg_tpr_std = zeros(num_methods, H);
avg_tnr_std = zeros(num_methods, H);
avg_precision_std = zeros(num_methods, H);
avg_recall_std    = zeros(num_methods, H);
avg_fscore_std    = zeros(num_methods, H);
avg_auc_std       = zeros(num_methods, H);

for k=1:num_methods
    for d = 1:3
        for ann = 1:4
            avg_acc(k,:) = avg_acc(k,:) + eval_lst{d, ann, k}.ACC(1:H);
            avg_tpr(k,:) = avg_tpr(k,:) + eval_lst{d, ann, k}.TPR(1:H);
            avg_tnr(k,:) = avg_tnr(k,:) + eval_lst{d, ann, k}.TNR(1:H);
            avg_precision(k,:) = avg_precision(k,:) + eval_lst{d, ann, k}.Precision(1:H);
            avg_recall(k,:)    = avg_recall(k,:) + eval_lst{d, ann, k}.Recall(1:H);
            avg_fscore(k,:)    = avg_fscore(k,:) + eval_lst{d, ann, k}.Fscore(1:H);
            avg_auc(k,:)       = avg_auc(k,:) + eval_lst{d, ann, k}.AUC(1:H);

            avg_acc_std(k,:) = avg_acc_std(k,:) + eval_lst{d, ann, k}.ACC_std(1:H);
            avg_tpr_std(k,:) = avg_tpr_std(k,:) + eval_lst{d, ann, k}.TPR_std(1:H);
            avg_tnr_std(k,:) = avg_tnr_std(k,:) + eval_lst{d, ann, k}.TNR_std(1:H);
            avg_precision_std(k,:) = avg_precision_std(k,:) + eval_lst{d, ann, k}.Precision_std(1:H);
            avg_recall_std(k,:)    = avg_recall_std(k,:) + eval_lst{d, ann, k}.Recall_std(1:H);
            avg_fscore_std(k,:)    = avg_fscore_std(k,:) + eval_lst{d, ann, k}.Fscore_std(1:H);
            avg_auc_std(k,:)       = avg_auc_std(k,:) + eval_lst{d, ann, k}.AUC_std(1:H);
        end
    end
    avg_acc(k,:) = avg_acc(k,:) ./ 12;
    avg_tpr(k,:) = avg_tpr(k,:) ./ 12;
    avg_tnr(k,:) = avg_tnr(k,:) ./ 12;
    avg_precision(k,:) = avg_precision(k,:) ./ 12;
    avg_recall(k,:)    = avg_recall(k,:) ./ 12;
    avg_fscore(k,:)    = avg_fscore(k,:) ./ 12;
    avg_auc(k,:)       = avg_auc(k,:) ./ 12;

    avg_acc_std(k,:) = avg_acc_std(k,:) ./ 12;
    avg_tpr_std(k,:) = avg_tpr_std(k,:) ./ 12;
    avg_tnr_std(k,:) = avg_tnr_std(k,:) ./ 12;
    avg_precision_std(k,:) = avg_precision_std(k,:) ./ 12;
    avg_recall_std(k,:)    = avg_recall_std(k,:) ./ 12;
    avg_fscore_std(k,:)    = avg_fscore_std(k,:) ./ 12;
    avg_auc_std(k,:)       = avg_auc_std(k,:) ./ 12;
end
%% get human average performance
matrix_acc_human = zeros(1,3*4);
matrix_tpr_human = zeros(1,3*4);
matrix_tnr_human = zeros(1,3*4);
matrix_precision_human = zeros(1,3*4);
matrix_recall_human    = zeros(1,3*4);
matrix_fscore_human    = zeros(1,3*4);

for d=1:3
    for ann=1:4
        ann_idx = annotator_lst(d, ann);
        choices    = choices_all{ann_idx,d};
        choices_gt = choices_gt_all{ann_idx,d};

        eval_metrics_human = get_ex_accuracy(choices', choices_gt');
        i = (d-1)*4 +ann;
        matrix_acc_human(i) = eval_metrics_human.ACC;
        matrix_tpr_human(i) = eval_metrics_human.TPR;
        matrix_tnr_human(i) = eval_metrics_human.TNR;
        matrix_precision_human(i) = eval_metrics_human.Precision;
        matrix_recall_human(i)    = eval_metrics_human.Recall;
        matrix_fscore_human(i)    = eval_metrics_human.Fscore;
    end
end

avg_acc_human = mean(matrix_acc_human);
avg_tpr_human = mean(matrix_tpr_human);
avg_tnr_human = mean(matrix_tnr_human);
avg_precision_human = mean(matrix_precision_human);
avg_recall_human    = mean(matrix_recall_human);
avg_fscore_human    = mean(matrix_fscore_human);

avg_acc_human_std = std(matrix_acc_human) / sqrt(12);
avg_tpr_human_std = std(matrix_tpr_human) / sqrt(12);
avg_tnr_human_std = std(matrix_tnr_human) / sqrt(12);
avg_precision_human_std = std(matrix_precision_human) / sqrt(12);
avg_recall_human_std    = std(matrix_recall_human) / sqrt(12);
avg_fscore_human_std    = std(matrix_fscore_human) / sqrt(12);

%% plotting
% [max_values, max_indices] = maxk(avg_acc(4:end, 50), 3);
method_indices = [1,2,3,6,8,10,12];
% method_indices = [1,2,3,6, 21];


color_map = [0,      0.4470, 0.7410; % blue
             0.8500, 0.3250, 0.0980; % orange
             0.9290, 0.6940, 0.1250; % yellow
             0.4940, 0.1840, 0.5560; % purple
             0.4660, 0.6740, 0.1880; % green
             0.6350, 0.0780, 0.1840; % red
             0.3010 0.7450 0.9330];  % light blue

x = 1:1:H;
x = x./10;

% plot_metric(avg_auc, avg_auc_std, 0, 0, methods, method_indices, x, H, color_map, 'AUC', [0.7,1])
% plot_metric(avg_acc, avg_acc_std, avg_acc_human, avg_acc_human_std, methods, method_indices, x, H, color_map, 'Accuracy', [0.8,0.87])
% plot_metric(avg_tpr, avg_tpr_std, avg_tpr_human, avg_tpr_human_std, methods, method_indices, x, H, color_map, 'True Positive Rate', [0.97,0.995])
% plot_metric(avg_tnr, avg_tnr_std, avg_tnr_human, avg_tnr_human_std, methods, method_indices, x, H, color_map, 'True Negative Rate', [0.6,0.78])
% plot_metric(avg_precision, avg_precision_std, avg_precision_human, avg_precision_human_std, methods, method_indices, x, H, color_map, 'Precision', [0.9,1])
% plot_metric(avg_recall, avg_recall_std, avg_recall_human, avg_recall_human_std, methods, method_indices, x, H, color_map, 'Recall', [0.9,1])
% plot_metric(avg_fscore, avg_fscore_std, avg_recall_human, avg_fscore_human_std, methods, method_indices, x, H, color_map, 'F-score', [0.9,1])

% save("hemisphere_al_plotting_balance1_results.mat", "avg_auc", "avg_auc_std", "avg_acc", "avg_acc_std", ...
%     "avg_tpr", "avg_tpr_std", "avg_tnr", "avg_tnr_std", "avg_precision", "avg_precision_std", ...
%     "avg_recall", "avg_recall_std", "avg_fscore", "avg_fscore_std", "-v7.3")
%% print results
num_methods = length(methods);
fprintf("Human performance: ACC: %.4f, TPR: %.4f, TNR: %.4f, Precision: %.4f, Recall: %.4f, F-score: %.4f\n", ...
                avg_acc_human, avg_tpr_human, avg_tnr_human, avg_precision_human, avg_recall_human, avg_fscore_human)
for i=1:length(method_indices)
    k = method_indices(i);
    method = methods{k};
    dspname = get_legend_name(method);

    auc = avg_auc(k, end);
    acc = avg_acc(k, end);
    tpr = avg_tpr(k, end);
    tnr = avg_tnr(k, end);
    precision = avg_precision(k, end);
    recall    = avg_recall(k, end);
    fscore    = avg_fscore(k, end);

    fprintf("Method %s performance: ACC: %.4f, TPR: %.4f, TNR: %.4f, Precision: %.4f, Recall: %.4f, F-score: %.4f, AUC: %.4f\n", ...
                dspname, acc, tpr, tnr, precision, recall, fscore, auc)
end