using ArgParse

import CSV
import DataFrames
import Distributions
import LinearAlgebra
import Random
import StatsBase

import RobustDistributedPCA

include("util.jl")

const opnorm = LinearAlgebra.opnorm

"""
  tau_optimal(Ws:Array{Float64, 3}, skip_first::Int)

Compute the optimal `τ` parameter for the filter algorithm by computing the
operator norm of the covariance of all but the first `skip_first` samples,
which are known to be outliers.
"""
function tau_optimal(Ws::Array{Float64, 3}, skip_first::Int)
  m = size(Ws, 2)
  wts = StatsBase.AnalyticWeights(
    [zeros(Int, skip_first); ones(Int, m - skip_first)],
  )
  centered = Ws .- StatsBase.mean(Ws, wts, dims=2)
  return (1 / sum(wts)) * opnorm(StatsBase.scattermat(centered, wts))
end

struct TrialSolution
  d_naive::Float64
  d_naive_iter::Float64
  d_procrustes_robust::Float64
  d_robust::Float64
  p_fail::Float64
end

function error_fun(p_fail::Float64)
  return λ -> sqrt(λ * p_fail)
end

function compute_guess(
  Ws::Array{Float64, 3},
  p_fail::Float64,
  use_randomization::Bool,
  use_oracle_tau::Bool,
)
  m = size(Ws, 2)
  if use_oracle_tau
    τ_opt = tau_optimal(Ws, Int(floor(p_fail * m)))
    return RobustDistributedPCA.filter(
      Ws,
      τ_opt,
      use_randomization,
    )
  else
    return RobustDistributedPCA.filter_adaptive(
      Ws,
      0.5,
      error_fun(p_fail),
      use_randomization,
    )
  end
end

function generate_solutions(
  m::Int,
  n::Int,
  d::Int,
  r::Int,
  stable_rank::Float64,
  gap::Float64,
  use_randomization::Bool,
  procrustes_iters::Int,
  num_repeats::Int,
  use_oracle_tau::Bool,
)
  distances = Vector{TrialSolution}()
  V, Λ = covariance_geom_decay(d, r, gap, stable_rank)
  D = Distributions.MvNormal(Symmetric(V * Λ * V'))
  W_truth = V[:, 1:r]
  for p_fail in 0:0.05:0.45
    @info "Trying p_fail = $(p_fail)"
    for i in 1:num_repeats
      @info "Running trial [$(i) / $(num_repeats)]"
      # Generate all local solutions and contaminate them.
      Vs = generate_samples(D, m, n, r)
      Ws = contaminate_samples!(copy(Vs), p_fail)
      # Option 1: Procrustes fixing with 1 iteration.
      fixed_naive = RobustDistributedPCA.procrustes_fixing(Ws, n_iter=1)
      @assert size(fixed_naive) == (d, m, r) "procrustes_fixing: dimension wrong"
      # Option 2: Procrustes fixing with procrustes_iters iterations.
      fixed_naive_iter = RobustDistributedPCA.procrustes_fixing(Ws, n_iter=procrustes_iters)
      @assert size(fixed_naive_iter) == (d, m, r) "procrustes_fixing_iter: dimension wrong"
      # Option 3: Procrustes fixing with robust reference selection.
      fixed_robust = RobustDistributedPCA.procrustes_fixing_robust(Ws, n_iter=procrustes_iters)
      # Run filter algorithm (without randomization).
      robust_guess = compute_guess(
        fixed_robust,
        p_fail,
        use_randomization,
        use_oracle_tau,
      )
      @assert size(robust_guess) == (d, r) "Incorrect dimensions"
      # Compute the 3 distances.
      W_fixed = Matrix(qr(StatsBase.mean(fixed_naive, dims=2)[:, 1, :]).Q)
      W_fixed_iter = Matrix(qr(StatsBase.mean(fixed_naive_iter, dims=2)[:, 1, :]).Q)
      W_fixed_robust = Matrix(qr(StatsBase.mean(fixed_robust, dims=2)[:, 1, :]).Q)
      W_robst = Matrix(qr(robust_guess).Q)
      push!(
        distances,
        TrialSolution(
          opnorm(W_truth - W_fixed * (W_fixed'W_truth)),
          opnorm(W_truth - W_fixed_iter * (W_fixed_iter'W_truth)),
          opnorm(W_truth - W_fixed_robust * (W_fixed_robust'W_truth)),
          opnorm(W_truth - W_robst * (W_robst'W_truth)),
          p_fail,
        )
      )
    end
  end
  return DataFrames.DataFrame(distances)
end

function postprocess_data(df::DataFrames.DataFrame, output_file::String)
  # Compute mean, median, and standard deviation for each field.
  gdf = DataFrames.groupby(df, :p_fail)
  gdf = DataFrames.combine(
    gdf,
    [:d_naive, :d_naive_iter, :d_procrustes_robust, :d_robust]
      .=> [StatsBase.mean StatsBase.median StatsBase.std],
  )
  CSV.write(output_file, gdf)
  return gdf
end

settings = ArgParseSettings(
  description = "Compare the performance of the non-robust and robust " *
                "versions of Procrustes fixing on Gaussian data with " *
                "adversarial corruptions.",
)

@add_arg_table! settings begin
  "--num-samples"
  help = "The number of samples per machine."
  arg_type = Int
  default = 200
  "--num-nodes"
  help = "The number of machines."
  arg_type = Int
  default = 150
  "--dim"
  help = "The problem dimension d."
  arg_type = Int
  default = 100
  "--nvec"
  help = "The number of eigenvectors sought."
  arg_type = Int
  default = 5
  "--stable-rank"
  help = "The stable rank of the unknown covariance matrix."
  arg_type = Float64
  default = 10.0
  "--gap"
  help = "The eigenvalue gap of the unknown covariance matrix."
  arg_type = Float64
  default = 0.25
  "--use-randomization"
  help = "Set to use the randomized version of the filtering algorithm."
  action = :store_true
  "--use-oracle-tau"
  help = "Set to use the oracle covariance matrix bound in the filter."
  action = :store_true
  "--procrustes-iters"
  help = "The number of iterations used by Procrustes alignment."
  arg_type = Int
  default = 1
  "--num-repeats"
  help = "The number of times to repeat each experiment."
  arg_type = Int
  default = 10
  "--seed"
  help = "Random seed for reproducibility."
  arg_type = Int
  default = 123
  "--output-file"
  help = "The output file for the trial statistics."
  arg_type = String
  default = ""
end

parsed = parse_args(settings)
Random.seed!(parsed["seed"])
distances = generate_solutions(
  parsed["num-nodes"],
  parsed["num-samples"],
  parsed["dim"],
  parsed["nvec"],
  parsed["stable-rank"],
  parsed["gap"],
  parsed["use-randomization"],
  parsed["procrustes-iters"],
  parsed["num-repeats"],
  parsed["use-oracle-tau"],
)
output_file = parsed["output-file"]
if !isempty(output_file)
  gdf = postprocess_data(distances, output_file)
  @show gdf
end
