function loss=testtt(ex,qr,P,w_imp,Phi,option)
%% model parameters
[num_actions,num_states]=size(P);
num_actions2=3;

exploration_rounds=ex;
query_rounds=qr;
Pcount=zeros(num_actions,num_states);  %% transition count
Rcount=zeros(num_states,num_actions2);  %% reward count
hatP=zeros(num_actions,num_states);

%% exploration phase
for i=1:exploration_rounds
    V_value=1./sqrt(min(Rcount,[],2)+ones(num_states,1));
    hatP=empirical_tran(Pcount);
    action_visit=Pcount * ones(num_states,1);
    A_value=1./sqrt(action_visit + ones(num_actions,1));
    Q_value=hatP*V_value + A_value;
    [~,action]=max(Q_value);
    state=discretesample(P(action,:),1);
    [~,action2]=min(Rcount(state,:));
    Rcount(state,action2)=Rcount(state,action2)+1;
    Pcount(action,state)=Pcount(action,state)+1;
end

%% query
Rcount=reshape(Rcount',num_states*num_actions2,1);
if option == 1
    estw=acquery(Phi,w_imp,Rcount,query_rounds);   
else
    estw=rglquery(Phi,w_imp,Rcount,query_rounds);
end
%% planning phase
est_reward=(ones(num_states*num_actions2,1)+sign(Phi'*estw))./2;
est_reward=reshape(est_reward,num_states,num_actions2);
[est_V,est_policy]=max(est_reward,[],2);
est_value = hatP*est_V;
[~,est_best_action] = max(est_value);

reward=(ones(num_states*num_actions2,1)+sign(Phi'*w_imp))./2;
reward=reshape(reward,num_states,num_actions2);
[V,~]=max(reward,[],2);
value=P * V;
[~,best_action] = max(value);
V_pi=zeros(num_states,1);
for i=1:num_states
    V_pi(i)=reward(i,est_policy(i));
end
value_pi=P * V_pi;

%% counting loss
loss=value(best_action)-value_pi(est_best_action);
end
% %% test accuracy
% false=0;
% for i=1:num_states
%     if (estw'*Phi(:,i))*(w_imp'*Phi(:,i))<0
%         false=false+1;
%     end
% end
% acc=false/num_states
%end

function estw= acquery (R,w,states_visit,budget)

    
    [d,d1]=size(R);
    
    for i=1:d1
        if states_visit(i) == 0
            R(:,i)=zeros(d,1);
        end
    end
    i=1;
    while i<=d1
        if sum(R(:,i)) == 0
            R(:,i)=[];
            d1=d1-1;
        else
            i=i+1;
        end
    end

    [d,d1]=size(R);
    VN=eye(d);
    sumphi=zeros(d,1);
    for i=1:budget
        que=1;
        score=0;
        for wen=1:d1
            newscore=R(:,wen)'*(VN\R(:,wen));
            if newscore>score
                que=wen;
                score=newscore;
            end
        end
        prob=(R(:,que)'*w+1)/2;
        VN=R(:,que)*R(:,que)'+VN;
        response=discretesample([1-prob,prob],1);
        sumphi=sumphi+R(:,que)*(2*response-3);
    end
    estw=VN\sumphi;
end
function estw= rglquery (R,w,states_visit,budget)

    [d,~]=size(R);
    states_prob=states_visit./sum(states_visit);
    VN=eye(d);
    sumphi=zeros(d,1);
    for i=1:budget
        que=discretesample(states_prob,1);
        prob=(R(:,que)'*w+1)/2;
        VN=R(:,que)*R(:,que)'+VN;
        response=discretesample([1-prob,prob],1);
        sumphi=sumphi+R(:,que)*(2*response-3);
    end
    estw=VN\sumphi;
end

function hatP = empirical_tran(countP)
    [d1,d2]=size(countP);
    hatP=zeros(d1,d2);
    for i=1:d1
        sum1=sum(countP(i,:));
        if sum1==0
            hatP(i,:)=ones(1,d2)./d2;
        else
            hatP(i,:)=countP(i,:)./sum1;
        end
    end
end
