% Plot script to compute results in Fig.3 (all panels) and S6 (panels B-E) of the paper

% Different time-delay windows are used for the different panels (depending
% on whther you want to reproduce results for the V1-V3A pair - i.e. plot
% maps of all time-delay points - or for the full network (compute values
% in the FIT-significant window)
% See comments in this section for the details on how to set parameters for
% the different panels

% To reproduce DFI results (Fig. S15D) change line 258 as explained in the corresponding line

clear all; clc;

% Define parameters for the analysis
nShuff = 10; % nShuff (subselect if larger)
codeVersion = ['NIPS_paper'];
params.neural_feature = '1D';
bands_order = {'band_40_75'}; % Frequency band (gamma)
params.delay_type = 'mean'; % We take the mean over delays to compute information values
time_methods = {'maximum'}; % We take the maximum over time to compute information values

% code up to the 'Plot results' section (i.e. up to the 'END OF FIST SECTION', line 215)
% shuld run for both 'conditioned' and 'simple' - to then take maximum 
% between the two at the plotting stage. However almost identical results
% are obtained only using the 'conditioned' shuffling

params.shuffType = 'conditioned'; 

% Panel-specific parameters:

% Time window used to compute TE and FIT results
% FIT specific window: (Fig. 3D,E,F,G,H and Fig. S6B)
% bands_TW = [0.2, 0.4]
% Whole window: (Figure 3C and Fig. S6C,D)
% bands_TW = [-0.1, 0.5]
% TE-specifi box: (Figure S6C,E)
% bands_TW = [0.15 0.35]
bands_TW = [0.2, 0.4];
% Delay window used to compute TE and FIT results
% FIT specific window: (Fig. 3D,E,F,G,H and Fig. S6B)
% params.minSelDelay = 3 (50ms) and params.maxSelDelay = 15 (250ms)
% Whole window: (Figure 3C and Fig. S6C,D)
% params.minSelDelay = 3 (50ms) and params.maxSelDelay = 15 (250ms)
% TE-specific window: (Figure 3C and Fig. S6C,D)
% params.minSelDelay = 1 (16ms) and params.maxSelDelay = 4 (67ms)
params.minSelDelay = 3;
params.maxSelDelay = 15;


%% Compute info values from time-delay maps and get a netweork for each session
% How we pool onformation blacross rois when computing forward/feedback
% indeces, either 'mean' (not sensitive to #rois fwd/fbk in the hierarchy)
% or 'sum' (sensitive to #rois fwd/fbk in the hierarchy)
params.fwd_fbk_pooling = 'mean';
    
for tIdx = 1:numel(time_methods) % either 'mean' or 'maximum'
    params.info_time_pooling = time_methods{tIdx};
    
    params.band_label = bands_order{1}; 
    disp(['Computing ', params.band_label, '; ', params.info_time_pooling, ' time']);
    save_results = 0;
    baseline_subtract = 0;
    params.zeros_with_nan = 0;

    params.error_type='subj'; % Either 'rois' to compute errors across ROIs or 'subj' to compute errors across subjects
    correrr_labels = {'all','corr','err'};
    %correrr_labels = {'all'};

    % Define paths
    paths.mainPath = pwd;
    paths.resultsPath = pwd;
    paths.scriptsPath = [paths.mainPath 'Scripts'];
    addpath(genpath(paths.scriptsPath))
    paths.dataPath = [paths.mainPath 'data'];
    paths.loadPath = [paths.resultsPath 'Results/' codeVersion '/' params.band_label '/' params.neural_feature];
    addpath(genpath(paths.loadPath))
    paths.save = []; % path where results will be saved


    if ~exist(paths.save, 'dir')
       mkdir(paths.save)
    end

    % Loading meta data
    load([paths.dataPath, '/times'])
    load([paths.dataPath '/rois_info'])
    totROIs = numel(rois);

    % initialize structures
    infoTransf_maps.all = []; infoTransf_maps.corr = []; infoTransf_maps.err = [];
    sigTimes=[]; % cell array containing significant time points (considered as putative emitting-receiving times for each ROI
    computedPairs = zeros(totROIs,totROIs);

    % Select ROIs pairs files to load
    matFiles=dir(fullfile(paths.loadPath,'*.mat'));
    tmpROIs = load([paths.loadPath,'/',matFiles(1).name]);
    tmp_params = tmpROIs.params;
    tmp_params.nShuff = tmpROIs.opts.nShuff;
    tmp_params.band_label = params.band_label;
    bandIdx = find(strcmp(tmp_params.freqs_labels,params.band_label));
    all_sig_areas = {'V1','V3A','LO3'}; % areas to load

    for k = 1:length(matFiles)
        
        tmpROIs = load([paths.loadPath,'/',matFiles(k).name],'Xlab','Ylab','params');
        tmp_params = tmpROIs.params;
        % Check that both ROIs belong to the selected ones
        if (sum(strcmp(all_sig_areas,tmpROIs.Xlab))>0) && (sum(strcmp(all_sig_areas,tmpROIs.Ylab))>0)
            tmpROIs = load([paths.loadPath,'/',matFiles(k).name]);
            tmp_params = tmpROIs.params;
            params.nShuff = tmpROIs.opts.nShuff;

            % Check that both emitter and receiver ROIs carry information for a
            % minimum number of time points
            if (numel(tmpROIs.sigTime.X) > params.min_time_points*numel(tmp_params.timePoints)) && (numel(tmpROIs.sigTime.Y) > params.min_time_points*numel(tmp_params.timePoints))
                computedPairs(tmpROIs.roiX,tmpROIs.roiY)=1;
                pairLab = ['pair_',tmpROIs.Xlab,'_',tmpROIs.Ylab];
                pairLab = matlab.lang.makeValidName(pairLab); % replaces characters that are not allowed as filed naes such as '-'

                for ceIdx = 1:numel(correrr_labels)
                %for ceIdx = 1:3
                    ceLab = correrr_labels{ceIdx};

                    tmpInfoTransf.(ceLab) = tmpROIs.infoTransf_maps.(ceLab);
                    if strcmp(ceLab,'all') % shufflings computed only for all trials, we want to test signficance there
                        tmpInfoTransfSh.(ceLab) = tmpROIs.infoTransfSh_maps.(ceLab);
                    end

                    % Rearrange the info measures from cell to mat arrays
                    infoTransf_maps.(ceLab).(pairLab) = [];
                    infoTransfSh_maps.(ceLab).(pairLab) = [];
                    for infoIdx = 1:numel(tmp_params.info_type)
                        infoLab = tmp_params.info_type{infoIdx};
                        for featIdx = 1:numel(tmp_params.selected_features)
                            featLab = tmp_params.selected_features{featIdx};
                            if isfield(tmpInfoTransf.(ceLab),params.band_label) % some pairs of frequencies are not computed (no significant time points in a frequency band for the emitter or the receiver)
                                if isfield(tmpInfoTransf.(ceLab).(params.band_label).(featLab),infoLab) % corr and err don't have FIT_C
                                    for subjIdx = 1:tmp_params.nSubj
                                        infoTransf_maps.(ceLab).(pairLab).(params.band_label).(featLab).(infoLab)(subjIdx,:,:,:) = tmpInfoTransf.(ceLab).(params.band_label).(featLab).(infoLab){subjIdx};
                                        if strcmp(ceLab,'all')
                                            infoTransfSh_maps.(ceLab).(pairLab).(params.band_label).(featLab).(infoLab)(subjIdx,:,:,:,:) = tmpInfoTransfSh.(ceLab).(params.band_label).(featLab).(infoLab){subjIdx};
                                        end
                                    end
                                end
                            end
                        end
                    end
                    validXroi = matlab.lang.makeValidName(tmpROIs.Xlab);
                    sigTimes.(validXroi) = tmpROIs.sigTime.X;
                end
            end
        end
    end
    tmpROIs.params=rmfield(tmpROIs.params,'rois_sel_method');

    tmp_params = tmpROIs.params;
    % Copy fields of tmp_params into params
    for fn = fieldnames(tmp_params)'
       params.(fn{1}) = tmp_params.(fn{1});
    end
    params.nShuff = nShuff;

    computedTPoints = params.timePoints;
    baselineTime = 2:find(times(computedTPoints) == 0.1); % I cut time point 1 since it's always zero
    computedROIs = rois(sum(computedPairs)>0);
    % Colors for the 4 visual subgroups
    cols = [11, 13, 186; 0, 119, 182; 0, 180, 216; 144, 224, 239]./255;

    % Replace manual selection of color with computedROIsGroups and
    % computedROIsCols = groupColor(computedROIsGroups)
    computedROIsGroups = [];
    orderedComputedROIs = [];
    for roi = 1:numel(computedROIs)
        roiIdx = find(strcmp(ordered_rois,computedROIs{roi}));
        orderedComputedROIs(roi) = roiIdx;
        computedROIsGroups = [computedROIsGroups gs_all(roiIdx)-1]; % -1 because we don't consider the auditory cortical group 1
    end
    [orderedComputedROIs,sort_idxs] = sort(orderedComputedROIs);
    computedROIsCols = [cols(computedROIsGroups,:)];

    % Compute connectivity strengths between all pairs
    sel_time_window = bands_TW; % seconds from stimulus presentation
    params.frequency_bands = {params.band_label};

    disp('Starting computing connectivity networks')
    for ceIdx = 1:numel(correrr_labels)
         ceLab = correrr_labels{ceIdx};
         if strcmp(ceLab,'all')
             doShuff = 1;
         else
             doShuff = 0;
         end
         % v3 to compute separate left and right networks
        [connStrenghts.(ceLab),sessConnStrenghts.(ceLab),subjConnStrenghts.(ceLab),staticNetworks.(ceLab),sessStaticNetworks.(ceLab),staticNetworks_vis.(ceLab),sessConnStrenghtsSh.(ceLab),timeDelayMaps.(ceLab),timeDelayMapsSh.(ceLab),sessTimeDelayMaps.(ceLab),sessTimeDelayMapsSh.(ceLab)] = ...
           compute_connectivity_net(infoTransf_maps.(ceLab),infoTransfSh_maps.(ceLab),sigTimes,computedROIs,sel_time_window,times,computedPairs,computedROIsGroups,rois,params,doShuff);
    end

    % Build null hypothesis across subj
    startSess = 1;
    for subj = 1:params.nSubj
        for infoIdx = 1:numel(params.info_type)
            infoLab = params.info_type{infoIdx};
            subjConnStrenghtsSh.all.(infoLab).(params.band_label)(subj,:,:,:) = nanmean(sessConnStrenghtsSh.all.(infoLab).(params.band_label)(startSess:startSess+3,:,:,:),1);
        end
        startSess = startSess + 4;
    end

    time_window_str{1} = strrep(num2str(sel_time_window(1)),'.','');
    time_window_str{2} = strrep(num2str(sel_time_window(2)),'.','');
    fname = ['visNet_timeDelay_',params.band_label,'_',params.shuffType,'Shuff_T_',time_window_str{1},'_',time_window_str{2},'_',params.info_time_pooling,'Time_',params.delay_type,'Del_',num2str(params.minSelDelay),'-',num2str(params.maxSelDelay),'_',date];
    disp(['Saving ', fname])
    save([paths.save,fname],'connStrenghts','sessConnStrenghts','staticNetworks','sessStaticNetworks','params','sel_time_window','paths','computedROIs','computedROIsCols','sort_idxs','sigTimes','timeDelayMaps','timeDelayMapsSh','sessTimeDelayMaps','sessTimeDelayMapsSh')
end

%%% END OF FIRST SECTION OF THE SCRIPT %%%

%% Plot results
% parameters of the resuts to load
clear all; close all;

rng(1)

tmp_params.bands_labels = {'Gamma band'};
bands_order = {'band_40_75'};

tmp_params.shuffType = 'conditioned'; % either 'simple' or 'conditioned'
tmp_params.communication_win = 'full'; % either 'full' or 'reduced'
tmp_params.version = 'NIPS_paper';
tmp_params.simpleShuffVersion = 'NIPS_paper';
tmp_params.save_folder = 'NIPS_paper_results';
tmp_params.doGLMEstat = 0;
tmp_params.plotGroups = 1;
tmp_params.clusterParams = [0.01,0.99];
tmp_params.nShuff = 10; % number of shufflings
tmp_params.nBoot = 500; % bbotstrap to use for the null hypothesis

subsel_ROIs = {'V1','V3A','LO3'}; % {'V1','V3A','LO3'} for Fig.3D-H and FigS6B; {'V1','V3A'} for Fig.3C and FigS6C,D;
% Time window files to load
% FIT specific window: (Fig. 3D,E,F,G,H and Fig. S6B)
% load_time_window = [0.2, 0.4]
% Whole window: (Figure 3C and Fig. S6C,D)
% load_time_window = [-0.1, 0.5]
% TE-specifi box: (Figure S6C,E)
% load_time_window = [0.15 0.35]
load_time_window = [0.2 0.4];

rois_save_label = subsel_ROIs{1};
for i = 1:numel(subsel_ROIs)-1;
    rois_save_label = [rois_save_label,'_',subsel_ROIs{i+1}];
end
for bIdx = 1 % i.e. gamma band
    tmp_params.selected_band = bands_order{bIdx};
    tmp_params.bandIdx = find(strcmp(bands_order, tmp_params.selected_band) == 1);
    tmp_params.delay_method  = 'mean';
    tmp_params.time_method = 'maximum';
    tmp_params.across_subj_method = 'mean'; % either 'median' or 'median'
    tmp_params.save_figures = 1;
    tmp_params.info_type = {'DI','FIT_S','FIT_C'}; % to plot DFI result (Fig.S15D) replace with params.info_type = {'DI','DFI_S','DFI_C'}, after doing this change 'FIT' labels in plots should be read as 'DFI'

    params.version = tmp_params.version;
    correrr_labels = {'all','corr','err'};    
    local_paths.figuresPath = []; % path to save figures

    time_window_str{1} = strrep(num2str(load_time_window(1)),'.','');
    time_window_str{2} = strrep(num2str(load_time_window(2)),'.','');

    %% Load DI and FIT results
    local_paths.resultsPath = []; % Path where results computed in the first section of this script were saved
    filename = [tmp_params.version,'\visNet_timeDelay_',tmp_params.selected_band,'_',tmp_params.shuffType,'Shuff_T_',time_window_str{1},'_',time_window_str{2},'_',tmp_params.time_method,'Time_',tmp_params.delay_method,'Del'];
    matFile = dir([local_paths.resultsPath,'\',filename,'_*']);
    try
        load([local_paths.resultsPath,'\',tmp_params.version,'\',matFile.name]);
        disp(['Plotting and saving ',tmp_params.selected_band, ' ',tmp_params.time_method,' time ',tmp_params.delay_method, ' delay in time window ',time_window_str{1},' ',time_window_str{2}])
    catch
        disp(['Non existent ',tmp_params.selected_band, ' ',tmp_params.time_method,' time ',tmp_params.delay_method, ' delay in time window ',time_window_str{1},' ',time_window_str{2}])
        continue
    end
        
    % Copy fields of tmp_params into params
    for fn = fieldnames(tmp_params)'
       params.(fn{1}) = tmp_params.(fn{1});
    end
    
    %% Load simple shuff data
    tmp_params.shuffType = 'simple';
    filename = [tmp_params.simpleShuffVersion,'\visNet_timeDelay_',tmp_params.selected_band,'_',tmp_params.shuffType,'Shuff_T_',time_window_str{1},'_',time_window_str{2},'_',tmp_params.time_method,'Time_',tmp_params.delay_method,'Del'];
    matFile = dir([local_paths.resultsPath,'\',filename,'_*']);
    try
        simpleShuff = load([local_paths.resultsPath,'\',tmp_params.simpleShuffVersion,'\',matFile.name]);
        disp(['Plotting and saving ',tmp_params.selected_band, ' ',tmp_params.time_method,' time ',tmp_params.delay_method, ' delay in time window ',time_window_str{1},' ',time_window_str{2}])
    catch
        disp(['Non existent simple shuff ',tmp_params.selected_band, ' ',tmp_params.time_method,' time ',tmp_params.delay_method, ' delay in time window ',time_window_str{1},' ',time_window_str{2}])
        continue
    end
    %%

    idxMinD = strfind(tmp_params.version,'minD'); % minimum delay considered when picking values over delays
    minD = tmp_params.version(idxMinD+4);
    
    local_paths.figuresPath = [local_paths.figuresPath,tmp_params.save_folder,'\',rois_save_label,'\',tmp_params.across_subj_method,'_band',tmp_params.selected_band(6:end),'\',tmp_params.communication_win,'Window\T_',time_window_str{1},'_',time_window_str{2},'_',tmp_params.time_method,'T_',params.rec_delay_type,'_',tmp_params.delay_method,'D_minD',minD];
     % Create figures directory if missing
    if ~exist(local_paths.figuresPath, 'dir')
       mkdir(local_paths.figuresPath)
    end
    
    paths.mainPath = pwd;
    paths.scriptsPath = [paths.mainPath 'Scripts'];

    % Compute ROIs cols
    load([paths.mainPath 'metadata\metadata'])
    load([paths.mainPath 'metadata\t1_t2_ratio'])
    load([paths.mainPath 'metadata\times'])
    [~,idx]=sort(t1_t2_ratio,'descend');
    t1_t2_ordered_rois = ordered_rois(idx);
    
    tMinIdx = find(times == load_time_window(1));
    tMaxIdx = find(times == load_time_window(2));

    plot_times = times(tMinIdx:tMaxIdx)*1000;    
    
    groupsName = groups;
    cols_groups = zeros(21,3);
    
    % ROIs subselection
    if strcmp(subsel_ROIs,'all')
        subsel_ROIs = computedROIs;
    end
    subselIdx = find(ismember(computedROIs,subsel_ROIs));
    
    computedROIs = computedROIs(subselIdx);
    for ceIdx = 1:numel(correrr_labels)
        ceLab = correrr_labels{ceIdx};
        for infoIdx = 1:numel(params.info_type)
            infoLab = params.info_type{infoIdx};
            
            totROIs = numel(staticNetworks.(ceLab).(infoLab).(params.band_label).Nodes);
            % subsample network graphs
            staticNetworks.(ceLab).(infoLab).(params.band_label) = rmnode(staticNetworks.(ceLab).(infoLab).(params.band_label),setdiff(1:totROIs,subselIdx));
            
            % Compute null
            connStrenghts.(ceLab).(infoLab).(params.band_label) = connStrenghts.(ceLab).(infoLab).(params.band_label)(subselIdx,subselIdx);
            sessConnStrenghts.(ceLab).(infoLab).(params.band_label) = sessConnStrenghts.(ceLab).(infoLab).(params.band_label)(:,subselIdx,subselIdx);
            timeDelayMaps.(ceLab).(infoLab).(params.band_label) = timeDelayMaps.(ceLab).(infoLab).(params.band_label)(subselIdx,subselIdx,:,:);
            simpleShuff.timeDelayMaps.(ceLab).(infoLab).(params.band_label) = simpleShuff.timeDelayMaps.(ceLab).(infoLab).(params.band_label)(subselIdx,subselIdx,:,:);
            if strcmp(ceLab,'all')
                sessTimeDelayMapsSh.(ceLab).(infoLab).(params.band_label) = sessTimeDelayMapsSh.(ceLab).(infoLab).(params.band_label)(:,subselIdx,subselIdx,:,:,:);
                simpleShuff.sessTimeDelayMapsSh.(ceLab).(infoLab).(params.band_label) = simpleShuff.sessTimeDelayMapsSh.(ceLab).(infoLab).(params.band_label)(:,subselIdx,subselIdx,:,:,:);
            end
        end
    end
    % Replace manual selection of color with computedROIsGroups and
    % computedROIsCols = groupColor(computedROIsGroups)
    computedROIsGroups = [];
    orderedComputedROIs = [];
    
    for roi = 1:numel(computedROIs)
        roiIdx = find(strcmp(ordered_rois,computedROIs{roi}));
        orderedComputedROIs(roi) = roiIdx;
        computedROIsGroups = [computedROIsGroups ordered_groups(roiIdx)]; % -1 because we don't consider the auditory cortical group 1
    end
    [orderedComputedROIs,sort_idxs] = sort(orderedComputedROIs);
    computedROIsCols = [cols_groups(computedROIsGroups,:)];
    computedROIsGroupsLab = {groupsName{computedROIsGroups}};

    rng(1);
    
    % Compute connectivity and evaluate p-values
    for ceIdx = 1:numel(correrr_labels)
        ceLab = correrr_labels{ceIdx};
        for infoIdx = 1:numel(params.info_type)
            infoLab = params.info_type{infoIdx};
            
            % Compute null by combining permutations across subjects
            if strcmp(ceLab,'all')
                timeDelayMapsBootCondSh.all.(infoLab).(params.band_label) = btstrp_shuff(sessTimeDelayMapsSh.all.(infoLab).(params.band_label),params.nBoot);
                timeDelayMapsBootSimpleSh.all.(infoLab).(params.band_label) = btstrp_shuff(simpleShuff.sessTimeDelayMapsSh.all.(infoLab).(params.band_label),params.nBoot);
                timeDelayMapsBootMaxSh.all.(infoLab).(params.band_label) = squeeze(max(cat(6,timeDelayMapsBootCondSh.all.(infoLab).(params.band_label),timeDelayMapsBootSimpleSh.all.(infoLab).(params.band_label)),[],6));
                
                 % Clean the RAM
                sessTimeDelayMapsSh.all = rmfield(sessTimeDelayMapsSh.all,infoLab);
                simpleShuffsessTimeDelayMapsSh.all = rmfield(simpleShuff.sessTimeDelayMapsSh.all,infoLab);
            end
            
            subjConnStrenghts.(ceLab).(infoLab).(params.band_label) = [];
            y = sessConnStrenghts.(ceLab).(infoLab).(params.band_label);
            y = permute(y,[2, 1, 3]);
            y = temporal_rebinning(y,4,'movmean');
            y = permute(y,[2, 1, 3]);
            subjConnStrenghts.(ceLab).(infoLab).(params.band_label) = y;
            medianConnStrenghts.(ceLab).(infoLab).(params.band_label) = squeeze(median(subjConnStrenghts.(ceLab).(infoLab).(params.band_label)));

        end
    end

    if strcmp(params.across_subj_method, 'mean')
        plot_connectivity_matrix_visNet_v3(connStrenghts,sessConnStrenghts,staticNetworks,computedROIs,computedROIsCols,sel_time_window,params,sort_idxs,8,local_paths);
    elseif strcmp(params.across_subj_method, 'median')
        plot_connectivity_matrix_visNet_v3(medianConnStrenghts,sessConnStrenghts,staticNetworks,computedROIs,computedROIsCols,sel_time_window,params,sort_idxs,8,local_paths);
    end
    
    % get relative order of ROIs in the hierarchy
    relativeOrder.all = get_forward_feedback_ROIs(computedROIs,ordered_rois);

    % Compute directionality idxs
    for ceIdx = 1:numel(correrr_labels)
        ceLab = correrr_labels{ceIdx};
        if strcmp(params.across_subj_method, 'mean')
            [directTransf.(ceLab), sessDirectTransf.(ceLab), subjDirectTransf.(ceLab)] = compute_fwd_fbk_transfer_v4(connStrenghts.(ceLab),sessConnStrenghts.(ceLab),relativeOrder,computedROIs,computedROIsGroupsLab,params,sort_idxs);

        elseif strcmp(params.across_subj_method, 'median')
            [directTransf.(ceLab), sessDirectTransf.(ceLab), subjDirectTransf.(ceLab)] = compute_fwd_fbk_transfer_v4(medianConnStrenghts.(ceLab),sessConnStrenghts.(ceLab),relativeOrder,computedROIs,computedROIsGroupsLab,params,sort_idxs);

        end
    end

    % Time-delay maps
    plot_pair = 'all_all';
    [FIT_td_map,DI_td_map,FIT_td_mapSh]=plot_time_delay_maps_visNet_v4(timeDelayMaps,timeDelayMapsBootMaxSh,computedROIs,sort_idxs,plot_times,'all',plot_pair,params,local_paths);
    
    plot_time_delay_profiles_visNet_v1(FIT_td_map,DI_td_map,plot_times,plot_pair,params);
    plot_fwd_vs_bwd_GLME_visNet_v4(sessDirectTransf,load_time_window,params,computedROIsGroupsLab,sort_idxs,local_paths)
    
    % FIT_S vs FIT_C
    plot_FIT_S_vs_C_GLME_visNet_v4(sessDirectTransf,load_time_window,params,local_paths)
    % Corr vs err
    plot_corr_err_GLME_visNet_v5(sessDirectTransf,staticNetworks,sort_idxs,load_time_window,params,local_paths)
 
end