/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "tde/ps.h"
#include "tde/details/io.h"

namespace tde {

c10::intrusive_ptr<FetchHandle> PS::Fetch(
    torch::Tensor ids_to_fetch,
    int64_t time,
    bool reinit,
    double weight_init_min,
    double weight_init_max) {
  std::lock_guard<std::mutex> lock(mu_);
  torch::NoGradGuard no_grad;
  TORCH_CHECK(ids_to_fetch.dim() == 2);
  std::vector<int64_t> col_ids{0};
  Filter(ids_to_fetch);
  if (cache_ids_to_fetch_or_evict_.empty()) {
    return c10::make_intrusive<FetchHandle>(time, c10::intrusive_ptr<PS>());
  }
  c10::intrusive_ptr<Notification> notification;
  {
    std::unique_lock<std::mutex> lock_fetch(fetch_notifications_mutex_);
    fetch_notifications_.emplace_back(
        time, c10::make_intrusive<Notification>());
    notification = fetch_notifications_.back().second;
  }
  uint32_t num_os_ids = os_ids_.size();
  io_.Pull(
      table_name_,
      global_ids_to_fetch_or_evict_,
      col_ids,
      num_os_ids,
      torch::kF32,
      [=, this, cache_ids_to_fetch = std::move(cache_ids_to_fetch_or_evict_)](
          auto&& val) {
        TORCH_CHECK(val.size() == cache_ids_to_fetch.size());
        for (uint32_t i = 0; i < cache_ids_to_fetch.size(); ++i) {
          int64_t cache_id = cache_ids_to_fetch[i];
          auto& fetched = val[i];
          if (!fetched.defined()) {
            if (reinit) {
              std::vector<torch::Tensor> tensors = GetTensorViews(cache_id);
              tensors[0].uniform_(weight_init_min, weight_init_max);
              // optimizer states will be set to zero
              for (uint32_t j = 1; j < num_os_ids; ++j) {
                tensors[j].zero_();
              }
            }
            continue;
          }

          std::vector<torch::Tensor> tensors = GetTensorViews(cache_id);
          for (uint32_t j = 0; j < num_os_ids; ++j) {
            tensors[j].copy_(fetched.slice(0, j, j + 1));
          }
        }
        notification->Done();
      });
  // `unsafe_reclain_from_nonowning` is the `instrusive_ptr` version of
  // `enable_shared_from_this`
  return c10::make_intrusive<FetchHandle>(
      time, c10::intrusive_ptr<PS>::unsafe_reclaim_from_nonowning(this));
}

void PS::Filter(const torch::Tensor& tensor) {
  cache_ids_to_fetch_or_evict_.clear();
  global_ids_to_fetch_or_evict_.clear();
  TORCH_CHECK(tensor.is_contiguous());
  auto* ptr = tensor.data_ptr<int64_t>();
  int64_t numel = tensor.numel();
  for (int64_t i = 0; i < numel; i += 2, ptr += 2) {
    if (auto cache_id = ptr[1];
        std::any_of(shards_->begin(), shards_->end(), [&](auto&& shard) {
          return shard.Has(cache_id);
        })) {
      cache_ids_to_fetch_or_evict_.emplace_back(cache_id);
      global_ids_to_fetch_or_evict_.emplace_back(*ptr);
    }
  }
}

void PS::Evict(torch::Tensor ids_to_evict) {
  std::lock_guard<std::mutex> lock(mu_);
  torch::NoGradGuard no_grad;
  TORCH_CHECK(ids_to_evict.dim() == 2);
  // make sure all previous fetches are done.
  SyncFetch();

  std::vector<int64_t> col_ids{0};
  // remove this copy!
  Filter(ids_to_evict);
  if (global_ids_to_fetch_or_evict_.empty()) {
    return;
  }

  uint32_t num_os_ids = os_ids_.size();
  uint32_t num_ids_to_fetch = global_ids_to_fetch_or_evict_.size();

  details::Notification notification;
  // Done first so that the Wait after preparing the first chunk won't stuck.
  notification.Done();
  // The shared data for all chunks.
  std::vector<uint64_t> offsets;
  offsets.resize(num_ids_per_chunk_ * num_os_ids * col_ids.size() + 1);

  for (uint32_t i = 0; i < num_ids_to_fetch; i += num_ids_per_chunk_) {
    uint32_t num_ids_in_chunk = std::min(
        static_cast<uint32_t>(num_ids_per_chunk_), num_ids_to_fetch - i);
    uint32_t data_size = num_ids_in_chunk * num_os_ids * col_ids.size();
    uint32_t offsets_size = num_ids_in_chunk * num_os_ids * col_ids.size() + 1;

    std::vector<torch::Tensor> all_tensors;
    for (uint32_t j = i; j < i + num_ids_in_chunk; ++j) {
      int64_t cache_id = cache_ids_to_fetch_or_evict_[j];
      std::vector<torch::Tensor> tensors = GetTensorViews(cache_id);
      all_tensors.insert(all_tensors.end(), tensors.begin(), tensors.end());
    }
    torch::Tensor data = torch::cat(all_tensors, 0).cpu();
    TORCH_CHECK(data.numel() == data_size * col_size_);

    // to prevent the original data from being prematurely recycled
    auto data_shared_ptr = std::make_shared<torch::Tensor>(data);

    offsets[0] = 0;
    for (uint32_t j = 0; j < all_tensors.size(); ++j) {
      offsets[j + 1] =
          offsets[j] + all_tensors[j].numel() * all_tensors[j].element_size();
    }
    // waiting for the Push of last chunk finishes.
    notification.Wait();
    notification.Clear();
    io_.Push(
        table_name_,
        tcb::span{global_ids_to_fetch_or_evict_.data() + i, num_ids_in_chunk},
        col_ids,
        os_ids_,
        tcb::span{
            reinterpret_cast<uint8_t*>(data_shared_ptr->data_ptr<float>()),
            data_size * sizeof(float)},
        tcb::span{offsets.data(), offsets_size},
        [&notification, data_shared_ptr] { notification.Done(); });
  }
  notification.Wait();
}

void PS::SyncFetch(int64_t time) {
  std::unique_lock<std::mutex> lock(
      fetch_notifications_mutex_, std::defer_lock);

  while (true) {
    lock.lock();
    if (fetch_notifications_.empty() ||
        fetch_notifications_.front().first != time && time >= 0) {
      lock.unlock();
      break;
    }
    auto notification = fetch_notifications_.front().second;
    fetch_notifications_.pop_front();
    lock.unlock();

    notification->Wait();
  }
}

std::vector<torch::Tensor> PS::GetTensorViews(int64_t cache_id) {
  for (auto& shard : *shards_) {
    if (shard.Has(cache_id)) {
      return shard.GetTensorView(cache_id);
    }
  }
  TORCH_CHECK(false, "all local shards do not contain cache id ", cache_id);
}

} // namespace tde
