clear
clc

subplot(1,3,1)
load data_test_1;
theta = 0:pi/50:2*pi;
% Define the parameters for obstacles
px_obs_1 = 7.5;
py_obs_1 = 10;
r_obs_1 = 2.5;

px_obs_2 = 10;
py_obs_2 = 5;
r_obs_2 = 2.5;

% Define the parameters for goal 
pxg = 15; 
pyg = 15;
rg_s = 2;
rg = 1.5;


xCoord_g = pxg + rg_s*cos(theta);
yCoord_g = pyg + rg_s*sin(theta);

xCoord_obs_1 = px_obs_1 + r_obs_1*cos(theta);
yCoord_obs_1 = py_obs_1 + r_obs_1*sin(theta);

xCoord_obs_2 = px_obs_2 + r_obs_2*cos(theta);
yCoord_obs_2 = py_obs_2 + r_obs_2*sin(theta);




fill(xCoord_obs_1,yCoord_obs_1,'r','FaceAlpha',.1,'EdgeAlpha',.1)
hold on
fill(xCoord_obs_2,yCoord_obs_2,'r','FaceAlpha',.1,'EdgeAlpha',.1)
hold on
fill(xCoord_g,yCoord_g,'k','FaceAlpha',.1,'EdgeAlpha',.1)
hold on

nc_ratio = 15;
for i_ratio=nc_ratio:nc_ratio
    s_all = s_all_ratio{i_ratio};
    n_tr = length(s_all);
    for i_tr=1:n_tr
        px_itr = s_all{i_tr}(1,:);
        py_itr = s_all{i_tr}(2,:);
        distance_1 = r_obs_1^2 - (px_itr' - px_obs_1).^2 - (py_itr' - py_obs_1).^2;
        distance_2 = r_obs_2^2 - (px_itr' - px_obs_2).^2 - (py_itr' - py_obs_2).^2;
        g = max([distance_1;distance_2;]);
        if g<=0
            plot(px_itr,py_itr,'--b')
        else
            plot(px_itr,py_itr,'--r')
        end
        hold on
    end
end
hold off
xlabel('$x$','Interpreter','Latex')
ylabel("$y$",'Interpreter','Latex')
set(gca,'FontSize',16,'fontname','Times New Roman')
title('(a) Trajectories by deterministic policy','Interpreter','Latex');
xlim([0,20])
ylim([0,20])
grid on

subplot(1,3,2)

fill(xCoord_obs_1,yCoord_obs_1,'r','FaceAlpha',.1,'EdgeAlpha',.1)
hold on
fill(xCoord_obs_2,yCoord_obs_2,'r','FaceAlpha',.1,'EdgeAlpha',.1)
hold on
fill(xCoord_g,yCoord_g,'k','FaceAlpha',.1,'EdgeAlpha',.1)
hold on

n_group = 5;
load data_test_1;
n_ratio = length(s_all_ratio);
cost_mean_ratio_orig_group = zeros(n_ratio,n_group);
prob_vio_ratio_orig_group = zeros(n_ratio,n_group);

for i_group=1:n_group
datasetname = sprintf('data_test_%d', i_group);
load(datasetname);
n_ratio = length(s_all_ratio);
cost_mean_ratio = zeros(n_ratio,1);
prob_vio_ratio = zeros(n_ratio,1);
for i_ratio=1:n_ratio
    s_all = s_all_ratio{i_ratio};
    n_tr = length(s_all);
    cost_tr = zeros(n_tr,1);
    collision_tr = zeros(n_tr,1);
    for i_tr=1:n_tr
        px_itr = s_all{i_tr}(1,:);
        py_itr = s_all{i_tr}(2,:);
        cost_tr(i_tr) = sum( 1./((px_itr-pxg).^2 + (py_itr-pyg).^2 + 0.1) ) /length(px_itr);
        distance_tr(i_tr) = sum( (px_itr(2:end)-px_itr(1:end-1)).^2 + (py_itr(2:end)-py_itr(1:end-1)).^2 );
        distance_1 = r_obs_1^2 - (px_itr' - px_obs_1).^2 - (py_itr' - py_obs_1).^2;
        distance_2 = r_obs_2^2 - (px_itr' - px_obs_2).^2 - (py_itr' - py_obs_2).^2;
        g = max([distance_1;distance_2;]);
        if g>0
            collision_tr(i_tr) = 1;
        end
    end
    cost_mean_ratio(i_ratio) = mean(cost_tr);
    prob_vio_ratio(i_ratio) = mean(collision_tr);
end

prob_vio_ratio_orig_group(:,i_group) = prob_vio_ratio;
cost_mean_ratio_orig_group(:,i_group) = cost_mean_ratio;

end

load data_test_flip_1.mat;


n_ratio = length(s_all_ratio);
cost_mean_ratio = zeros(n_ratio,1);
prob_vio_ratio = zeros(n_ratio,1);

for i_ratio=1:n_ratio
    s_all = s_all_ratio{i_ratio};
    n_tr = length(s_all);
    cost_tr = zeros(n_tr,1);
    collision_tr = zeros(n_tr,1);
    for i_tr=1:n_tr
        px_itr = s_all{i_tr}(1,:);
        py_itr = s_all{i_tr}(2,:);
        cost_tr(i_tr) = sum( 1./((px_itr-pxg).^2 + (py_itr-pyg).^2 + 0.1) ) /length(px_itr);
        distance_tr(i_tr) = sum( (px_itr(2:end)-px_itr(1:end-1)).^2 + (py_itr(2:end)-py_itr(1:end-1)).^2 );
        distance_1 = r_obs_1^2 - (px_itr' - px_obs_1).^2 - (py_itr' - py_obs_1).^2;
        distance_2 = r_obs_2^2 - (px_itr' - px_obs_2).^2 - (py_itr' - py_obs_2).^2;
        g = max([distance_1;distance_2;]);
        if g>0
            collision_tr(i_tr) = 1;
        end
        if i_ratio==nc_ratio&&g<=0
            plot(px_itr,py_itr,'--b')
        elseif i_ratio==nc_ratio&&g>0
            plot(px_itr,py_itr,'--r')
        end
        hold on

    end
    cost_mean_ratio(i_ratio) = mean(cost_tr);
    prob_vio_ratio(i_ratio) = mean(collision_tr);
end
hold off
title('(b) Trajectories by flipping-based policy','Interpreter','Latex');
xlabel('$x$','Interpreter','Latex')
ylabel("$y$",'Interpreter','Latex')
set(gca,'FontSize',16,'fontname','Times New Roman')
xlim([0,20])
ylim([0,20])
grid on

n_group = 5;
cost_mean_ratio_flip_group = zeros(n_ratio,n_group);
prob_vio_ratio_flip_group = zeros(n_ratio,n_group);

for i_group=1:n_group
datasetname = sprintf('data_test_flip_%d', i_group);
load(datasetname);
n_ratio = length(s_all_ratio);
cost_mean_ratio = zeros(n_ratio,1);
prob_vio_ratio = zeros(n_ratio,1);
for i_ratio=1:n_ratio
    s_all = s_all_ratio{i_ratio};
    n_tr = length(s_all);
    cost_tr = zeros(n_tr,1);
    collision_tr = zeros(n_tr,1);
    for i_tr=1:n_tr
        px_itr = s_all{i_tr}(1,:);
        py_itr = s_all{i_tr}(2,:);
        cost_tr(i_tr) = sum( 1./((px_itr-pxg).^2 + (py_itr-pyg).^2 + 0.1) ) /length(px_itr);
        distance_tr(i_tr) = sum( (px_itr(2:end)-px_itr(1:end-1)).^2 + (py_itr(2:end)-py_itr(1:end-1)).^2 );
        distance_1 = r_obs_1^2 - (px_itr' - px_obs_1).^2 - (py_itr' - py_obs_1).^2;
        distance_2 = r_obs_2^2 - (px_itr' - px_obs_2).^2 - (py_itr' - py_obs_2).^2;
        g = max([distance_1;distance_2;]);
        if g>0
            collision_tr(i_tr) = 1;
        end
    end
    cost_mean_ratio(i_ratio) = mean(cost_tr);
    prob_vio_ratio(i_ratio) = mean(collision_tr);
end

prob_vio_ratio_flip_group(:,i_group) = prob_vio_ratio;
cost_mean_ratio_flip_group(:,i_group) = cost_mean_ratio;

end


prob_vio_ratio_flip_mean = mean(prob_vio_ratio_flip_group')';
cost_mean_ratio_flip_mean = mean(cost_mean_ratio_flip_group')';
prob_vio_ratio_flip_max = max(prob_vio_ratio_flip_group')';
cost_mean_ratio_flip_max = max(cost_mean_ratio_flip_group')';
prob_vio_ratio_flip_min = min(prob_vio_ratio_flip_group')';
cost_mean_ratio_flip_min = min(cost_mean_ratio_flip_group')';

prob_vio_ratio_orig_mean = mean(prob_vio_ratio_orig_group')';
cost_mean_ratio_orig_mean = mean(cost_mean_ratio_orig_group')';
prob_vio_ratio_orig_max = max(prob_vio_ratio_orig_group')';
cost_mean_ratio_orig_max = max(cost_mean_ratio_orig_group')';
prob_vio_ratio_orig_min = min(prob_vio_ratio_orig_group')';
cost_mean_ratio_orig_min = min(cost_mean_ratio_orig_group')';

subplot(1,3,3)
plot(prob_vio_ratio_orig_mean,cost_mean_ratio_orig_mean,'sk','LineWidth',1,'MarkerSize',3)
hold on
plot(prob_vio_ratio_flip_mean,cost_mean_ratio_flip_mean,'ob','LineWidth',1,'MarkerSize',3)
hold on
for i_ratio=1:n_ratio
    plot([prob_vio_ratio_orig_min(i_ratio) prob_vio_ratio_orig_max(i_ratio)],[cost_mean_ratio_orig_mean(i_ratio) cost_mean_ratio_orig_mean(i_ratio)],'-k','LineWidth',1)
    hold on
    plot([prob_vio_ratio_orig_mean(i_ratio) prob_vio_ratio_orig_mean(i_ratio)],[cost_mean_ratio_orig_min(i_ratio) cost_mean_ratio_orig_max(i_ratio)],'-k','LineWidth',1)
    hold on
    plot([prob_vio_ratio_flip_min(i_ratio) prob_vio_ratio_flip_max(i_ratio)],[cost_mean_ratio_flip_mean(i_ratio) cost_mean_ratio_flip_mean(i_ratio)],'-b','LineWidth',1)
    hold on
    plot([prob_vio_ratio_flip_mean(i_ratio) prob_vio_ratio_flip_mean(i_ratio)],[cost_mean_ratio_flip_min(i_ratio) cost_mean_ratio_flip_max(i_ratio)],'-b','LineWidth',1)
    hold on
end

hold off
title('(c) Reward v.s. violation probability','Interpreter','Latex');
legend('Original','Flipping-based')
xlabel('Violation probability $\alpha$','Interpreter','Latex')
ylabel("Mean reward",'Interpreter','Latex')
set(gca,'FontSize',16,'fontname','Times New Roman')
xlim([0,0.45])
grid on