function [x,info]=LogisticSparseSketch(A,b,options)
%% [x,info]=LogisticSparseSketch(A,b,options) 
% where A is the data matrix and b contains the labels {-1,+1}
% f: \mathbb R^d \rightarrow \mathbb R
% f(x) = gamma/2 ||x||_2^2 - \tfrac{1}{n} \sum_{i=1}^n log( \sigmoid( b_i a_i^T x)
% A = (a_1, ... , a_n)^T \in \mathbb R^{d \times n}
% b = (b_1, ... , b_n)^T \in {-1,1}^n
%
%Additional mandatory Input:
% options a struct with the fields
% options.method can have the values 1 (sketched Newton)
%                                    2 (plain Newton with Line-Search)
%                                    3 (Gradient)
%                                    4 (Accelerated Gradient)
% options.sketchsize positive integer that controls the stepsize
%options.Atest test data matrix
%options.b     test labels
%
%options.maxit should be a positive integer
%options.tol should be some small numer
%options.gamma regularization parameter should be larger than zero
%options.L Lipschitz constant needed only for gradient type methods
%options.printout should be 0 or 1 (0 means no printout, 1 means additional
%computations)
%options.savemod controls how often we save function values and gradient
%norm, iterations and wallclock time
%
%
% Output:
% x the approximate optimal solution computed
% info a struct that contains supplementary data, such as iterations
% norm of gradient, time, function values



info.time.starttime=tic;
info.method=options.method;
if options.method==1
info.sketchsize=options.sketchsize;
else
    info.sketchsize=[];
end
info.trackk=1;
[options.numsamples,options.numfeatures]=size(A);             %In our paper the dimension is d which is a problem, because our direction is also d
x=zeros(options.numfeatures,1);               %Initial point

k=0; %iteration counter
%See what we are dealing with
[info,normg,bATx,sigmoidbATx]=FullEval(x,A,b,options,info,k); 

%First direction
if options.method==1
[lambda,S,bATdx,normdxsquared,dxTx,la,laprime]=ComputeSketchedNewtonDirection(x,sigmoidbATx,A,b,options);
elseif options.method==2 %Newton
    H=(A*A.')./options.gamma; %Precomputation saves time
[dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullNewtonDirection(x,sigmoidbATx,A,b,H,options,info,k);
elseif options.method==3 || options.method==4 %Gradienet /Fast Gradient
[dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullGradient(x,sigmoidbATx,A,b,options,info,k);
if options.method==4
    theta=1;
    y=x;  
end 
else
    error('')
end

%Main loop
while normg>options.tol && k<options.maxit
k=k+1;

%LineSearch
if 1 %Change to zero to turn of line-search
[t,bATxnew,sigmoidevalnew]=LineSearch(bATx,bATdx,normdxsquared,dxTx,la,laprime,options);
else 
    t=1;
    bATxnew=bATx+bATdx;
    sigmoidevalnew=sigmoid(bATxnew);
end
%Update Iterate
if options.method==1
    
        x(S)=x(S)+t*lambda;
elseif options.method==2 || options.method==3
  
x=x+t*dx;
elseif options.method==4
    ynew=x+t*dx;
    thetanew=(1+sqrt(4*theta^2+1))/2;
    x=ynew+((theta-1)/thetanew)*(ynew-y);
    theta=thetanew;
    y=ynew;
    bATxnew=b.*(A*x); %bATxnew from before is actually equal to b.*(A*ynew), and we can not avoid this computation 
    sigmoidevalnew=sigmoid(bATxnew);
else
    error('Please specify options.method= ? (1 sketched Newton / 2 plain Newton')
end

if mod(k,options.savemod*5)==0 %See what norm of gradient looks like, evaluate again to exclude rounding errors 
 [info,normg,bATx,sigmoidbATx]=FullEval(x,A,b,options,info,k);
else
    bATx=bATxnew;
    sigmoidbATx=sigmoidevalnew;
end

%Compute next direction
if options.method==1
[lambda,S,bATdx,normdxsquared,dxTx,la,laprime]=ComputeSketchedNewtonDirection(x,sigmoidbATx,A,b,options);
elseif options.method==2 %Newton
[dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullNewtonDirection(x,sigmoidbATx,A,b,H,options,info,k);
elseif options.method==3 || options.method==4 %Gradient /Fast Gradient
[dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullGradient(x,sigmoidbATx,A,b,options,info,k);
else
    error('')
end

end
disp('Done')
info.iterations=k;
info.time.totaltime=toc(info.time.starttime);
end

function [dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullGradient(x,sigmoidbATx,A,b,options,info,k)
g=options.gamma*x+((b.*(sigmoidbATx-1)).'*A).'./options.numsamples;
normg=norm(g);


if mod(k,options.savemod)==0 || normg<=options.tol %Better resolution for gradient/ fast gradient at no extra cost
   if mod(k,options.savemod*5)==0 
    normx=norm(x);
  fval=0.5*options.gamma*normx^2-sum(log(sigmoidbATx))/options.numsamples;
  if options.printout==1
  fprintf('Iter %4.5f Value: %4.9f NormG %4.9f Normx %4.9f Train-Error %4.5f Test-Error %4.5f \n',k/options.maxit ,fval,normg,normx,sum(sign(A*x)~=b)/options.numsamples,sum(sign(options.Atest*x)~=options.btest)/size(options.btest,1))
  end
   else
       normx=nan;
       fval=nan;
       
   end
  info.Track(info.trackk,:)=[k,fval,normg,toc(info.time.starttime)];
  info.trackk=info.trackk+1;
  
end

dx=-g/(options.gamma+options.L);

bATdx=b.*(A*dx);
normdxsquared=norm(dx)^2;
dxTx=x.'*dx;
la=dx.'*g;
laprime=options.gamma*normdxsquared+sum((bATdx.^2).*(sigmoidbATx.*(1-sigmoidbATx)))/size(A,1);

end



function [info,normg,bATx,sigmoidbATx]=FullEval(x,A,b,options,info,k)
  bATx=b.*(A*x);  
  sigmoidbATx=sigmoid(bATx);
  normg=inf;
  if options.method==1
  normg=norm(options.gamma*x+((b.*(sigmoidbATx-1)).'*A).'/options.numsamples);
  normx=norm(x);
  fval=0.5*options.gamma*normx^2-sum(log(sigmoidbATx))/options.numsamples;
  info.Track(info.trackk,:)=[k,fval,normg,toc(info.time.starttime)];
  info.trackk=info.trackk+1;
  if options.printout==1
  fprintf('Iter %4.5f Value: %4.9f NormG %4.9f Normx %4.9f Train-Error %4.5f Test-Error %4.5f \n',k/options.maxit ,fval,normg,normx,sum(sign(A*x)~=b)/options.numsamples,sum(sign(options.Atest*x)~=options.btest)/size(options.btest,1))
  end
  end
end

function [dx,bATdx,normdxsquared,dxTx,la,laprime,normg,info]=ComputeFullNewtonDirection(x,sigmoidbATx,A,b,H,options,info,k)
g=options.gamma*x+((b.*(sigmoidbATx-1)).'*A).'./options.numsamples;


normg=norm(g);
normx=norm(x);
fval=0.5*options.gamma*normx^2-sum(log(sigmoidbATx))/options.numsamples;
info.Track(info.trackk,:)=[k,fval,normg,toc(info.time.starttime)];
info.trackk=info.trackk+1;
  if options.printout==1
    fprintf('Iter %4.5f Value: %4.9f NormG %4.9f Normx %4.9f Train-Error %4.5f Test-Error %4.5f \n',k/options.maxit ,fval,normg,normx,sum(sign(A*x)~=b)/options.numsamples,sum(sign(options.Atest*x)~=options.btest)/size(options.btest,1))
  end
    if normg>options.tol %Otherwise we leave
Ag=A*g;
[L,p]=chol(diag((abs((1-sigmoidbATx).*sigmoidbATx)./size(A,1)).^-1)+H,'lower');
Ag=-(L.'\(L\Ag)); 
dx=-g./options.gamma-(Ag.'*A).'./(options.gamma^2);
bATdx=b.*(A*dx);
normdxsquared=norm(dx)^2;
dxTx=x.'*dx;
la=dx.'*g;
laprime=la;
    else %We still have to give something back
        dx=0;bATdx=0;normdxsquared=0;dxTx=0;la=0;laprime=0;
    end
end



function [lambda,S,bATdx,normdxsquared,dxTx,la,laprime]=ComputeSketchedNewtonDirection(x,sigmoidbATx,A,b,options)
S=sort(randperm(options.numfeatures,options.sketchsize));                        %Using sparse sketches here

[STg,STHS]=SketchLogistic(x,sigmoidbATx,S,A,b,options);  %Sketch gradient and Hessian
[L,p]=chol(STHS,'lower');
if p==0
        lambda=-(L.'\(L\STg));                                      %Sketched Newton
else
    lambda=-STHS\STg;                                      %Sketched Newton
end

%Needed for next Line-Search
bATdx=b.*(A(:,S)*lambda);
normdxsquared=lambda.'*lambda; % This is only true if S^TS=I
dxTx=x(S).'*lambda;
la = lambda.'*STg;
laprime = la; 
end

function [STg,STHS]=SketchLogistic(x,sigmoidbATx,S,A,b,options)
% Output: sketched Gradient gS
%         sketched Hessian  STHS
STHS=A(:,S); %This will be overwritten
STg=options.gamma*x(S,1)+((b.*(sigmoidbATx-1)).'*STHS).'./size(A,1);



sigmoid2=(sqrt( abs((1-sigmoidbATx).*sigmoidbATx)./size(A,1))); %abs should be irrelevant
if issparse(STHS)
[i,j,v]=find(STHS);
STHS=sparse(i,j,v.*sigmoid2(i),size(STHS,1),size(STHS,2)); %Significant savings
else
STHS=bsxfun(@times,sigmoid2,STHS);
end
STHS=(STHS.'*STHS)+diag(options.gamma*ones(size(STHS,2),1));
end
function [t,bATxnew,sigmoidevalnew]=LineSearch(bATx,bATdx,normdxsquared,dxTx,la,laprime,options)
%Line Search 
Ia=0; Ib=1.0; count1=0;
epsilon=abs(la)*0.05;
[lb,lbprime,bATxnew,sigmoidevalnew]=SketchLineLogistic(Ib,bATx,bATdx,normdxsquared,dxTx,options.gamma,options.numsamples);
while lb<-epsilon %Enlargement Phase:  find [Ia,Ib] with l(Ia)<-epsilon
    
    %Hermite Interpolation  (To slow: PP=spline([Ia,Ib],[laprime,la,lb,lbprime].'))

    h=Ib-Ia;
    labprime=(lb-la)/h;
    ladoubleprime=(labprime-laprime)/h;
    labtrippleprime=((lbprime-labprime)/h-ladoubleprime)/h;
    Intervall=linspace(Ib,2*Ib,100).';
    Intervall=Intervall(2:end);
    Splineval=la+(Intervall-Ia).*(laprime+(Intervall-Ia).*(ladoubleprime +(Intervall-Ib).*labtrippleprime));

    
    Index=find(Splineval>=0,1,'first');
    if isempty(Index)
        Ia=Ib; Ib=2*Ib; 
    else
        Ia=Ib; Ib=Intervall(Index); 
    end
    la=lb; laprime=lbprime;
   [lb,lbprime,bATxnew,sigmoidevalnew]=SketchLineLogistic(Ib,bATx,bATdx,normdxsquared,dxTx,options.gamma,options.numsamples);
  
   count1=count1+1;
end
t=Ib; lt=lb; ltprime=lbprime;
count2=0; %Needs future tweaking
while abs(lt)>epsilon && Ib-Ia>10*eps  %Shrinking Phase: Find t with abs(l(t)) small
 if lt<0; Ia=t; la=lt; laprime=ltprime;
 else; Ib=t; lb=lt; lbprime=ltprime; 
 end
 
%Hermite Interpolation
h=Ib-Ia;
labprime=(lb-la)/h;
ladoubleprime=(labprime-laprime)/h;
labtrippleprime=((lbprime-labprime)/h-ladoubleprime)/h;

%This is a bit greedy and could be refined in the future
%because we actually have explicit formulars for the Splines root
%( which need extra safeguarding though)
Intervall=linspace(Ia,Ib,22).'; 
Intervall=Intervall(2:end-1);

Splineval=la+(Intervall-Ia).*(laprime+(Intervall-Ia).*(ladoubleprime +(Intervall-Ib).*labtrippleprime));

     
[~,Index]=min(abs(Splineval));
    t=Intervall(Index(1));
 
[lt,ltprime,bATxnew,sigmoidevalnew]=SketchLineLogistic(t,bATx,bATdx,normdxsquared,dxTx,options.gamma,options.numsamples);
count2=count2+1;
end
end
function [lt,ltprime,bATxnew,sigmoidevalnew]=SketchLineLogistic(t,bATx,bATdx,normdxsquared,dxTx,gamma,n)
bATxnew=bATx+t*bATdx;
sigmoidevalnew=sigmoid(bATxnew);
lt=gamma*(t*normdxsquared+dxTx)+ bATdx.'*(sigmoidevalnew-1)/n;
ltprime=gamma*normdxsquared+ sum((sigmoidevalnew.*(1-sigmoidevalnew)).*bATdx.^2)/n;
end
function y=sigmoid(x)
%Implementation of the sigmoid function
%
%                 1
% sigmoid(x)= ---------
%             1+exp(-x) 
y=zeros(size(x));
ind=(x<=0);
y(ind)=exp(x(ind));
y(ind)=y(ind)./(1+y(ind));
y(~ind)=1./(1+exp(-x(~ind)));
end