%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% optimize_probitsem_entropy_bound:
%
% Choose a subset of variables using a different approximation for H(Y) by
% Globerson et al.
%
% If a ILP solution is required, a Gurobi/MEX call can be done. Please
% check:
%
% - http://www.convexoptimization.com/wikimization/index.php/Gurobi_Mex:_A_MATLAB_interface_for_Gurobi
% - http://www.gurobi.com/
%
% Input:
%
% - K_width: largest number of parents to keep in the entropy approximation
% - L, S: parameters of probit model
% - sel_K_start: initialization vector
% - solve_ilp: if true, solve by integer linear programming
% - use_gurobi: use Gurobi/MEX tools
%
% Output:
%
% - sel_K: the selection of K variables
%
% Created by: Ricardo Silva, London, 26/04/2011
% University College London
%
% Current version: 26/04/2011

function [sel_K optim_found] = optimize_probitsem_entropy_bound(K_width, ...
                               L, S, sel_K_start, solve_ilp, use_gurobi)

[num_y num_x] = size(L); num_x = num_x - 1;

% Calculate univariate weights

W = zeros(num_y, 1);
M = 100000;

x = chol(S)' * randn(num_x, M);

for y = 1:num_y
   m = L(y, 1:num_x) * x + L(y, end);
   py0 = normcdf(-m); py1 = 1 - py0; 
   log_py1 = log(py1); log_py1(py1 == 0) = 0;
   log_py0 = log(py0); log_py0(py0 == 0) = 0;
   W(y) = mean(py1 .* log_py1 + py0 .* log_py0);
end

clear('x');

% Get entropy approximation

fprintf('Caching entropy information...\n');
entropy_order = optimize_entropy_order(K_width, L, S, true);
entropy_sets = get_entropy_sets(L, S, entropy_order);

% Optimize choice of variables

sel_K = sel_K_start;
z = zeros(1, num_y);
z(sel_K) = 1;
not_sel_K = find(z == 0);

fprintf('Initial score = %f\n', sum(W(sel_K)) + get_entropy_score(z(1:num_y), entropy_sets));

if solve_ilp % Solve by integer linear programming
      
  num_extra_vars = 0;
  for y = 1:num_y
    num_extra_vars = num_extra_vars + 2^length(entropy_sets{y}.parents);
  end
  num_vars = num_y + num_extra_vars;
  ilp_w = zeros(num_vars, 1);
  ilp_w(1:num_y) = W;

  % Equality constraint
  
  num_fix = round(length(sel_K) / 2);
  [~, idx] = sort(diag(L(:, 1:num_x) * S * L(:, 1:num_x)'), 'descend');
  
  sel_K = sort(idx(1:length(sel_K_start)))'; % Ignore sel_K_start
  z = zeros(1, num_y);
  z(sel_K) = 1;
  
  fix_these = sort(idx(1:num_fix))';
  Aeq = sparse(1 + num_fix, num_vars); beq = zeros(1 + num_fix, 1);
  Aeq(1, 1:num_y) = 1; beq(1) = sum(z);
  fix_pos = 2;
  for x = fix_these
    Aeq(fix_pos, x) = 1; beq(fix_pos) = 1; %#ok<SPRIX>
    fix_pos = fix_pos + 1;
  end
  
  % Inequality constraints

  num_constraints = 0;
  for y = 1:num_y
    num_p = length(entropy_sets{y}.parents);
    num_comb_p = 2^num_p;  
    p_sel = zeros(1, num_p);
    for i = 1:num_comb_p 
       num_constraints = num_constraints + sum(p_sel) + 2;
       p_sel = advance_bits(p_sel);
    end
  end
  
  A = sparse(num_constraints, num_vars);
  b = sparse(num_constraints, 1);
  row_pos = 1; col_pos = num_y;
  
  z_start = zeros(1, num_vars); z_start(1:num_y) = z;
  
  for y = 1:num_y

    num_p = length(entropy_sets{y}.parents);
    p_in_pos = zeros(num_y, 1);
    p_in_pos(entropy_sets{y}.parents) = 1:num_p;
    num_comb_p = 2^num_p;  
    p_sel = zeros(1, num_p);

    for i = 1:num_comb_p % All non-empty subsets of parents

      p_in  = entropy_sets{y}.parents(p_sel == 1);

      % Enter constraints extra_z(col_pos) == 1 --> z(c) == 1, where 
      % c in {y} \union {"parents" of y})
      %
      % (That is, extra_z(col_pos) - z(c) <= 0)
      
      col_pos = col_pos + 1;
      
      A(row_pos, [col_pos y]) = [1 -1];  %#ok<SPRIX>
      row_pos = row_pos + 1;
      for p = p_in
        A(row_pos, [col_pos p]) = [1 -1];    %#ok<SPRIX>
        row_pos = row_pos + 1;  
      end    
      
      % Enter constraint extra_z(col_pos) == 0 --> at least one z(c) is 0,
      % with c in {y} \union {"parents" of y})
      %
      % (That is, sum (1 - z(c)) >= 1 - extra_z(col_pos))
      
      A(row_pos, [y p_in col_pos]) = [ones(1, length(p_in) + 1) -1]; %#ok<SPRIX>
      b(row_pos) = length(p_in);    
      row_pos = row_pos + 1;  

      % Now, enter corresponding weights into objective function  

      num_p_in = length(p_in);
      p_in_sel = zeros(1, num_p_in);
      num_p_in_comb = 2^num_p_in;
      p_in_alt = zeros(1, num_p); 
      
      for j = 1:num_p_in_comb
        p_in_alt(p_in_pos(p_in)) = 0; p_in_alt(p_in_pos(p_in(p_in_sel == 1))) = 1;
        ilp_w(col_pos) = ilp_w(col_pos) + ...
                         entropy_sets{y}.value(parent_entry(p_in_alt)) * ...
                         (-1)^sum(p_in_sel == 0);
        p_in_sel = advance_bits(p_in_sel);  
      end
            
      % Adjust initial solution
      
      z_start(col_pos) = prod(z([y p_in]));
      
      % Advance
      
      p_sel = advance_bits(p_sel);

    end

  end
  
  if use_gurobi
    objtype = -1; % 1 for minimize, -1 for maximize
    contypes = [repmat('<', 1, size(A, 1)) repmat('=', 1, size(Aeq, 1))];
    A = [A; Aeq];
    b = [b; beq];
    lb = [];
    ub = [];
    vtypes = repmat('B', 1, size(A, 2));

    clear opts
    opts.IterationLimit = 150000;
    opts.FeasibilityTol = 1e-6;
    opts.IntFeasTol = 1e-5;
    opts.OptimalityTol = 1e-6;
    opts.Method = 1; 
    opts.Presolve = -1;
    opts.Display = 1;
    opts.Start = z_start;
    opts.LogFile = 'test_gurobi_mex_MIP.log';

    [z, ~, flag] = gurobi_mex(ilp_w, objtype, A, b, contypes, lb, ub, vtypes, opts);          
    z = z';
    optim_found = flag == 2;
  else % Use MATLAB's optimization toolbox. Good luck with that!
    options = optimset('Display', 'iter', 'MaxTime', 3600);
    z = bintprog(-ilp_w, A, b, Aeq, beq, z_start, options)';  
    optim_found = false;
  end
  
  sel_K = find(z(1:num_y) == 1);
  fprintf('Best score = %f\n', z * ilp_w);
  return
end

best_score = sum(W(sel_K)) + get_entropy_score(z, entropy_sets);

iter = 1;

fprintf('Initial score = %f: \n', best_score);

while true
  fprintf('Iteration [%d]: \n', iter);

  changed = false;
  for y1 = sel_K
    fprintf('Attempting to remove %d\n', y1);  
    for y2 = not_sel_K     
     z(y1) = 0; z(y2) = 1;
     score = sum(W(z == 1)) + get_entropy_score(z, entropy_sets);
     if score > best_score
       best_score = score;
       best_pair = [y1 y2];
       changed = true;
     end
     z(y1) = 1; z(y2) = 0;
    end
  end
  
  if ~changed
    break
  end
  
  z(best_pair(1)) = 1 - z(best_pair(1));
  z(best_pair(2)) = 1 - z(best_pair(2));
  sel_K = find(z == 1);
  not_sel_K = find(z == 0);

  fprintf('Current solution [score = %f]:', best_score);
  disp(sel_K);
  iter = iter + 1;   
end
fprintf('Score found: %f\n', best_score);
optim_found = false;

