/**
 * Copyright (c) 2018-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree. An additional grant
 * of patent rights can be found in the PATENTS file in the same directory.
 */

#include "gloo/math.h"
#include "gloo/scatter.h"
#include "gloo/test/base_test.h"

namespace gloo {
namespace test {
namespace {

// Test parameterization.
using Param = std::tuple<int, size_t>;

// Test fixture.
class ScatterTest : public BaseTest,
                    public ::testing::WithParamInterface<Param> {};

TEST_P(ScatterTest, Default) {
  auto contextSize = std::get<0>(GetParam());
  auto dataSize = std::get<1>(GetParam());

  spawn(contextSize, [&](std::shared_ptr<Context> context) {
    auto input = Fixture<uint64_t>(context, contextSize, dataSize);
    auto output = Fixture<uint64_t>(context, 1, dataSize);

    ScatterOptions opts(context);

    // Multiple inputs (one per rank)
    opts.setInputs(input.getPointers(), dataSize);

    // Single output
    opts.setOutput(output.getPointer(), dataSize);

    // Take turns being root
    for (size_t root = 0; root < context->size; root++) {
      input.assignValues();
      output.clear();
      opts.setRoot(root);
      scatter(opts);

      // Validate result on all processes
      const auto ptr = output.getPointer();
      const auto base = (root * context->size) + context->rank;
      const auto stride = context->size * context->size;
      for (auto j = 0; j < dataSize; j++) {
        ASSERT_EQ(j * stride + base, ptr[j]) << "Mismatch at index " << j;
      }
    }
  });
}

std::vector<size_t> genMemorySizes() {
  std::vector<size_t> v;
  v.push_back(1);
  v.push_back(10);
  v.push_back(100);
  return v;
}

INSTANTIATE_TEST_CASE_P(
    ScatterDefault,
    ScatterTest,
    ::testing::Combine(
        ::testing::Values(2, 4, 7),
        ::testing::ValuesIn(genMemorySizes())));

} // namespace
} // namespace test
} // namespace gloo
