clear
clc
close all

rng(0)


%% parameter setup

r = 5;

k = 15;


n = 1000;
d = k*r;


w_star = randn(r,1);
w_star = w_star/norm(w_star);

v_star = randn(k,1);
if (sum(v_star) < 0)
    v_star = - v_star;
end


nu = 0.08;


%% Data generation

x = randn(d,n);

y = zeros(1,n);

for j=1:k
    y = y + v_star(j) * CNN_ReLU( w_star.' * x(((j-1)*r + 1):(j*r),:) );
end


y = y + nu * randn(1,n);





%% Initialization
w_0 = randn(r,1);
w_0 = w_0 / norm(w_0);

v_0 = randn(k,1);
if (norm(v_0) > 0.5* k^(-1/2)*sum(v_star))
    v_0 = v_0/ norm(v_0) * 0.5* k^(-1/2)*sum(v_star);
end



%% Training
T = 100;
alpha = 0.04;

error_w_all = zeros(T,4);
error_v_all = zeros(T,4);

what_all = zeros(r,4);
vhat_all = zeros(k,4);

[what_all(:,1), vhat(:,1), error_w_all(:,1), error_v_all(:,1)] = ApproximateGD( w_0, v_0, x, y, alpha, T, w_star, v_star );

[what_all(:,2), vhat(:,2), error_w_all(:,2), error_v_all(:,2)] = ApproximateGD( -w_0, v_0, x, y, alpha, T, w_star, v_star );

[what_all(:,3), vhat(:,3), error_w_all(:,3), error_v_all(:,3)] = ApproximateGD( w_0, -v_0, x, y, alpha, T, w_star, v_star );

[what_all(:,4), vhat(:,4), error_w_all(:,4), error_v_all(:,4)] = ApproximateGD( -w_0, -v_0, x, y, alpha, T, w_star, v_star );







error_w_all_DoubleConvotron = zeros(T,4);
error_v_all_DoubleConvotron = zeros(T,4);

what_all_DoubleConvotron = zeros(r,4);
vhat_all_DoubleConvotron = zeros(k,4);

[what_all_DoubleConvotron(:,1), vhat_DoubleConvotron(:,1), error_w_all_DoubleConvotron(:,1), error_v_all_DoubleConvotron(:,1)] = DoubleConvotron( w_0, v_0, x, y, alpha, T, w_star, v_star );

[what_all_DoubleConvotron(:,2), vhat_DoubleConvotron(:,2), error_w_all_DoubleConvotron(:,2), error_v_all_DoubleConvotron(:,2)] = DoubleConvotron( -w_0, v_0, x, y, alpha, T, w_star, v_star );

[what_all_DoubleConvotron(:,3), vhat_DoubleConvotron(:,3), error_w_all_DoubleConvotron(:,3), error_v_all_DoubleConvotron(:,3)] = DoubleConvotron( w_0, -v_0, x, y, alpha, T, w_star, v_star );

[what_all_DoubleConvotron(:,4), vhat_DoubleConvotron(:,4), error_w_all_DoubleConvotron(:,4), error_v_all_DoubleConvotron(:,4)] = DoubleConvotron( -w_0, -v_0, x, y, alpha, T, w_star, v_star );







[~,jbest] = min(error_w_all(end,:));

[~,jbest_DoubleConvotron] = min(error_w_all_DoubleConvotron(end,:));


figure(1);
semilogy(1:T,(error_w_all(:,jbest)), 'r.-')

hold on

semilogy(1:T,(error_v_all(:,jbest)), 'b.-' )

% semilogy(x,y,'*','color',[.5 .4 .7])

semilogy(1:T,(error_w_all_DoubleConvotron(:,jbest_DoubleConvotron)), '.-' ,'color',[0 .5 0])

semilogy(1:T,(error_v_all_DoubleConvotron(:,jbest_DoubleConvotron)), '.-' ,'color',[0 .75 0.75] )


legend('ApproxGD: || w - w* ||_2','ApproxGD: || v - v* ||_2','DoubleConvotron: || w - w* ||_2','DoubleConvotron: || v - v* ||_2' );

%title('tanh, k = 30, r = 9, l_2 distance')

xlabel('epoch');
ylabel('l_2 error')





set(gcf,'Units','inches');
screenposition = get(gcf,'Position');
set(gcf,...
    'PaperPosition',[0 0 screenposition(3:4)],...
    'PaperSize',[screenposition(3:4)]);
print -dpdf -painters errorplot_ReLU_k15r5_log





