# coding=utf-8
# Copyright 2023 The Soar Neurips2023 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Looks at KMR benefit for different dataset sizes."""
import collections

from absl import app
from absl import flags
import h5py
import matplotlib.pyplot as plt
import numpy as np
import utils

_HDF5 = flags.DEFINE_string("hdf5", None, "Path to hdf5 of dataset.")
_ETA = flags.DEFINE_float("eta", 2.5, "AVQ eta.")
_LAMBDA = flags.DEFINE_float("lambda", 1, "SOAR lambda.")


def calc_kmr(ds, qs, sizes):
  """Computes KMR for various sizes and saves them as Numpy arrays."""
  eta = _ETA.value
  soar_l = _LAMBDA.value

  for size in sizes:
    cur_ds = None
    if size < len(ds):
      cur_ds = ds[np.random.choice(range(len(ds)), size, replace=False)]
    else:
      cur_ds = ds
    print("Computing ground truth for size", size, flush=True)
    gt = utils.compute_ground_truth(cur_ds, qs, 100)
    print("Training k-means...", flush=True)
    num_centers = size // 400
    orig_centers, toke = utils.train_kmeans(cur_ds, num_centers)
    print("Performing no-SOAR spilled assignment...", flush=True)
    toke, toke2 = utils.redo_assignment(orig_centers, cur_ds)
    print("# empty partitions:", num_centers - len(set(toke)))
    print("Updating centers...", flush=True)
    centers = utils.compute_avq_centers(cur_ds, orig_centers, toke, eta)
    print("SOAR...", flush=True)
    toke3 = utils.soar_assign(cur_ds, centers, toke, soar_l, True)

    print("KMR...", flush=True)
    np.save(f"npys/s{size}_kmr1.npy", utils.kmr(centers, toke, None, qs, gt))
    np.save(f"npys/s{size}_kmr2.npy", utils.kmr(centers, toke, toke2, qs, gt))
    np.save(f"npys/s{size}_kmr3.npy", utils.kmr(centers, toke, toke3, qs, gt))


def main(argv):
  del argv  # Unused.
  hdf5 = h5py.File(_HDF5.value, "r")
  ds = utils.normalize(hdf5["train"][()])
  qs = utils.normalize(hdf5["test"][()])
  print("Dataset shape:", ds.shape)
  print("Query shape:", qs.shape)

  sizes = [10**5, 3 * 10**5, 10**6, 3 * 10**6, len(ds)]
  size_names = ["100K", "300K", "1M", "3M", "9.99M"]
  # calc_kmr(ds, qs, sizes)

  recall_targs = [0.85, 0.9, 0.95]
  # Dataset size -> improvement at different recall targets
  soar_dict = collections.defaultdict(list)
  nosoar_dict = collections.defaultdict(list)

  for size in sizes:
    kmr1 = np.load(f"npys/s{size}_kmr1.npy")
    kmr2 = np.load(f"npys/s{size}_kmr2.npy")
    kmr3 = np.load(f"npys/s{size}_kmr3.npy")
    for targ in recall_targs:
      num_datapoints = [np.argmax(k > targ) for k in [kmr1, kmr2, kmr3]]
      soar_dict[size].append(num_datapoints[0] / num_datapoints[2])
      nosoar_dict[size].append(num_datapoints[0] / num_datapoints[1])
      print(targ, *num_datapoints, sep=",")

  x = 2 * np.arange(len(recall_targs))
  width = 0.35

  _, ax = plt.subplots(layout="constrained", figsize=(5, 3))
  ax.axhline(1, c="black", linestyle="--", alpha=0.7, linewidth=1)
  for i, (size, improvements) in enumerate(soar_dict.items()):

    rects = ax.bar(x + width * i, improvements, width, label=size_names[i])
    # ax.bar_label(rects, fmt="{:,.2f}x", padding=3)

  ax.set_xlabel("Recall Target (R@100)")
  ax.set_ylabel("KMR Efficiency Gain")
  ax.set_xticks(x + 2 * width, [f"{100*t}%" for t in recall_targs])
  ax.legend(loc="upper left", ncols=2)
  ax.set_ylim(0.5, 1.5)

  plt.savefig("out/size_plot.pdf", bbox_inches="tight", pad_inches=0.02)
  plt.savefig("out/size_plot.png", dpi=300, bbox_inches="tight", pad_inches=0.1)
  plt.close()


if __name__ == "__main__":
  app.run(main)
