clc; clear;

%% Data parameters
n                          = 1000; % Number of datapoints
d                          = 1000; % Dimensionality
s                          = 10;   % Number of dimensions turned on
noise_var                  = 1;
num_runs                   = 50;
k                          = 1;
block_size                 = 100;
T                          = floor(n/block_size);
soft_thresholding_constant = 0.25/d;

error_block_power_method_all = zeros(num_runs, T+1);
error_oja_const_eta_all      = zeros(num_runs, n+1);
error_johnstone_lu_all       = zeros(num_runs, n+1);
error_wang_lu_all            = zeros(num_runs, n+1);

avg_error_block_power_method = zeros(T+1,1);
x_vals_block                 = zeros(T+1,1);
avg_error_oja_const_eta      = zeros(n+1,1);
avg_error_johnstone_lu       = zeros(n+1,1);
avg_error_wang_lu            = zeros(n+1,1);
var_estimates                = zeros(d,1);

for r = 1:num_runs
    %% Generate data
    % [A, data_cov, X] = generate_data(n,d,s,1,noise_var);
    d1 = 20; d2 = 10; s = d1;
    [data_cov, X]  = generate_data_jing_lei(n,d,d1,d2);
    gamma            = s;
    [U,Lambda]       = eig(data_cov);
    U                = real(U);
    Lambda           = real(Lambda);
    [Lambda, ind]    = sort(diag(Lambda)); 
    U                = U(:, ind);
    sample_cov       = cov(X);
    v1               = U(:,d);
    vp               = U(:,1:d-1);
    lambda1          = Lambda(d);
    lambda2          = Lambda(d-1);
    tr_sigma         = trace(data_cov);

    %% Run algorithm
    init  = randn(d,k);
    init  = init/norm(init);

    fprintf("Inner product of Initialisation and True Eigenvector = %.5f\n", ...
        init'*v1);
    
    error_oja_const_eta      = zeros(n+1,1);
    error_johnstone_lu       = zeros(n+1,1);
    error_wang_lu            = zeros(n+1,1);
    error_block_power_method = zeros(T+1,1);
    
    Q_1     = init;
    Q_1     = Q_1/norm(Q_1);
    [Q_1,~] = qr(Q_1,"econ");
    Q_2     = init;
    Q_2     = Q_2/norm(Q_2);
    [Q_2,~] = qr(Q_2,"econ");
    Q_3     = init;
    Q_3     = Q_3/norm(Q_3);
    [Q_3,~] = qr(Q_3,"econ");
    Q_4     = init;
    Q_4     = Q_4/norm(Q_4);
    [Q_4,~] = qr(Q_4,"econ");

    error_oja_const_eta(1)      = norm(U(:,1:d-k)'*Q_1,"fro")^2;
    error_johnstone_lu(1)       = norm(U(:,1:d-k)'*Q_2,"fro")^2;
    error_block_power_method(1) = norm(U(:,1:d-k)'*Q_3,"fro")^2;
    error_wang_lu(1)            = norm(U(:,1:d-k)'*Q_4,"fro")^2;

    % lr                          = 0.3*log(n)/n; 
    lr_init = 0.5;
    
    for t=1:n
        x_t = X(t,:)';

        for j=1:d
            if(t == 1)
                var_estimates(j) = (x_t(j)^2);
            else
                var_estimates(j) = var_estimates(j)*(t-1)/t + (x_t(j)^2)/t;
            end
        end

        lr = lr_init/(t + 1);
        Q_1    = Q_1 + lr*x_t*(x_t'*Q_1);
        Q_2    = get_variance_truncated_eigenvector(X, var_estimates, gamma);
        Q_4    = Q_4 + (1/d)*x_t*(x_t'*Q_4);
        Q_4    = Q_4 - soft_thresholding_constant*sign(Q_4);

        [Q_4,~]                         = qr(Q_4,"econ");
        Q_prime                         = row_truncation(Q_1, gamma);
        [Q_prime,~]                     = qr(Q_prime,"econ");
        error_oja_const_eta(t+1)        = norm(U(:,1:d-k)'*Q_prime,"fro")^2;
        error_wang_lu(t+1)              = norm(U(:,1:d-k)'*Q_4,"fro")^2;
        error_johnstone_lu(t+1)         = norm(U(:,1:d-k)'*Q_2,"fro")^2;

        Q_4 = Q_4 * sqrt(d);

        if(mod(t,100) == 0)
            fprintf("Error after %d iterations with Oja = %.5f\n", t, error_oja_const_eta(t+1));
            fprintf("Error after %d iterations with Wang and Lu = %.5f\n", t, error_wang_lu(t+1));
            fprintf("Error after %d iterations with Johnstone and Lu = %.5f\n", t, error_johnstone_lu(t+1));
        end
    end

    itr = 1;
    x_vals_block(1) = 1;
    for tau=0:(T-1)
        S_tau = zeros(d,k);
        for t=(block_size*tau + 1):(block_size*(tau+1))
            x_t = X(t,:)';
            S_tau = S_tau + (1/block_size)*(x_t*x_t')*Q_3;
            itr = itr + 1;
        end
        S_tau = row_truncation(S_tau, gamma);
        [Q_3,~] = qr(S_tau,"econ");
        error_block_power_method(tau+1) = norm(U(:,1:d-k)'*Q_3,"fro")^2;
        x_vals_block(tau+1) = itr;
        if(mod(tau,1) == 0)
            fprintf("Error after %d iterations with Stochastic Block Power Method = %.5f\n", tau, error_block_power_method(tau+1));
        end
    end

    % Store errors for this run
    error_block_power_method_all(r, :) = error_block_power_method;
    error_oja_const_eta_all(r, :) = error_oja_const_eta;
    error_johnstone_lu_all(r, :) = error_johnstone_lu;
    error_wang_lu_all(r, :) = error_wang_lu;

    avg_error_block_power_method = avg_error_block_power_method + error_block_power_method/num_runs;
    avg_error_oja_const_eta      = avg_error_oja_const_eta + error_oja_const_eta/num_runs;
    avg_error_johnstone_lu       = avg_error_johnstone_lu + error_johnstone_lu/num_runs;
    avg_error_wang_lu            = avg_error_wang_lu + error_wang_lu/num_runs;
    
    fprintf("Final error Traditional Oja's Algorithm with truncation : %.5f\n", error_oja_const_eta(end));
    fprintf("Final error  Stochastic Block Power Method : %.5f\n", error_block_power_method(end-1));
    fprintf("Final error  Wang and Lu : %.5f\n", error_wang_lu(end));
    fprintf("Final error Johnstone and Lu : %.5f\n", error_johnstone_lu(end));
    fprintf("============================\n");
end

% Theoretical rate
theoretical_rate = zeros(n,1);
factor           = (lambda1*lambda2)/((lambda1-lambda2)^2);
for i=1:n
    theoretical_rate(i) = factor*(s*log(d)/i);
end
theoretical_rate = theoretical_rate/theoretical_rate(1);

% Compute mean and standard deviation of errors
avg_error_oja_const_eta = mean(error_oja_const_eta_all, 1);
std_error_oja_const_eta = std(error_oja_const_eta_all, 0, 1);

avg_error_johnstone_lu = mean(error_johnstone_lu_all, 1);
std_error_johnstone_lu = std(error_johnstone_lu_all, 0, 1);

avg_error_wang_lu = mean(error_wang_lu_all, 1);
std_error_wang_lu = std(error_wang_lu_all, 0, 1);

avg_error_block_power_method = mean(error_block_power_method_all, 1);
std_error_block_power_method = std(error_block_power_method_all, 0, 1);

figure;
xlabel('Timesteps');
ylabel('Error');
title("sin-squared error with timesteps");
errorbar(1:n, avg_error_oja_const_eta(1:n), std_error_oja_const_eta(1:n)); hold on;
errorbar(1:n, avg_error_wang_lu(1:n), std_error_wang_lu(1:n)); hold on;
errorbar(x_vals_block(1:end-1), avg_error_block_power_method(1:end-1), std_error_block_power_method(1:end-1)); hold on;
errorbar(1:n, avg_error_johnstone_lu(1:n), std_error_johnstone_lu(1:n));
legend("Oja's Algorithm with Cardinality Truncation (Our Algorithm)", ...
       "Oja's Algorithm with Iterative Soft Thresholding (Wang and Lu, 2016)", ...
       "Stochastic Power Method with Truncation (Yang and Zu, 2015)", ...
       "Diagonal Thresholding (Johnstone and Lu, 2009)",...
       "Information Theoretical Rate, O(slog(d)/n)");

% figure;
% xlabel('Timesteps');
% ylabel('Error');
% title("sin-squared error with timesteps");
% plot(1:n,avg_error_oja_const_eta(1:n)); hold on;
% plot(1:n, avg_error_wang_lu(1:n)); hold on;
% plot(x_vals_block(1:end-1),avg_error_block_power_method(1:end-1)); hold on;
% plot(1:n, avg_error_johnstone_lu(1:n)); hold on;
% plot(1:n,theoretical_rate(1:n)); hold on;
% legend("Oja's Algorithm with Cardinality Truncation (Our Algorithm)", ...
%        "Oja's Algorithm with Iterative Soft Thresholding (Wang and Lu, 2016)", ...
%        "Stochastic Power Method with Truncation (Yang and Zu, 2015)", ...
%        "Diagonal Thresholding (Johnstone and Lu, 2009)",...
%        "Information Theoretical Rate, O(slog(d)/n)");