function [f_vec1,g_vec1,time_vec1,sample_vec,x,acc_vec] = DBGD_Sto(fun_f,grad_f,grad_g,fun_g,TSA,param,x0,A1,A2,b1,b2)
% The update rule is given by 
% x_{k+1} = \Pi_{Z}(x_k - \gamma_k(\nabla f (x_k)+\lambda_k \nabla g (x_k))),
% \lambda_k = \max\{ \frac{\phi(x_k)- \nabla f(x_k)'* \nabla g(x_k)}{\|\nabla g(x_k)\|^2},0 \},
% and \phi(x) = \min\{\alpha (g(x)-\hat{g}), \beta \|\nabla g (x)\|^2\}

stepsize = param.stepsize;
alpha = param.alpha;
beta = param.beta;

lambda = param.lam;

maxiter = param.maxiter;
maxtime = param.maxtime;

x = x0;

tic;
% algorithm
iter = 0;
f_vec1 = [];
g_vec1 = [];
time_vec1 = [];
acc_vec = [];
sample_vec = [];
n1 = height(A1);
n2 = height(A2);
while iter <= maxiter
    iter = iter+1;
    % uniformly sample
    upperidx = randsample(n1,1);
    loweridx = randsample(n2,1);
    grad_fi= @(x) (n1)*A1(upperidx,:)'*(A1(upperidx,:)*x-b1(upperidx,:));
    grad_gi= @(x) (n2)*A2(loweridx,:)'*(A2(loweridx,:)*x-b2(loweridx,:));
    fun_gi= @(x) (n2)*sum_square(A2(loweridx,:)*x-b2(loweridx,:))/2;
    % Compute phi
    grad_f_x = grad_fi(x);
    grad_g_x = grad_gi(x);
    phi = min(alpha*fun_gi(x),beta*(grad_g_x'*grad_g_x));
%     disp(full(grad_f_x'*grad_g_x));
%     disp('check dimension')
    weight = max((phi-grad_f_x'*grad_g_x)/(grad_g_x'*grad_g_x),0);
%     disp(full(weight));
    v = grad_f_x+weight*grad_g_x;
    x = x-stepsize*v;
    % Projection to simplex
    x = ProjectOntoL1Ball(x,lambda);

    cpu_t1 = toc;
    f_vec1 = [f_vec1;fun_f(x)];
    g_vec1 = [g_vec1;fun_g(x)];
    time_vec1 = [time_vec1;cpu_t1];
    % test set accuracy
    [acc_vec] = [acc_vec;TSA(x)];
    [sample_vec] = [sample_vec;iter*2];
    
    if mod(iter,5000) == 1
        fprintf('Iteration: %d\n',iter)
    end
%     if cpu_t1>maxtime
%         break
%     end
end
end

