using ArgParse

import CSV
import DataFrames
import Distributions
import LinearAlgebra
import Random
import StatsBase

import RobustDistributedPCA

include("util.jl")

const opnorm = LinearAlgebra.opnorm

"""
  tau_optimal(Ws:Array{Float64, 3}, skip_first::Int)

Compute the optimal `τ` parameter for the filter algorithm by computing the
operator norm of the covariance of all but the first `skip_first` samples,
which are known to be outliers.
"""
function tau_optimal(Ws::Array{Float64, 3}, skip_first::Int)
  m = size(Ws, 2)
  wts = StatsBase.AnalyticWeights(
    [zeros(Int, skip_first); ones(Int, m - skip_first)],
  )
  centered = Ws .- StatsBase.mean(Ws, wts, dims=2)
  return (1 / sum(wts)) * opnorm(StatsBase.scattermat(centered, wts))
end

struct TrialSolution
  d_naive::Float64
  d_procrustes::Float64
  d_robust::Float64
  num_samples::Int
  num_machines::Int
end

function error_fun(p_fail::Float64)
  return λ -> sqrt(λ * p_fail)
end

function compute_guess(
  Ws::Array{Float64, 3},
  p_fail::Float64,
  use_randomization::Bool,
  use_oracle_tau::Bool,
)
  m = size(Ws, 2)
  if use_oracle_tau
    τ_opt = tau_optimal(Ws, Int(floor(p_fail * m)))
    return RobustDistributedPCA.filter(
      Ws,
      τ_opt,
      use_randomization,
    )
  else
    return RobustDistributedPCA.filter_adaptive(
      Ws,
      0.5,
      error_fun(p_fail),
      use_randomization,
    )
  end
end

function run_trial(
  m::Int,
  n::Int,
  d::Int,
  r::Int,
  p_fail::Float64,
  D::Distributions.Sampleable, W_truth::Matrix{Float64})
  Vs = generate_samples(D, m, n, r)
  Ws = contaminate_samples!(copy(Vs), p_fail)
  # Option 1: Naive Procrustes fixing.
  fixed_naive = RobustDistributedPCA.procrustes_fixing(Ws, n_iter=1)
  @assert size(fixed_naive) == (d, m, r) "procrustes_fixing: dimension wrong"
  # Option 2: Robust Procrustes fixing.
  fixed_robust = RobustDistributedPCA.procrustes_fixing_robust(Ws, n_iter=1)
  @assert size(fixed_robust) == (d, m, r) "procrustes_fixing_robust: dimension wrong"
  # Run filter algorithm (without randomization).
  robust_guess = compute_guess(
    fixed_robust,
    p_fail,
    false,  # use_randomization = false
    false,  # use_oracle_tau = false
  )
  @assert size(robust_guess) == (d, r) "Incorrect dimensions"
  # Compute the 3 distances.
  W_naive = Matrix(qr(StatsBase.mean(fixed_naive, dims=2)[:, 1, :]).Q)
  W_fixed = Matrix(qr(StatsBase.mean(fixed_robust, dims=2)[:, 1, :]).Q)
  W_robst = Matrix(qr(robust_guess).Q)
  return TrialSolution(
    opnorm(W_truth - W_naive * (W_naive'W_truth)),
    opnorm(W_truth - W_fixed * (W_fixed'W_truth)),
    opnorm(W_truth - W_robst * (W_robst'W_truth)),
    n,
    m,
  )
end

function generate_solutions(
  num_samples::Vector{Int},
  num_machines::Vector{Int},
  d::Int,
  r::Int,
  stable_rank::Float64,
  gap::Float64,
  num_repeats::Int,
  p_fail::Float64,
)
  distances = Vector{TrialSolution}()
  V, Λ = covariance_geom_decay(d, r, gap, stable_rank)
  D = Distributions.MvNormal(Symmetric(V * Λ * V'))
  W_truth = V[:, 1:r]
  for m in num_machines
    for n in num_samples
      @info "Trying m = $(m), n = $(n)"
      for i in 1:num_repeats
        @info "Running trial [$(i) / $(num_repeats)]"
        push!(distances, run_trial(m, n, d, r, p_fail, D, W_truth))
      end
    end
  end
  return DataFrames.DataFrame(distances)
end

function postprocess_data(df::DataFrames.DataFrame, output_file::String)
  # Compute mean, median, and standard deviation for each field.
  gdf = DataFrames.groupby(df, [:num_samples, :num_machines])
  gdf = DataFrames.combine(
    gdf,
    [:d_naive, :d_procrustes, :d_robust]
      .=> [StatsBase.mean StatsBase.median StatsBase.std],
  )
  CSV.write(output_file, gdf)
  return gdf
end

settings = ArgParseSettings(
  description = "Compare the performance of the non-robust and robust " *
                "versions of Procrustes fixing on Gaussian data with " *
                "adversarial corruptions.",
)

@add_arg_table! settings begin
  "--num-samples"
  help = "The number of samples per machine."
  arg_type = Int
  nargs = '+'
  "--num-nodes"
  help = "The number of machines."
  arg_type = Int
  nargs = '+'
  "--dim"
  help = "The problem dimension d."
  arg_type = Int
  default = 100
  "--nvec"
  help = "The number of eigenvectors sought."
  arg_type = Int
  default = 5
  "--stable-rank"
  help = "The stable rank of the unknown covariance matrix."
  arg_type = Float64
  default = 10.0
  "--gap"
  help = "The eigenvalue gap of the unknown covariance matrix."
  arg_type = Float64
  default = 0.25
  "--num-repeats"
  help = "The number of times to repeat each experiment."
  arg_type = Int
  default = 10
  "--p_fail"
  help = "The fraction of corruptions."
  arg_type = Float64
  default = 0.1
  "--seed"
  help = "Random seed for reproducibility."
  arg_type = Int
  default = 123
  "--output-file"
  help = "The output file for the trial statistics."
  arg_type = String
  default = ""
end

parsed = parse_args(settings)
Random.seed!(parsed["seed"])
distances = generate_solutions(
  parsed["num-samples"],
  parsed["num-nodes"],
  parsed["dim"],
  parsed["nvec"],
  parsed["stable-rank"],
  parsed["gap"],
  parsed["num-repeats"],
  parsed["p_fail"],
)
output_file = parsed["output-file"]
if !isempty(output_file)
  gdf = postprocess_data(distances, output_file)
  @show gdf
end
