%% Real data analysis script
% This script fits DTI, NNLS, CBP, and EBP mixture models
% to voxel data from a section of a brain slice.
% The particular region of interest (ROI) is known to contain fiber crossings
% The data in each voxel is split into a training set and a test set
% Then cross-validation with NNLS is used to choose a regularization parameter tb
% The fits are constrained so that the L1 norm of the weights is equal to tb
% This is motivated by our prior knowledge that the L1 norm of the weights
%  should be equal to exp(-b*lambda2)
%
% The results are saved in a file called 'results_rd6_vk_12.mat'
% The matrix 'output' contains the sum of squared errors as evaluated by the test set for various methods:
%  Column 1: iteration number
%  Column 2: optimal L1 constraint tb
%  Column 3: test error of DTI
%  Column 4: test error of NNLS
%  Column 5: test error of first order CBP
%  Column 6: test error of EBP

load('realdata0.mat');
load('realdata6.mat');
runno = 12;
%%
ii=3;
kk=1;
grid=bvecss{ii,kk};
b=1;
l1=1;
l2=0;
datas11=datas(ii,:,kk);
%%
rad = 0.8 + runno./10;

centers = sphlattice(20,rad);
[cbpcenters,delta] = sphlattice(10,rad);

pp = size(centers,2);
ppc = size(cbpcenters,2);


%%

xs = ste_tan0(centers,grid);
nvox = size(datas11,2);
grid1 = grid(:,splits1);
grid2 = grid(:,splits2);
xs1 = ste_tan0(centers,grid1);
xs2 = ste_tan0(centers,grid2);
cbpxs1 = cbpmat(cbpcenters,grid1,delta);
cbpxs2 = cbpmat(cbpcenters,grid2,delta);
xfinder1 = @(r) xfinder_random2(grid1,3,1000,3,r);
xfinder2 = @(r) xfinder_random2(grid2,3,1000,3,r);

%%

tbs = (5:15)./20;
%tbs = .5;
%tbs = [.5,1];
ntb = size(tbs,2);
%%

ntrials = nvox;

output=zeros(ntrials,6);
mse_cals = zeros(ntrials,ntb);

iis = randsample(nvox,ntrials);


%%
for iii = 1:ntrials;
    %%
    %ii=iii;
    %ii=iis(iii);
    mse_cal=zeros(1,ntb);
    y = datas11{iii};
    y1 = y(splits1);
    y2 = y(splits2);
    for jj=1:ntb;
        tb = tbs(jj);
        betaa_0 = lsqnonneg([xs1;ones(1,pp)],[y1;tb]);
        yh_0 = xs2*betaa_0;
        mse_cal(jj) = sum((y2-yh_0).^2);
    end
    tb = tbs(find(mse_cal==min(mse_cal),1));
    betaa_0 = lsqnonneg([xs2;ones(1,pp)],[y2;tb]);
    yh_0 = xs1*betaa_0;
    mse_0 = norm(y1-yh_0)^2;
    spar_0 = sum(betaa_0 > 0);
    est_0 = centers(:,betaa_0 > 0);
    mags_0 = betaa_0(betaa_0 > 0);
    xs_0 = xs2(:,betaa_0 > 0);
    
    [est_d,~,estQ,estQr] = dti_fit(y2,grid2);
    yh_d0 = exp(-squeeze(diag(grid2' * estQ * grid2)));
    %bt_d = lsqnonneg(y2,yh_d0);
    bt_d=1;
    yh_d = bt_d.*exp(-squeeze(diag(grid1' * estQ * grid1)));
    mse_d = norm(y1-yh_d)^2;
    
    betaa_c = lsqnonneg([cbpxs2;ones(1,8*ppc)],[y2;tb]);
    yh_c = cbpxs1 * betaa_c;
    temp = sum(reshape(betaa_c,ppc,[])');
    spar_c = sum(temp > 0);
    mse_c = norm(y1-yh_c)^2;

    [mags_t,est_t] = dwi_ebp( y2,tb, est_0, [xs_0;ones(1,spar_0)], mags_0, xfinder2 ,10, yh_0, est_0, mags_0);
    yh_t = ste_tan0(est_t,grid1)*mags_t;
    mse_t = norm(y1-yh_t)^2;
    output(iii,:) = [iii,tb,mse_d,mse_0,mse_c,mse_t];
    mse_cals(iii,:) = mse_cal;
    
end

%%
fname  = strcat('results_rd6_vk_',int2str(runno),'.mat');
save(fname);
%save('realdata5.mat','output','mse_cals');
