clear;clc;
%%
annotator_lst = [1, 2, 3, 4; % dataset1
                 1, 3, 4, 5; % dataset2
                 1, 2, 3, 6];% dataset3
load('hemisphere_dataset_summary.mat')
%%
method_rand = struct("name", "random", "continue", false);
method_cal  = struct("name", "cal", "continue", false);
method_dal  = struct("name", "dal", "continue", false);
methods     = {method_cal, method_dal};

weights_lst = [0.3, 0.4, 0.5];
for i=1:length(weights_lst)
    weight = weights_lst(i);
    method_dcal = struct("name", "dcal", "weight", weight, "continue", false);
    methods{end+1} = method_dcal;
end

%%%% DEBUG %%%%
% methods = {method_cal, method_dal};

clear g i gamma method_cal method_cal method_dcal method_mab reward_func reward_name weight;

num_methods = length(methods);
%%
eval_lst_all = cell(4, 4, 4, num_methods);
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    directory = './results';
    filename = sprintf('runner_batchprocess_%s.mat', dspname);
    fullPath = fullfile(directory, filename);
    load(fullPath, 'eval_lst')
    eval_lst_all(:,:,:,k) = eval_lst(:,:,:,1);
    clear eval_lst;
end
eval_lst = eval_lst_all;
%%
H = size(eval_lst{1,1,1,1}.ACC, 2);
avg_acc = zeros(num_methods, H);
avg_tpr = zeros(num_methods, H);
avg_tnr = zeros(num_methods, H);
avg_precision = zeros(num_methods, H);
avg_recall    = zeros(num_methods, H);
avg_fscore    = zeros(num_methods, H);
avg_auc       = zeros(num_methods, H);

for k = 1:num_methods
    for ann1 = 1:4
        for ann2 = 1:4
            for ann3 = 1:4
                avg_acc(k,:) = avg_acc(k,:) + eval_lst{ann1, ann2, ann3, k}.ACC(1:H);
                avg_tpr(k,:) = avg_tpr(k,:) + eval_lst{ann1, ann2, ann3, k}.TPR(1:H);
                avg_tnr(k,:) = avg_tnr(k,:) + eval_lst{ann1, ann2, ann3, k}.TNR(1:H);
                avg_precision(k,:) = avg_precision(k,:) + eval_lst{ann1, ann2, ann3, k}.Precision(1:H);
                avg_recall(k,:)    = avg_recall(k,:) + eval_lst{ann1, ann2, ann3, k}.Recall(1:H);
                avg_fscore(k,:)    = avg_fscore(k,:) + eval_lst{ann1, ann2, ann3, k}.Fscore(1:H);
                avg_auc(k,:)       = avg_auc(k,:) + eval_lst{ann1, ann2, ann3, k}.AUC(1:H);
            end
        end
    end
    avg_acc(k,:) = avg_acc(k,:) ./ 64;
    avg_tpr(k,:) = avg_tpr(k,:) ./ 64;
    avg_tnr(k,:) = avg_tnr(k,:) ./ 64;
    avg_precision(k,:) = avg_precision(k,:) ./ 64;
    avg_recall(k,:)    = avg_recall(k,:) ./ 64;
    avg_fscore(k,:)    = avg_fscore(k,:) ./ 64;
    avg_auc(k,:)       = avg_auc(k,:) ./ 64;
end

avg_acc_human = 0;
avg_tpr_human = 0;
avg_tnr_human = 0;
avg_precision_human = 0;
avg_recall_human    = 0;
avg_fscore_human    = 0;
avg_auc_human       = 0;

for ann1 = 1:4
    ann1_name    = annotator_lst(1, ann1);
    choices_ann1 = choices_all{ann1_name, 1};
    choices_gt1  = choices_gt_all{ann1_name, 1};
    for ann2 = 1:4
        ann2_name    = annotator_lst(2, ann2);
        choices_ann2 = choices_all{ann2_name, 2};
        choices_gt2  = choices_gt_all{ann2_name, 2};
        for ann3 = 1:4
            ann3_name    = annotator_lst(3, ann3);
            choices_ann3 = choices_all{ann3_name, 3};
            choices_gt3  = choices_gt_all{ann3_name, 3};
            % concatenate the choices and choices_gt
            choices_3d   = [choices_ann1, choices_ann2, choices_ann3];
            choices_3dgt = [choices_gt1, choices_gt2, choices_gt3];

            eval_metrics_human = get_ex_accuracy(choices_3d', choices_3dgt');

            avg_acc_human = avg_acc_human + eval_metrics_human.ACC / 64;
            avg_tpr_human = avg_tpr_human + eval_metrics_human.TPR / 64;
            avg_tnr_human = avg_tnr_human + eval_metrics_human.TNR / 64;
            avg_precision_human = avg_precision_human + eval_metrics_human.Precision / 64;
            avg_recall_human    = avg_recall_human + eval_metrics_human.Recall / 64;
            avg_fscore_human    = avg_fscore_human + eval_metrics_human.Fscore / 64;
            avg_auc_huamn       = avg_auc_human + eval_metrics_human.AUC / 64;
        end
    end
end
%%
color_map = [0,      0.4470, 0.7410; % blue
             0.8500, 0.3250, 0.0980; % orange
             0.9290, 0.6940, 0.1250; % yellow
             0.4940, 0.1840, 0.5560; % purple
             0.4660, 0.6740, 0.1880; % green
             0.6350, 0.0780, 0.1840];% red
x = 1:1:H;
x = x./10;
figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_auc(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('AUC')
line([1, H]./10, [avg_auc_huamn, avg_auc_huamn], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_acc(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('ACC')
line([1, H]./10, [avg_acc_human, avg_acc_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_recall(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('Recall')
line([1, H]./10, [avg_recall_human, avg_recall_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_precision(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('Precision')
line([1, H]./10, [avg_precision_human, avg_precision_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_tpr(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('True Positive Rate')
line([1, H]./10, [avg_tpr_human, avg_tpr_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_tnr(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('True Negative Rate')
line([1, H]./10, [avg_tnr_human, avg_tnr_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_fscore(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
end
legend()
title('F-score')
line([1, H]./10, [avg_fscore_human, avg_fscore_human], 'Color', 'r', 'LineStyle', '--');
hold off