/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include <resource_manager.hpp>

namespace HugeCTR {

/**
 * @brief GPU resources manager which holds the minimal, essential set of resources
 *
 * A core GPU Resource manager
 */
class ResourceManagerCore : public ResourceManager {
 private:
  int num_process_;
  int process_id_;
  DeviceMap device_map_;
  std::shared_ptr<CPUResource> cpu_resource_;
  std::vector<std::shared_ptr<GPUResource>> gpu_resources_; /**< GPU resource vector */
  std::vector<std::vector<bool>> p2p_matrix_;

  void all2all_warmup();
  void enable_all_peer_accesses();

#ifndef DISABLE_CUDF
  std::vector<std::shared_ptr<rmm::mr::device_memory_resource>> base_cuda_mr_;
  std::vector<std::shared_ptr<rmm::mr::device_memory_resource>> memory_resource_;
  std::vector<rmm::mr::device_memory_resource*> original_device_resource_;

  void initialize_rmm_resources();
#endif

 public:
  ResourceManagerCore(int num_process, int process_id, DeviceMap&& device_map,
                      unsigned long long seed);
  static std::shared_ptr<ResourceManager> create(
      const std::vector<std::vector<int>>& visible_devices, unsigned long long seed,
      DeviceMap::Layout layout = DeviceMap::LOCAL_FIRST);

  HCTR_DISALLOW_COPY_AND_MOVE(ResourceManagerCore);
  ~ResourceManagerCore();

  // from ResourceManagerBase
  void set_local_gpu(std::shared_ptr<GPUResource> gpu_resource, size_t local_gpu_id) override {
    if (local_gpu_id >= get_local_gpu_count()) {
      HCTR_OWN_THROW(Error_t::WrongInput, "Error: Invalid local_gpu_id");
    }
    if (gpu_resources_[local_gpu_id] != nullptr) {
      HCTR_OWN_THROW(Error_t::WrongInput, "Error: Already initialized");
    }
    gpu_resources_[local_gpu_id] = gpu_resource;
  }
  const std::shared_ptr<GPUResource>& get_local_gpu(size_t local_gpu_id) const override {
    return gpu_resources_[local_gpu_id];
  }
  const std::shared_ptr<GPUResource>& get_local_gpu_from_device_id(
      size_t device_id) const override {
    const auto& local_gpu_device_id_list = get_local_gpu_device_id_list();
    const auto& iter =
        std::find(local_gpu_device_id_list.begin(), local_gpu_device_id_list.end(), device_id);
    if (iter == local_gpu_device_id_list.end()) {
      HCTR_OWN_THROW(Error_t::WrongInput,
                     "Error: device_id does not exist in the local_gpu_device_id_list");
    }
    const auto local_gpu_id = iter - local_gpu_device_id_list.begin();
    return get_local_gpu(local_gpu_id);
  }
  size_t get_local_gpu_count() const override { return device_map_.get_device_list().size(); }
  size_t get_global_gpu_count() const override { return device_map_.size(); }

  // from ResourceManager
  int get_num_process() const override { return num_process_; }
  int get_process_id() const override { return process_id_; }
  int get_master_process_id() const override { return 0; }
  bool is_master_process() const override { return process_id_ == 0; }

  const std::shared_ptr<CPUResource>& get_local_cpu() const override { return cpu_resource_; }

  const std::vector<std::shared_ptr<GPUResource>>& get_local_gpus() const override {
    return gpu_resources_;
  }

  const std::vector<int>& get_local_gpu_device_id_list() const override {
    return device_map_.get_device_list();
  }

  int get_process_id_from_gpu_global_id(size_t global_gpu_id) const override {
    return device_map_.get_pid(global_gpu_id);
  }

  size_t get_gpu_local_id_from_global_id(size_t global_gpu_id) const override {
    return device_map_.get_local_id(global_gpu_id);
  }

  size_t get_gpu_global_id_from_local_id(size_t local_gpu_id) const override {
    return device_map_.get_global_id(local_gpu_id);
  }

  bool p2p_enabled(int src_dev, int dst_dev) const override;
  bool all_p2p_enabled() const override;

  DeviceMap::Layout get_device_layout() const override { return device_map_.get_device_layout(); }

#ifndef DISABLE_CUDF
  const std::shared_ptr<rmm::mr::device_memory_resource>& get_device_rmm_device_memory_resource(
      int local_gpu_id) const override;
#endif
};
}  // namespace HugeCTR
