% FIT mixed sources simulation (mixing Z1 and Z2 with noise balancing signal increase)
% reproduce Fig.S8C-D


clear all; %close all;

rng(1) % For reproducibility
save_results = 0; % set to 1 to save results file
save_plots = 0;
results_path = ''; % path to the directory to save results
scenario_name = 'twoMixedSources_figS8CD';

results_path = [results_path,'\',scenario_name];
if ~exist(results_path,'dir')
    mkdir(results_path)
end

% Simulation parameters
nTrials_per_stim = 500; % number of trials per stimulus value
simReps = 50; % repetitions of the simulation
nShuff = 10; % number of permutations (used for both FIT permutation tests)

alpha_range = 0:0.1:1; % range of alpha y parameter 
beta_range = 0:0.1:1; % range of beta x parameter 

epsNeural = 1; % standard deviation of gaussian noise in X_noise and Y
ratio_noise_Z = 0.2;

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 2; 
opts.n_binsY = 2; 
opts.n_binsZ = 2; 
opts.n_binsS = 2; % Number of stimulus values

shuff_types = {'cond','simple'}; % 'cond' in the shuffling of X at fixed S, 'simple' is the shuffling of S across all trials
null2plot = 'max'; % null hypothesis to use (maximum between 'simple' and 'cond')
n_boot = 500; % number of sample of the null distribution
prctile_plt = 99; % percentile used to determine significance

Z_encoding = [-1 1];

xy_comm = 1; % c parameter

% Initialize structures
fit_xy = nan(simReps,numel(alpha_range),numel(beta_range)); di_xy = fit_xy; dfi_xy = fit_xy; 
fitSh_xy.simple = nan(simReps,numel(alpha_range),numel(beta_range),nShuff); diSh_xy.simple = fitSh_xy.simple; dfiSh_xy.simple = fitSh_xy.simple;
fitSh_xy.cond = nan(simReps,numel(alpha_range),numel(beta_range),nShuff); diSh_xy.cond = fitSh_xy.simple; dfiSh_xy.cond = fitSh_xy.simple;
fit_yx = nan(simReps,numel(alpha_range),numel(beta_range)); di_yx = fit_xy; dfi_yx = fit_xy; 
fitSh_yx.simple = nan(simReps,numel(alpha_range),numel(beta_range),nShuff); diSh_yx.simple = fitSh_xy.simple; dfiSh_yx.simple = fitSh_xy.simple;
fitSh_yx.cond = nan(simReps,numel(alpha_range),numel(beta_range),nShuff); diSh_yx.cond = fitSh_xy.simple; dfiSh_yx.cond = fitSh_xy.simple;

infoX1 = fit_xy; infoX2 = fit_xy; infoY1 = fit_xy; infoY2 = fit_xy; infoZ11 = fit_xy; infoZ12 = fit_xy; infoZ21 = fit_xy; infoZ22 = fit_xy;
%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for aIdx = 1:numel(alpha_range)
        for bIdx = 1:numel(alpha_range)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % simulate mixed sources
            Z11 = encoding_function(S,Z_encoding,1,0) + ratio_noise_Z*epsNeural*randn(1,nTrials); % X noise time series
            Z12 = encoding_function(S,Z_encoding,1,0) + ratio_noise_Z*epsNeural*randn(1,nTrials); % X noise time series

            Z21 = ratio_noise_Z*epsNeural*randn(1,nTrials); % X noise time series
            Z22 = Z11 + ratio_noise_Z*epsNeural*randn(1,nTrials); % X noise time series
            
            X1 = Z11 + alpha_range(aIdx)*Z21 + (1+eps)*epsNeural*randn(1,nTrials); % balance X1 --> signal in X1 does not depend on mixing, since Z21 does not carry signal
            X2 = Z12 + alpha_range(aIdx)*Z22 + (1+alpha_range(aIdx)+eps)*epsNeural*randn(1,nTrials); %  balance X2 --> signal in X2 increases linearly with alpha --> we increase noise as 1+alpha to balance the SNR
             
            Y1 = Z21 + beta_range(bIdx)*Z11 + (1+beta_range(bIdx)+eps)*epsNeural*randn(1,nTrials); %  balance Y1 --> signal in Y1 increases linearly with beta --> we increase noise as 1+beta to balance the SNR
            Y2 = Z22 + beta_range(bIdx)*Z12 + (1+beta_range(bIdx)+eps)*epsNeural*randn(1,nTrials); % balance Y12 --> signal in Y2 increases linearly with beta --> we increase noise as 1+beta to balance the SNR
            
            % Discretize neural activity
            edgs = eqpop(Z11, opts.n_binsX);
            [~,bZ11] = histc(Z11, edgs);
            edgs = eqpop(Z12, opts.n_binsX);
            [~,bZ12] = histc(Z12, edgs);

            edgs = eqpop(Z21, opts.n_binsX);
            [~,bZ21] = histc(Z21, edgs);
            edgs = eqpop(Z22, opts.n_binsX);
            [~,bZ22] = histc(Z22, edgs);

            edgs = eqpop(X1, opts.n_binsX);
            [~,bX1] = histc(X1, edgs);
            edgs = eqpop(X2, opts.n_binsX);
            [~,bX2] = histc(X2, edgs);

            edgs = eqpop(Y1, opts.n_binsY);
            [~,bY1] = histc(Y1, edgs);
            edgs = eqpop(Y2, opts.n_binsY);
            [~,bY2] = histc(Y2, edgs);

            [di_xy(repIdx,aIdx,bIdx),dfi_xy(repIdx,aIdx,bIdx),fit_xy(repIdx,aIdx,bIdx)]=...
                compute_FIT_TE(S, bX1, bY2, bY1);
            
            [di_yx(repIdx,aIdx,bIdx),dfi_yx(repIdx,aIdx,bIdx),fit_yx(repIdx,aIdx,bIdx)]=...
                compute_FIT_TE(S, bY1, bX2, bX1);

            for shIdx = 1:nShuff

                % conditioned shuff (shuffle X at fixed S)
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    ridx = randperm(sum(idx));

                    tmpX = bX1(idx);
                    X1Sh(1,idx) = tmpX(ridx);
                    tmpY = bY1(idx);
                    Y1Sh(1,idx) = tmpY(ridx);
                end

                [diSh_xy.cond(repIdx,aIdx,bIdx,shIdx),dfiSh_xy.cond(repIdx,aIdx,bIdx,shIdx),fitSh_xy.cond(repIdx,aIdx,bIdx,shIdx)]=...
                    compute_FIT_TE(S, X1Sh, bY2, bY1);
                
                [diSh_yx.cond(repIdx,aIdx,bIdx,shIdx),dfiSh_yx.cond(repIdx,aIdx,bIdx,shIdx),fitSh_yx.cond(repIdx,aIdx,bIdx,shIdx)]=...
                    compute_FIT_TE(S, Y1Sh, bX2, bX1);
                
                % simple shuff (shuffle X across all trials)
                idx = randperm(nTrials);
                Ssh = S(idx);
                X1Sh = bX1(idx);
                Y1Sh = bY1(idx);
                
                [~,dfiSh_xy.simple(repIdx,aIdx,bIdx,shIdx),fitSh_xy.simple(repIdx,aIdx,bIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bX1, bY2, bY1);
                [diSh_xy.simple(repIdx,aIdx,bIdx,shIdx)]=...
                    DI_infToolBox(X1Sh, bY2, bY1, 'naive', 0);

                [~,dfiSh_yx.simple(repIdx,aIdx,bIdx,shIdx),fitSh_yx.simple(repIdx,aIdx,bIdx,shIdx)]=...
                    compute_FIT_TE(Ssh, bY1, bX2, bX1);
                [diSh_yx.simple(repIdx,aIdx,bIdx,shIdx)]=...
                    DI_infToolBox(Y1Sh, bX2, bX1, 'naive', 0);
            end

            % Compute info in signals, used in previous checks
            % t = 1
            [M_x1s, nt] = buildr(S,bX1);
            opts.nt = nt;
            infoX1(repIdx,aIdx,bIdx) = information(M_x1s, opts, 'I');
            [M_y1s, nt] = buildr(S,bY1);
            opts.nt = nt;
            infoY1(repIdx,aIdx,bIdx) = information(M_y1s, opts, 'I');
            [M_z11s, nt] = buildr(S,bZ11);
            opts.nt = nt;
            infoZ11(repIdx,aIdx,bIdx) = information(M_z11s, opts, 'I');
            [M_z21s, nt] = buildr(S,bZ21);
            opts.nt = nt;
            infoZ21(repIdx,aIdx,bIdx) = information(M_z21s, opts, 'I');
            % t = 2
            [M_x2s, nt] = buildr(S,bX2);
            opts.nt = nt;
            infoX2(repIdx,aIdx,bIdx) = information(M_x2s, opts, 'I');
            [M_y2s, nt] = buildr(S,bY2);
            opts.nt = nt;
            infoY2(repIdx,aIdx,bIdx) = information(M_y2s, opts, 'I');
            [M_z12s, nt] = buildr(S,bZ11);
            opts.nt = nt;
            infoZ12(repIdx,aIdx,bIdx) = information(M_z12s, opts, 'I');
            [M_z22s, nt] = buildr(S,bZ21);
            opts.nt = nt;
            infoZ22(repIdx,aIdx,bIdx) = information(M_z22s, opts, 'I');
        end
    end
    [di_z1z2(repIdx),dfi_z1z2(repIdx,bIdx),fit_z1z2(repIdx)]=...
        compute_FIT_TE(S, bZ11, bZ22, bZ21);
    [di_z2z1(repIdx),dfi_z2z1(repIdx,bIdx),fit_z2z1(repIdx)]=...
        compute_FIT_TE(S, bZ21, bZ12, bZ11);
end

if save_results
    fname = [scenario_name '_' date '.mat'];
    save([results_path,'\',fname])
end

%% Plot panels S8C and S8D
maxZ = 0.175; % max color axis

figure('Position',[146,159,926,420])
subplot(2,3,1)
imagesc(alpha_range,beta_range,squeeze(mean(fit_xy,1)))
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')


c=colorbar();
caxis([0,maxZ])

title('FIT(X->Y)')

subplot(2,3,4)
imagesc(alpha_range,beta_range,squeeze(mean(di_xy,1)))
title('TE(X->Y)')
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')

c=colorbar();
caxis([0,maxZ])

subplot(2,3,2)
imagesc(alpha_range,beta_range,squeeze(mean(fit_yx,1)))
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')

c=colorbar();
caxis([0,maxZ])

title('FIT(Y->X)')

subplot(2,3,5)
imagesc(alpha_range,beta_range,squeeze(mean(di_yx,1)))
title('TE(Y->X)')
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')

c=colorbar();
caxis([0,maxZ])

ax(1)=subplot(2,3,3);
imagesc(alpha_range,beta_range,squeeze(mean(fit_xy-fit_yx,1)))
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')


c=colorbar();
caxis([0,maxZ])

title('Delta FIT')

ax(2)=subplot(2,3,6);
imagesc(alpha_range,beta_range,squeeze(mean(di_xy-di_yx,1)))
title('Delta TE')
ylabel('alpha')
xlabel('beta')
set(gca,'YDir','normal')

c=colorbar();
caxis([0,maxZ])

sgtitle(['X = Z1 + alpha*Z2  Y = beta*Z1+Z2, FIT(Z1->Z2)=' num2str(mean(fit_z1z2(1,:)),2)])
if save_plots 
    fname = [scenario_name '_FIT_TE_c' num2str(xy_comm) '_noise' num2str(epsNeural) '_' date '.png'];
    print([results_path '\' fname],'-dpng');
end
