% FIT one mixed source, null model (Fig. S8B)

clear all; 

rng(0) % For reproducibility
save_results = 1; % set to 1 to save results file
save_plots = 0;
results_path = ''; % path to the directory to save results
scenario_name = 'mixedSources_S8B';

results_path = [results_path,'\',scenario_name];
if ~exist(results_path,'dir')
    mkdir(results_path)
end

% Simulation parameters
nTrials_per_stim = 500; % number of trials per stimulus value
simReps = 50; % repetitions of the simulation
nShuff = 10; % number of permutations (used for both FIT permutation tests)

alphaY_range = 0:0.1:1; % range of alpha y parameter 
betaX_range = 0; % range of beta x parameter 

epsNeural = 1; % standard deviation of gaussian noise in X_noise and Y
epsZ = 0.2; % noise in the source

% Define information options
opts = [];
opts.verbose = false;
opts.method = "dr";
opts.bias = 'naive';
opts.btsp = 0;
opts.n_binsX = 2; 
opts.n_binsY = 2; 
opts.n_binsZ = 2; 
opts.n_binsS = 2; % Number of stimulus values

shuff_types = {'condX','simple'}; % 'condX' is the shuffling of X at fixed S, 'simple' is the shuffling of S across all trials
null2plot = 'max'; % null hypothesis to use (maximum between 'simple' and 'cond')
n_boot = 500; % number of sample of the null distribution
prctile_plot = 99; % percentile used to determine significance

Z_encoding = [-1 1];
K_encoding = [-1 1];

xy_comm = 0; % communication parameter (set zero to reproduce Fig.8B, the real communication line in Fig.8A was obtained setting this parameter to 1) 
z_signal = 1; % set to 1 if the mixed source Z should encode S-info
alpha_X = 1;

% Initialize structures
fit = nan(simReps,numel(alphaY_range),numel(betaX_range)); di = fit; dfi = fit; 
fitSh.simple = nan(simReps,numel(alphaY_range),numel(betaX_range),nShuff); diSh.simple = fitSh.simple; dfiSh.simple = fitSh.simple;
fitSh.condX = nan(simReps,numel(alphaY_range),numel(betaX_range),nShuff); diSh.condX = fitSh.simple; dfiSh.condX = fitSh.simple;
% fitSh.condY = nan(simReps,numel(alphaY_range),numel(betaX_range),nShuff); diSh.condY = fitSh.simple; dfiSh.condY = fitSh.simple;
infoX1 = fit; infoX2 = fit; infoY1 = fit; infoY2 = fit; infoZ1 = fit; infoZ2 = fit; infoX1_Sh = nan(simReps,numel(alphaY_range),numel(betaX_range),nShuff);
info_XpYp = fit; info_shXpYp = infoX1_Sh; info_XpYt = fit; info_shXpYt = infoX1_Sh; 
%% Run simulation

for repIdx = 1:simReps
    disp(['Repetition number ',num2str(repIdx)]);
    for aIdx = 1:numel(alphaY_range)
        for bIdx = 1:numel(betaX_range)
            nTrials = nTrials_per_stim*opts.n_binsS; % Compute number of trials

            % Draw the stimulus value for each trial
            S = randi(opts.n_binsS,1,nTrials);

            % simulate mixed source
            Z1 = z_signal*encoding_function(S,Z_encoding,1,0) + epsZ*randn(1,nTrials); % X noise time series
            Z2 = 2*z_signal*encoding_function(S,Z_encoding,1,0) + epsNeural*randn(1,nTrials); % X noise time series

            % other term of information encoded in X and transmitted to Y
            % (not used in Fig.S8B)
            K1 = encoding_function(S,K_encoding,1,0) + epsNeural*randn(1,nTrials); % X noise time series
            K2 = encoding_function(S,K_encoding,1,0) + epsNeural*randn(1,nTrials); % X noise time series

            X1 = alpha_X*Z1 + betaX_range(bIdx)*K1 + epsNeural*randn(1,nTrials); % noise in X treated as measurment noise: not transmitted to Y
            X2 = alpha_X*Z2 + betaX_range(bIdx)*K2 + epsNeural*randn(1,nTrials); % 

            Y1 = alphaY_range(aIdx)*Z1 + epsNeural*randn(1,nTrials); % X noise time series
            Y2 = alphaY_range(aIdx)*Z2 + xy_comm*betaX_range(bIdx)*K1 + epsNeural*randn(1,nTrials); % X noise time series

            % Discretize neural activity
            edgs = eqpop(Z1, opts.n_binsX);
            [~,bZ1] = histc(Z1, edgs);
            edgs = eqpop(Z2, opts.n_binsX);
            [~,bZ2] = histc(Z2, edgs);

            edgs = eqpop(X1, opts.n_binsX);
            [~,bX1] = histc(X1, edgs);
            edgs = eqpop(X2, opts.n_binsX);
            [~,bX2] = histc(X2, edgs);

            edgs = eqpop(Y1, opts.n_binsY);
            [~,bY1] = histc(Y1, edgs);
            edgs = eqpop(Y2, opts.n_binsY);
            [~,bY2] = histc(Y2, edgs);

            [di(repIdx,aIdx,bIdx),dfi(repIdx,aIdx,bIdx),fit(repIdx,aIdx,bIdx)]=...
                compute_FIT_TE(S, bX1, bY2, bY1);

            for shIdx = 1:nShuff

                % conditioned shuff (shuffle X at fixed S)
                Sval = unique(S);
                for Ss = 1:numel(Sval)
                    idx = (S == Sval(Ss));
                    ridx = randperm(sum(idx));

                    tmpX = bX1(idx);
                    X1Sh(1,idx) = tmpX(ridx);
                    tmpY = bY2(idx);
                    Y2Sh(1,idx) = tmpY(ridx);
                end

                [diSh.condX(repIdx,aIdx,bIdx,shIdx),dfiSh.condX(repIdx,aIdx,bIdx,shIdx),fitSh.condX(repIdx,aIdx,bIdx,shIdx)]=...
                    compute_FIT_TE(S, X1Sh, bY2, bY1);

            end

            % Compute mutual info encoded in signals (used in previous versions of the script to check results)
            % t = 1
            [M_x1s, nt] = buildr(S,bX1);
            opts.nt = nt;
            infoX1(repIdx,aIdx,bIdx) = information(M_x1s, opts, 'I');
            [M_y1s, nt] = buildr(S,bY1);
            opts.nt = nt;
            infoY1(repIdx,aIdx,bIdx) = information(M_y1s, opts, 'I');
            [M_z1s, nt] = buildr(S,bZ1);
            opts.nt = nt;
            infoZ1(repIdx,aIdx,bIdx) = information(M_z1s, opts, 'I');
            % t = 2
            [M_x2s, nt] = buildr(S,bX2);
            opts.nt = nt;
            infoX2(repIdx,aIdx,bIdx) = information(M_x2s, opts, 'I');
            [M_y2s, nt] = buildr(S,bY2);
            opts.nt = nt;
            infoY2(repIdx,aIdx,bIdx) = information(M_y2s, opts, 'I');
            [M_z2s, nt] = buildr(S,bZ2);
            opts.nt = nt;
            infoZ2(repIdx,aIdx,bIdx) = information(M_z2s, opts, 'I');
            
        end
    end
end

if save_results
    fname = [scenario_name '_' date '.mat'];
    save([results_path,'\',fname])
end

%% Plot panel S8B
% Compute null of mean FIT
rng(0)
for shLab = shuff_types
    pooledFITsh.(shLab{1}) = btstrp_shuff(fitSh.(shLab{1}),n_boot);
    pooledTEsh.(shLab{1}) = btstrp_shuff(diSh.(shLab{1}),n_boot);
    pooledDFIsh.(shLab{1}) = btstrp_shuff(dfiSh.(shLab{1}),n_boot);
end
plot_bX = 1; % beta X value to plot
Ymax = 0.15;

% Here we use the conditioned null hyp (shuffling X at fixed S), which is
% strictly more permissive than the element-wise maximum, therefore non-sig
% under the conditioned null impolies non-sig. for the full null

% FIT and TE
figure('Position',[540,260,559,330])
hold on
fit_bX0 = squeeze(fit(:,:,plot_bX));
di_bX0 = squeeze(di(:,:,plot_bX));
h(1) = plot(alphaY_range,mean(fit_bX0(:,:,1),1),'b');

h(2) = plot(alphaY_range,prctile(squeeze(pooledFITsh.condX(:,plot_bX,:)),prctile_plot,2),'--');
h(3) =plot(alphaY_range,mean(di_bX0,1),'g');
h(4) =plot(alphaY_range,prctile(squeeze(pooledTEsh.condX(:,plot_bX,:)),prctile_plot,2),'--');
legend([h(1) h(2) h(3) h(4)],'FIT measured','FIT 99th prc.','TE measured','TE 99th prc.')
ylim([0,Ymax])

xlabel('alpha')
ylabel('[bits]')

title(['One mixed source (null model)'])
