function w = ProjectOntoIntersection(x, R, a, b)
% Projection onto an intersection of a L2-ball and a halfspace
% L2-ball with radius R and halfspace a'*x <= b
if (R < 0)
    error('Radius of L2 ball is negative: %2.3f\n', R);
end
d = abs(-b)/norm(a,2);
if d > R && b < 0
    error('Empty intersection');
end
% disp(b/norm(a,2));
% disp(R);
% disp(b);
% if abs(-b)/norm(a,2) > R && b < 0
%     disp("error");
%     return;
% end

% check if the point is inside the set
if norm(x, 2) <= R && dot(a,x) <= b
    w = x;
%     disp('inside');
    return;
end
% projection onto L2-ball and check if it is on the right side of hyperplane

if norm(x,2) > R
    x1 = R*x/norm(x,2);
    if dot(a, x1) <= b
        w = x1;
%         disp('on the ball');
        return;
    end
end

% projection onto the hyperplane and if it is in the ball
x2 = x-(dot(x,a)-b)/norm(a,2)^2*a;
if norm(x2,2) <= R
    w = x2;
%     disp('on the hyperplane');
    return;
end
% project onto a n-1 dimensional sphere
% 1. project onto a small ball; 2. project onto the hyperplane; 3. scale
% the vector
% d = abs(-b)/norm(a,2);
if R-d <0 
    disp(R);
    disp(d);
    disp('error');
    %return;
end
% disp(R-d);
r = sqrt(R^2-d^2);
% disp(r);
% the center of low dimenional ball
% c = b/norm(a,2)^2*a;
c = ProjectOntoHS(zeros(size(x)),a,b);

% project onto the low dimenional ball
x3 = ProjectOntoL2Ball(x, r, c);
% % disp(c);
% x3 = x-c;
% % disp(x3);
% x3 = r*x3/norm(x3,2)+c;

% project onto the hyperplane
x3 = ProjectOntoHS(x3,a,b);
% % disp(x3);
% x3 = x3-(dot(x3,a)-b)/norm(a,2)^2*a;
% % disp(x3);

% rescale it
x3 = r*(x3-c)/norm(x3-c,2)+c;
w = x3;
% disp('on the intersection');
return;

