clc; clear;

%% Data parameters
n         = 500; % Number of datapoints
d         = 500; % Dimensionality
d1        = 20; 
d2        = 10; 
s         = d1;
noise_var = 1;
num_runs  = 1; % Number of independent runs

%% Initialize variables to store results
errors_all_runs = zeros(n+1, num_runs);
norm_bn_all_runs = zeros(n, num_runs);
norm_bn_v1_all_runs = zeros(n, num_runs);
theoretical_bound_bn_all_runs = zeros(n, num_runs);
theoretical_bound_bn_1_all_runs = zeros(n, num_runs);

for run = 1:num_runs
    fprintf("Run %d/%d\n", run, num_runs);

    %% Generate data
    [data_cov, X]  = generate_data_jing_lei(n,d,d1,d2);
    [U,Lambda]     = eig(data_cov);
    U              = real(U);
    Lambda         = real(Lambda);
    [Lambda, ind]  = sort(diag(Lambda)); 
    U              = U(:, ind);
    sample_cov     = cov(X);
    v1             = U(:,d);
    lambda1        = Lambda(d);
    lambda2        = Lambda(d-1);

    %% Run algorithm
    k     = 1;
    gamma = s;
    init  = randn(d,k);
    init  = init/norm(init);

    fprintf("Inner product of Initialisation and True Eigenvector = %.5f\n", init'*v1);

    % Traditional Oja with Truncation
    fprintf("Traditional Oja's Algorithm with truncation\n");
    error_tradoja_trunc = zeros(n+1,1);
    Q     = init;
    Q     = Q/norm(Q);
    [Q,~] = qr(Q,"econ");
    error_tradoja_trunc(1) = norm(U(:,1:d-k)'*Q,"fro")^2;
    lr_init         = 0.25*log(n)/n;
    Bn_u0           = v1;
    Bn              = eye(d,d);

    factor = 1;
    if(all(Bn_u0 == init))
        factor = (1/sqrt(d));
    end

    in_support_index = 1;
    outside_support_index = 2*s;

    alpha = 1 - (0.5/(lambda1-noise_var))*lr_init*noise_var*trace(data_cov);
    eta_2_coff = (lambda1-noise_var)^2*(1+1/alpha) + 3*(1+1/alpha)*(lambda1-noise_var)*noise_var + (3 + 2/alpha)*noise_var^2;

    p = floor(log(d));
    norm_bn          = ones(n,1);
    norm_bn_v1       = ones(n,1);
    theoretical_bound_bn = ones(n,1);
    theoretical_bound_bn(1) = (1 + lr_init*lambda1 + lr_init^2*eta_2_coff);
    theoretical_bound_bn_1 = ones(n,1);
    x_1 = X(1,:)';
    theoretical_bound_bn_1(1) = norm(eye(d) + lr_init*(x_1*x_1'), 2);

    itr = 1;
    norm_xx_bn = ones(n-1,1);
    norm_xx_bn_1 = ones(n-1,1);
    norm_sigma_bn = ones(n-1,1);

    entry_in_support = zeros(n,1);
    entry_outside_support = zeros(n,1);
    theoretical_entry_in_support = ones(n,1);
    theoretical_entry_outside_support = ones(n,1);
    theoretical_entry_in_support(1) = factor*v1(in_support_index);
    theoretical_entry_outside_support(1) = factor*16*lr_init*sqrt(lambda1-data_cov(outside_support_index,outside_support_index));

    for t=1:n
        x_t = X(t,:)';
        lr = lr_init;
        Q = (eye(d) + lr*(x_t*x_t'))*Q;
        Bn = (eye(d) + lr*(x_t*x_t'))*Bn;
        Bn_u0 = (eye(d) + lr*(x_t*x_t'))*Bn_u0;
        norm_bn(t) = norm(Bn,2)^2;
        norm_bn_v1(t) = norm(Bn_u0,2)^2;
        if(t >= 2)
            theoretical_bound_bn(t) = theoretical_bound_bn(t-1)*(1 + 2*lr_init*lambda1 + lr^2*eta_2_coff);
            theoretical_bound_bn_1(t) = theoretical_bound_bn_1(t-1)*((1 + lr*lambda1)^2 + (p-1)*lr^2*norm(x_t)^2);
        end
        if(t < n)
            x_t_1 = X(t+1,:)';
            norm_xx_bn(itr) = norm((x_t_1*x_t_1'-data_cov)*Bn);
            norm_xx_bn_1(itr) = norm((x_t_1*x_t_1')*Bn);
            norm_sigma_bn(itr) = norm(noise_var*d*Bn);
            itr = itr + 1;
        end
        [Q,~] = qr(Q,"econ");
        Q_prime = row_truncation(Q, gamma);
        [Q_prime,~] = qr(Q_prime,"econ");
        error_tradoja_trunc(t+1) = norm(U(:,1:d-k)'*Q_prime,"fro")^2;
        if(mod(t,100) == 0)
            fprintf("Error after %d iterations = %.5f\n", t, error_tradoja_trunc(t+1));
        end
    end
    Q_traditional_trunc = Q_prime;
    fprintf("Final error Traditional Oja's Algorithm with truncation : %.5f\n", error_tradoja_trunc(end));
    fprintf("============================\n");

    %% Store results for current run
    errors_all_runs(:, run) = error_tradoja_trunc;
    norm_bn_all_runs(:, run) = norm_bn;
    norm_bn_v1_all_runs(:, run) = norm_bn_v1;
    theoretical_bound_bn_all_runs(:, run) = theoretical_bound_bn;
    theoretical_bound_bn_1_all_runs(:, run) = theoretical_bound_bn_1;
end

%% Compute mean and standard deviation
mean_errors = mean(errors_all_runs, 2);
std_errors = std(errors_all_runs, 0, 2);
mean_norm_bn = mean(norm_bn_all_runs, 2);
std_norm_bn = std(norm_bn_all_runs, 0, 2);
mean_norm_bn_v1 = mean(norm_bn_v1_all_runs, 2);
std_norm_bn_v1 = std(norm_bn_v1_all_runs, 0, 2);
mean_theoretical_bound_bn = mean(theoretical_bound_bn_all_runs, 2);
std_theoretical_bound_bn = std(theoretical_bound_bn_all_runs, 0, 2);
mean_theoretical_bound_bn_1 = mean(theoretical_bound_bn_1_all_runs, 2);
std_theoretical_bound_bn_1 = std(theoretical_bound_bn_1_all_runs, 0, 2);

%% Plot results with error bars
figure;
hold on;
plot(1:n, log(mean_norm_bn), 'LineWidth', 2);
plot(1:n, log(mean_norm_bn_v1), 'LineWidth', 2);
plot(1:n, log(mean_theoretical_bound_bn), 'LineWidth', 2);
plot(1:n, log(mean_theoretical_bound_bn_1), 'LineWidth', 2);

xlabel('Timesteps', 'FontSize', 20);
ylabel('Value', 'FontSize', 20);
title("Comparison of B_{n} with timesteps", 'FontSize', 20);

legend("||B_{n}B_{n}^{T}||", ...
       "v_{1}^{T}B_{n}B_{n}^{T}v_{1}", ...
       "(1 + \eta\lambda_{1})^{2n}", ...
       "Jain et al. (2016)", 'FontSize', 20);

grid on;
set(gca, 'GridLineStyle', '-', 'GridAlpha', 0.5, 'FontSize', 20);
hold off;
