/*
 * Copyright (c) 2017-present, XXX, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "synctrainer.h"
#include "batcher.h"
#include "checkpointer.h"
#include "distributed.h"
#include "sampler.h"

#include "common/autograd.h"

namespace {
const std::string kValueKey = "V";
const std::string kQKey = "Q";
const std::string kPiKey = "Pi";
const std::string kSigmaKey = "std";
const std::string kActionQKey = "actionQ";
const std::string kActionKey = "action";
const std::string kPActionKey = "pAction";

} // namespace
namespace cpid {

void BatchedFrame::toCuda() {
  auto cudaify = [](torch::Tensor t) { return t.to(torch::kCUDA); };
  state = common::applyTransform(state, cudaify);
  // forwarded_state = common::applyTransform(forwarded_state, cudaify);
  reward = reward.to(torch::kCUDA);
  action = action.to(torch::kCUDA);
  if (pAction.defined()) {
    pAction = pAction.to(torch::kCUDA);
  }
}

std::shared_ptr<SyncFrame> SingleFrame::batch(
    const std::vector<std::shared_ptr<SyncFrame>>& list,
    std::unique_ptr<AsyncBatcher>& batcher) {
  torch::NoGradGuard g_;
  if (list.empty())
    return nullptr;

  size_t batchSize = list.size();
  auto batched = std::make_shared<BatchedFrame>();
  batched->reward = torch::zeros({int64_t(batchSize)});
  auto rewAcc = batched->reward.accessor<float, 1>();
  std::vector<ag::Variant> actions, pActions;
  std::vector<ag::Variant> states;
  actions.reserve(batchSize);
  states.reserve(batchSize);

  for (size_t i = 0; i < batchSize; ++i) {
    auto curFrame = std::dynamic_pointer_cast<SingleFrame>(list[i]);
    rewAcc[i] = curFrame->reward;
    if (curFrame->pAction.defined()) {
      pActions.push_back(curFrame->pAction);
    }
    actions.push_back(curFrame->action);
    states.push_back(curFrame->state);
  }

  batched->action = batcher->makeBatch(actions, -42).get();
  if (pActions.size() != 0) {
    batched->pAction = batcher->makeBatch(pActions, -42).get();
    batched->pAction.set_requires_grad(false);
  }
  batched->state = batcher->makeBatch(states);
  batched->action.set_requires_grad(false);
  batched->reward.set_requires_grad(false);
  return batched;
}

SyncTrainer::SyncTrainer(
    ag::Container model,
    ag::Optimizer optim,
    std::unique_ptr<BaseSampler> sampler,
    std::unique_ptr<AsyncBatcher> batcher,
    int returnsLength,
    int updateFreq,
    int trainerBatchSize,
    bool overlappingUpdates,
    bool forceOnPolicy,
    bool gpuMemoryEfficient,
    bool reduceGradients,
    float maxGradientNorm)
    : Trainer(model, optim, std::move(sampler), std::move(batcher)),
      returnsLength_(returnsLength),
      updateFreq_(updateFreq),
      trainerBatchSize_(trainerBatchSize),
      overlappingUpdates_(overlappingUpdates),
      forceOnPolicy_(forceOnPolicy),
      gpuMemoryEfficient_(gpuMemoryEfficient),
      reduceGradients_(reduceGradients),
      threads_(10),
      stepMutex_(1),
      checkpointer_(new Checkpointer(this)),
      max_gradient_norm_(maxGradientNorm) {
  if (updateFreq_ == 1) {
    batcher_->setModel(model_);
  } else {
    batcher_->setModel(ag::clone(model_));
  }
  if (returnsLength < 2) {
    throw std::runtime_error("SyncTrainer: the return size must be at least 2");
  }
}

void SyncTrainer::step(
    GameUID const& key,
    EpisodeKey const&,
    std::shared_ptr<ReplayBufferFrame> value,
    bool isDone) {
  if (!train_)
    return;
  auto frame = std::static_pointer_cast<SyncFrame>(value);
  priority_lock lk(stepMutex_, 0); // low priority lock
  lk.lock();
  batchCV_.wait(
      lk, [this]() { return (int)readyToUpdate_.size() < trainerBatchSize_; });
  frameBuffers_[key].emplace_back(std::move(frame), isDone);
  if ((int)frameBuffers_[key].size() >= returnsLength_) {
    std::unique_lock<std::mutex> lk2(forwardMutex_);
    readyToUpdate_.insert({key, hires_clock::now()});
    lk2.unlock();
  }
  lk.unlock();
  batchCV_.notify_all();
  /*
  if ((int)readyToUpdate_.size() >= trainerBatchSize_) {
    metaUpdate();
  }
  */
}

std::shared_ptr<ReplayBufferFrame> SyncTrainer::makeFrame(
    ag::Variant trainerOutput,
    ag::Variant state,
    float reward) {
  torch::NoGradGuard g_;
  bool isCuda = model_->options().device().is_cuda();
  auto back = isCuda ? torch::kCUDA : torch::kCPU;
  if (gpuMemoryEfficient_) {
    back = torch::kCPU;
  }
  auto frame = std::make_shared<SingleFrame>();
  frame->action = trainerOutput[kActionKey].to(back);
  frame->state = std::move(state);

  auto toBack = [back](torch::Tensor t) { return t.to(back); };
  frame->state = common::applyTransform(frame->state, toBack);

  if (trainerOutput.getDict().count(kPActionKey) > 0) {
    frame->pAction = trainerOutput[kPActionKey].to(back);
  }
  frame->reward = reward;
  return frame;
}

std::shared_ptr<SyncFrame> SyncTrainer::makeEmptyFrame() {
  return std::make_shared<SingleFrame>();
}

void SyncTrainer::createBatch(
    const std::vector<GameUID>& selectedGames,
    std::vector<std::shared_ptr<SyncFrame>>& seq,
    torch::Tensor& terminal) {
  MetricsContext::Timer batchTimer(metricsContext_, "trainer:batch_creation");

  seq.resize(returnsLength_);

  std::vector<std::future<std::shared_ptr<SyncFrame>>> futures;
  auto combinedFrame = makeEmptyFrame();
  auto batchOneFrame =
      [&](std::vector<std::shared_ptr<SyncFrame>> currentFrame) {
        return combinedFrame->batch(currentFrame, batcher_);
      };
  for (int i = 0; i < returnsLength_; ++i) {
    std::vector<std::shared_ptr<SyncFrame>> currentFrame;
    for (size_t j = 0; j < selectedGames.size(); ++j) {
      auto& f = frameBuffers_[selectedGames[j]][i];
      currentFrame.push_back(f.first);
      terminal[i][j] = f.second;
    }
    futures.emplace_back(
        threads_.enqueue(batchOneFrame, std::move(currentFrame)));
  }
  for (int i = 0; i < returnsLength_; ++i) {
    futures[i].wait();
    seq[i] = futures[i].get();
    if (model_->options().device().is_cuda()) {
      seq[i]->toCuda();
    }
  }
}

bool SyncTrainer::update() {
  priority_lock lk(stepMutex_, 1); // high priority lock
  lk.lock();

  auto shouldDoUpdate = [this]() {
    if ((int)readyToUpdate_.size() >= trainerBatchSize_) {
      return true;
    }
    if ((int)readyToUpdate_.size() > 0) {
      hires_clock::time_point now = hires_clock::now();
      hires_clock::time_point oldest = std::accumulate(
          readyToUpdate_.begin(),
          readyToUpdate_.end(),
          now,
          [](hires_clock::time_point t, const auto& it) {
            return std::min(t, it.second);
          });
      std::chrono::duration<double, std::milli> dur = now - oldest;
      if (dur.count() / 1000. > 5)
        return true;
    }
    return false;
  };

  while (
      !batchCV_.wait_for(lk, std::chrono::milliseconds(2000), shouldDoUpdate)) {
  }
  int actualBatchSize = readyToUpdate_.size();
  updateCount_++;
  metricsContext_->incCounter("sampleCount", actualBatchSize);

  std::vector<GameUID> selectedGames;
  for (const auto& g : readyToUpdate_) {
    selectedGames.push_back(g.first);
    if ((int)selectedGames.size() == trainerBatchSize_)
      break;
  }

  std::vector<std::shared_ptr<SyncFrame>> seq;
  auto terminal = torch::zeros({returnsLength_, actualBatchSize}, at::kByte);

  createBatch(selectedGames, seq, terminal);

  if (updateFreq_ == 1) {
    batcher_->lockModel();
  }
  {
    MetricsContext::Timer batchTimer(metricsContext_, "trainer:doUpdate");
    doUpdate(seq, terminal);
  }

  if (updateFreq_ == 1) {
    batcher_->unlockModel();
  } else {
    batcher_->setModel(ag::clone(model_));
  }

  // now we clean up the frameBuffers_
  std::unique_lock<std::mutex> lk2(forwardMutex_);
  if (forceOnPolicy_) {
    frameBuffers_.clear();
    readyToUpdate_.clear();
  } else {
    size_t to_delete = overlappingUpdates_ ? 1 : returnsLength_ - 1;
    for (const auto& g : selectedGames) {
      auto& buffer = frameBuffers_[g];
      buffer.erase(buffer.begin(), buffer.begin() + to_delete);
      if ((int)buffer.size() < returnsLength_) {
        readyToUpdate_.erase(g);
      }
    }
  }
  lk2.unlock();
  lk.unlock();
  forwardCV_.notify_all();
  batchCV_.notify_all();

  if (checkpointer_) {
    checkpointer_->updateDone(updateCount_);
  }
  return true;
}

ag::Variant SyncTrainer::forwardUnbatched(ag::Variant in) {
  auto out = model_->forward(batcher_->makeBatch({in}));
  return batcher_->unBatch(out, false, -1)[0];
}

void SyncTrainer::setTrain(bool train) {
  if (train) {
    model_->train();
  } else {
    model_->eval();
  }
  train_ = train;
}

void SyncTrainer::computeAllForward(
    const std::vector<std::shared_ptr<SyncFrame>>& seq,
    int batchSize) {
  MetricsContext::Timer forwardTimer(
      metricsContext_, "trainer:computeAllForward");

  if (gpuMemoryEfficient_) {
    bool isCuda = model_->options().device().is_cuda();
    auto cudaify = [](torch::Tensor t) { return t.to(torch::kCUDA); };

    for (size_t i = 0; i < seq.size(); ++i) {
      ag::Variant input = seq[i]->state;
      if (isCuda) {
        input = common::applyTransform(input, cudaify);
      }
      seq[i]->forwarded_state = model_->forward(std::move(input));
    }
  } else {
    // To save computation, we can do one forward for all element of the seq at
    // once
    std::vector<ag::Variant> allStates;
    for (const auto& frame : seq) {
      allStates.emplace_back(frame->state);
    }
    ag::Variant batch = common::makeBatchVariant(allStates);
    auto collapseFirstDim = [](torch::Tensor t) {
      std::vector<int64_t> sizes(t.sizes().vec()),
          newSizes(t.sizes().size() - 1);
      newSizes[0] = sizes[0] * sizes[1];
      std::copy(sizes.begin() + 2, sizes.end(), newSizes.begin() + 1);
      return t.view(newSizes);
    };
    batch = common::applyTransform(std::move(batch), collapseFirstDim);
    ag::Variant composedResult = model_->forward(std::move(batch));
    std::vector<ag::Variant> result =
        common::unBatchVariant(composedResult, batchSize, false);

    for (size_t i = 0; i < seq.size(); ++i) {
      seq[i]->forwarded_state = std::move(result[i]);
    }
  }
}

void SyncTrainer::doOptimStep() {
  namespace dist = distributed;
  if (reduceGradients_) {
    for (auto& var : model_->parameters()) {
      if (!var.grad().defined()) {
        continue;
      }
      {
        MetricsContext::Timer timeAllreduce(
            metricsContext_, "trainer:network_time");
        dist::allreduce(var.grad());
      }
      var.grad().div_(dist::globalContext()->size);
    }
  }

  //compute the inf norm
  float norm = 0;
  for (auto& var : model_->parameters()) {
    if (!var.grad().defined()) {
      continue;
    }
    norm = std::max(norm, torch::max(torch::abs(var.grad())).item<float>());
  }
  metricsContext_->pushEvent(
                             "grad_inf_norm", norm);
  if (max_gradient_norm_ > 0) {
    float clip_coef = max_gradient_norm_ / (norm +  1.e-5);
    if (clip_coef < 1) {
      for (auto& var : model_->parameters()) {
        if (!var.grad().defined()) {
          continue;
        }
        var.grad().mul_(clip_coef);
      }
    }
  }
  optim_->step();
  optim_->zero_grad();
}

void SyncTrainer::setCheckpointer(std::unique_ptr<Checkpointer> checkpointer) {
  checkpointer_ = std::move(checkpointer);
}

ag::Variant SyncTrainer::forward(
    ag::Variant inp,
    GameUID const& gameUID,
    EpisodeKey const& key) {
  std::unique_lock<std::mutex> forwardLock(forwardMutex_);
  forwardCV_.wait(forwardLock, [this, gameUID]() {
    return readyToUpdate_.count(gameUID) == 0;
  });
  forwardLock.unlock();
  return Trainer::forward(inp, gameUID, key);
}

void SyncTrainer::forceStopEpisode(GameUID const& id, EpisodeKey const&) {
  priority_lock lk(stepMutex_, 0); // low priority lock
  lk.lock();
  readyToUpdate_.erase(id);
  frameBuffers_.erase(id);
}

std::shared_ptr<Evaluator> SyncTrainer::makeEvaluator(
    size_t n,
    std::unique_ptr<BaseSampler> sampler) {
  return evaluatorFactory(
      model_,
      std::move(sampler),
      n,
      [this](ag::Variant inp, GameUID const& id, EpisodeKey const& key) {
        torch::NoGradGuard g;
        return this->forwardUnbatched(inp);
      });
}

} // namespace cpid
