clc; clear;

addpath(genpath('./FastICA_25/'));

num_experiments    = 5;
maxruns            = 20;
verbose_flag       = false;

corner_scores_cgf      = zeros(num_experiments,1);
corner_scores_chf      = zeros(num_experiments,1);
corner_scores_kurtosis = zeros(num_experiments,1);
corner_scores_aiyou    = zeros(num_experiments,1);
corner_scores_jade     = zeros(num_experiments,1);
corner_scores_fastica  = zeros(num_experiments,1);
corner_scores_meta     = zeros(num_experiments,1);

amari_scores_cgf       = zeros(num_experiments,1);
amari_scores_chf       = zeros(num_experiments,1);
amari_scores_kurtosis  = zeros(num_experiments,1);
amari_scores_aiyou     = zeros(num_experiments,1);
amari_scores_jade      = zeros(num_experiments,1);
amari_scores_fastica   = zeros(num_experiments,1);
amari_scores_meta      = zeros(num_experiments,1);

sinr_scores_cgf        = zeros(num_experiments,1);
sinr_scores_chf        = zeros(num_experiments,1);
sinr_scores_kurtosis   = zeros(num_experiments,1);
sinr_scores_aiyou      = zeros(num_experiments,1);
sinr_scores_jade       = zeros(num_experiments,1);
sinr_scores_fastica    = zeros(num_experiments,1);
sinr_scores_meta       = zeros(num_experiments,1);

k = 14;
U      = gallery('qmult',k);
Lambda = diag(rand(k,1)*2 + 1);
V      = gallery('qmult',k);
B      = U*Lambda*V'; % Mixing matrix

% B = rand(k,k);
% B = B./sum(B,2);
% B = eye(k);

for exp=1:num_experiments
    fprintf("Experiment %d\n", exp);

    dataset            = "zero_kurtosis_bernoulli";
    n                  = 100000;
    % noise_power        = 0.01;
    noise_power        = 0.1;
    num_trials         = 200;
    projections        = randn(2*num_trials,k);
   
    [Sigma,B,X,z,e] = generate_data(dataset,B,n,k,noise_power);
    % [Sigma,B,X,z,e] = generate_data_belkin(n,k,noise_power);
    A_aiyou       = inv(PCF_ica(X'));
    A_jade        = jade(X');
    [A_fastica,~] = fastica(X');
    
   
    cgf_fn          = cgf(X);
    chf_fn          = symmetric_chf(X);
    kurtosis_fn     = kurtosis(X);
    C_cgf           = cgf_fn.estimate_C(20);
    fprintf("C matrix generation for CGF complete\n");
    C_chf           = chf_fn.estimate_C(20);
    fprintf("C matrix generation for CHF complete\n");
    C_kurtosis      = kurtosis_fn.estimate_C(20);
    fprintf("C matrix generation for Kurtosis complete\n");
    
    A = A_aiyou;
    C_cgf      = A*A';
    C_chf      = A*A';
    C_kurtosis_prime = A*A';
    pinv_C_cgf      = pinv(C_cgf);
    pinv_C_chf      = pinv(C_chf);
    pinv_C_kurtosis = pinv(C_kurtosis);
    pinv_C_kurtosis_prime = pinv(C_kurtosis_prime);

    A_cgf       = zeros(k,k);
    B1_cgf      = zeros(k,k);
    A_chf       = zeros(k,k);
    B1_chf      = zeros(k,k);
    A_kurtosis  = zeros(k,k);
    B1_kurtosis = zeros(k,k);
    A_kurtosis_prime  = zeros(k,k);
    B1_kurtosis_prime = zeros(k,k);
    A_meta      = zeros(k,k);
    B1_meta     = zeros(k,k);
 
    for i=1:k
        fprintf('Column %d\n',i);
        u_init=randn(k,1);
        u_init=u_init/norm(u_init);
        
        fprintf('CGF run\n');
        M_cgf = eye(k) - A_cgf*B1_cgf;
        [u_cgf,~,~,~] = ICA_power(X, ...
                                  maxruns, ...
                                  B, ...
                                  verbose_flag, ...
                                  cgf_fn, ...
                                  u_init, ...
                                  C_cgf, ...
                                  M_cgf);

        A_cgf(:,i)=u_cgf';
        u1=pinv_C_cgf*A_cgf(:,i);
        B1_cgf(i,:)=u1/(u1'*A_cgf(:,i));

        fprintf('CHF run\n');
        M_chf = eye(k) - A_chf*B1_chf;
        [u_chf,~,~,~] = ICA_power(X, ...
                                  maxruns, ...
                                  B, ...
                                  verbose_flag, ...
                                  chf_fn, ...
                                  u_init, ...
                                  C_chf, ...
                                  M_chf);

        A_chf(:,i)=u_chf';
        u1=pinv_C_chf*A_chf(:,i);
        B1_chf(i,:)=u1/(u1'*A_chf(:,i));

        fprintf('Kurtosis original run\n');
        M_kurtosis = eye(k) - A_kurtosis*B1_kurtosis;
        [u_kurtosis,~,~,~] = ICA_power(X, ...
                                       maxruns, ...
                                       B, ...
                                       verbose_flag, ...
                                       kurtosis_fn, ...
                                       u_init, ...
                                       C_kurtosis, ...
                                       M_kurtosis);

        A_kurtosis(:,i)=u_kurtosis';
        u1=pinv_C_kurtosis*A_kurtosis(:,i);
        B1_kurtosis(i,:)=u1/(u1'*A_kurtosis(:,i));

        fprintf('Kurtosis with PFICA C run\n');
        M_kurtosis_prime = eye(k) - A_kurtosis_prime*B1_kurtosis_prime;
        [u_kurtosis_prime,~,~,~] = ICA_power(X, ...
                                       maxruns, ...
                                       B, ...
                                       verbose_flag, ...
                                       kurtosis_fn, ...
                                       u_init, ...
                                       C_kurtosis_prime, ...
                                       M_kurtosis_prime);

        A_kurtosis_prime(:,i)=u_kurtosis_prime';
        u1=pinv_C_kurtosis_prime*A_kurtosis_prime(:,i);
        B1_kurtosis_prime(i,:)=u1/(u1'*A_kurtosis_prime(:,i));

        fprintf('Meta run\n');
        M_meta = eye(k) - A_meta*B1_meta;
        [u_meta,C_meta] = ICA_power_meta(X, ...
                                         maxruns, ...
                                         B, ...
                                         verbose_flag, ...
                                         kurtosis_fn, ...
                                         chf_fn, ...
                                         cgf_fn, ...
                                         u_init, ...
                                         C_kurtosis, ...
                                         C_chf, ...
                                         C_cgf, ...
                                         num_trials, ...
                                         projections, ...
                                         M_meta);

        A_meta(:,i)=u_meta';
        pinv_C_meta=pinv(C_meta);
        u1=pinv_C_meta*A_meta(:,i);
        B1_meta(i,:)=u1/(u1'*A_meta(:,i));

        fprintf('=====================\n');
        fprintf('=====================\n');
    end
    
    fprintf('Metric Computation\n');
    
    avg_thresh = 0;
    num_repetitions = 5;
    for itr=1:num_repetitions
        [~,match_matrix] = compute_corner_score(randn(k,k),B,-1);
        avg_thresh = avg_thresh + mean(max(match_matrix,[],1));
    end
    thresh_random = avg_thresh/num_repetitions;
    fprintf("Random threshold for corner score = %.5f\n", thresh_random);

    [corner_scores_cgf(exp), match_matrix_cgf] = ...
                          compute_corner_score(A_cgf,B,thresh_random);
    [corner_scores_chf(exp), match_matrix_chf] = ...
                          compute_corner_score(A_chf,B,thresh_random);
    [corner_scores_kurtosis(exp), match_matrix_kurtosis] = ...
                          compute_corner_score(A_kurtosis,B,thresh_random);
    [corner_scores_kurtosis_prime(exp), match_matrix_kurtosis_prime] = ...
                  compute_corner_score(A_kurtosis_prime,B,thresh_random);
    [corner_scores_aiyou(exp), match_matrix_aiyou] = ...
                          compute_corner_score(A_aiyou,B,thresh_random);
    [corner_scores_jade(exp), match_matrix_jade] = ...
                          compute_corner_score(A_jade,B,thresh_random);
    [corner_scores_meta(exp), match_matrix_meta] = ...
                          compute_corner_score(A_meta,B,thresh_random);
    if(size(A_fastica,1) == k && size(A_fastica,2) == k)
        [corner_scores_fastica(exp), match_matrix_fastica] = ...
                          compute_corner_score(A_fastica,B,thresh_random);
    end

    amari_scores_cgf(exp)      = compute_amari_score(A_cgf,B);
    amari_scores_chf(exp)      = compute_amari_score(A_chf,B);
    amari_scores_kurtosis(exp) = compute_amari_score(A_kurtosis,B);
    amari_scores_kurtosis_prime(exp) = compute_amari_score(A_kurtosis_prime,B);
    amari_scores_aiyou(exp)    = compute_amari_score(A_aiyou,B);
    amari_scores_jade(exp)     = compute_amari_score(A_jade,B);
    amari_scores_meta(exp)     = compute_amari_score(A_meta,B);
    if(size(A_fastica,1) == k && size(A_fastica,2) == k)
        amari_scores_fastica(exp)  = compute_amari_score(A_fastica,B);
    end

    sinr_scores_cgf(exp)      = compute_sinr(A_cgf,X,B,z,e,Sigma);
    sinr_scores_chf(exp)      = compute_sinr(A_chf,X,B,z,e,Sigma);
    sinr_scores_kurtosis(exp) = compute_sinr(A_kurtosis,X,B,z,e,Sigma);
    sinr_scores_kurtosis_prime(exp) = compute_sinr(A_kurtosis_prime,X,B,z,e,Sigma);
    sinr_scores_aiyou(exp)    = compute_sinr(A_aiyou,X,B,z,e,Sigma);
    sinr_scores_jade(exp)     = compute_sinr(A_jade,X,B,z,e,Sigma);
    sinr_scores_meta(exp)     = compute_sinr(A_meta,X,B,z,e,Sigma);
    if(size(A_fastica,1) == k && size(A_fastica,2) == k)
        sinr_scores_fastica(exp)  = compute_sinr(A_fastica,X,B,z,e,Sigma);
    end

    fprintf("Corner Score for CGF = %.5f\n", ...
             corner_scores_cgf(exp));
    fprintf("Corner Score for CHF = %.5f\n", ...
             corner_scores_chf(exp));
    fprintf("Corner Score for Kurtosis = %.5f\n", ...
             corner_scores_kurtosis(exp));
    fprintf("Corner Score for Kurtosis with PFICA C = %.5f\n", ...
             corner_scores_kurtosis_prime(exp));
    fprintf("Corner Score for Meta = %.5f\n", ...
             corner_scores_meta(exp));
    fprintf("Corner Score for PFICA = %.5f\n", ...
             corner_scores_aiyou(exp));
    fprintf("Corner Score for JADE = %.5f\n", ...
             corner_scores_jade(exp));
    fprintf("Corner Score for FASTICA = %.5f\n", ...
             corner_scores_fastica(exp));

    fprintf("Amari Score for CGF = %.5f\n", ...
             amari_scores_cgf(exp));
    fprintf("Amari Score for CHF = %.5f\n", ...
             amari_scores_chf(exp));
    fprintf("Amari Score for Kurtosis = %.5f\n", ...
             amari_scores_kurtosis(exp));
    fprintf("Amari Score for Kurtosis prime = %.5f\n", ...
             amari_scores_kurtosis_prime(exp));
    fprintf("Amari Score for Meta = %.5f\n", ...
             amari_scores_meta(exp));
    fprintf("Amari Score for PFICA %.5f\n", ...
             amari_scores_aiyou(exp));
    fprintf("Amari Score for JADE %.5f\n", ...
             amari_scores_jade(exp));
    fprintf("Amari Score for FASTICA = %.5f\n", ...
             amari_scores_fastica(exp));

    fprintf("SINR Score for CGF = %.5f\n", ...
             sinr_scores_cgf(exp));
    fprintf("SINR Score for CHF = %.5f\n", ...
             sinr_scores_chf(exp));
    fprintf("SINR Score for Kurtosis = %.5f\n", ...
             sinr_scores_kurtosis(exp));
    fprintf("SINR Score for Kurtosis prime = %.5f\n", ...
             sinr_scores_kurtosis_prime(exp));
    fprintf("SINR Score for Meta = %.5f\n", ...
             sinr_scores_meta(exp));
    fprintf("SINR Score for PFICA = %.5f\n", ...
             sinr_scores_aiyou(exp));
    fprintf("SINR Score for JADE = %.5f\n", ...
             sinr_scores_jade(exp));
    fprintf("SINR Score for FASTICA = %.5f\n", ...
             sinr_scores_fastica(exp));

    fprintf('=====================\n');
    fprintf('=====================\n');
    fprintf('=====================\n');
end

fprintf("Average Corner Score for CGF = %.5f +- %.5f\n", ...
    mean(corner_scores_cgf), std(corner_scores_cgf));
fprintf("Average Corner Score for CHF = %.5f +- %.5f\n", ...
    mean(corner_scores_chf), std(corner_scores_chf));
fprintf("Average Corner Score for Kurtosis = %.5f +- %.5f\n", ...
    mean(corner_scores_kurtosis), std(corner_scores_kurtosis));
fprintf("Average Corner Score for Kurtosis prime = %.5f +- %.5f\n", ...
    mean(corner_scores_kurtosis_prime), std(corner_scores_kurtosis_prime));
fprintf("Average Corner Score for Meta = %.5f +- %.5f\n", ...
    mean(corner_scores_meta), std(corner_scores_meta));
fprintf("Average Corner Score for PFICA = %.5f +- %.5f\n", ...
    mean(corner_scores_aiyou), std(corner_scores_aiyou));
fprintf("Average Corner Score for JADE = %.5f +- %.5f\n", ...
    mean(corner_scores_jade), std(corner_scores_jade));
fprintf("Average Corner Score for FASTICA = %.5f +- %.5f\n", ...
    mean(corner_scores_fastica), std(corner_scores_fastica));
fprintf('=====================\n');

fprintf("Average Amari Score for CGF = %.5f +- %.5f\n", ...
    mean(amari_scores_cgf), std(amari_scores_cgf));
fprintf("Average Amari Score for CHF = %.5f +- %.5f\n", ...
    mean(amari_scores_chf), std(amari_scores_chf));
fprintf("Average Amari Score for Kurtosis = %.5f +- %.5f\n", ...
    mean(amari_scores_kurtosis), std(amari_scores_kurtosis));
fprintf("Average Amari Score for Kurtosis prime = %.5f +- %.5f\n", ...
    mean(amari_scores_kurtosis_prime), std(amari_scores_kurtosis_prime));
fprintf("Average Amari Score for Meta = %.5f +- %.5f\n", ...
    mean(amari_scores_meta), std(amari_scores_meta));
fprintf("Average Amari Score for PFICA = %.5f +- %.5f\n", ...
    mean(amari_scores_aiyou), std(amari_scores_aiyou));
fprintf("Average Amari Score for JADE = %.5f +- %.5f\n", ...
    mean(amari_scores_jade), std(amari_scores_jade));
fprintf("Average Amari Score for FASTICA = %.5f +- %.5f\n", ...
    mean(amari_scores_fastica), std(amari_scores_fastica));
fprintf('=====================\n');

fprintf("Average SINR Score for CGF = %.5f +- %.5f\n", ...
    mean(sinr_scores_cgf), std(sinr_scores_cgf));
fprintf("Average SINR Score for CHF = %.5f +- %.5f\n", ...
    mean(sinr_scores_chf), std(sinr_scores_chf));
fprintf("Average SINR Score for Kurtosis = %.5f +- %.5f\n", ...
    mean(sinr_scores_kurtosis), std(sinr_scores_kurtosis));
fprintf("Average SINR Score for Kurtosis prime = %.5f +- %.5f\n", ...
    mean(sinr_scores_kurtosis_prime), std(sinr_scores_kurtosis_prime));
fprintf("Average SINR Score for Meta = %.5f +- %.5f\n", ...
    mean(sinr_scores_meta), std(sinr_scores_meta));
fprintf("Average SINR Score for PFICA = %.5f +- %.5f\n", ...
    mean(sinr_scores_aiyou), std(sinr_scores_aiyou));
fprintf("Average SINR Score for JADE = %.5f +- %.5f\n", ...
    mean(sinr_scores_jade), std(sinr_scores_jade));
fprintf("Average SINR Score for FASTICA = %.5f +- %.5f\n", ...
    mean(sinr_scores_fastica), std(sinr_scores_fastica));