/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "extensions.h"

std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma) {
  const auto &dz_ = dz.contiguous();
  const auto &x_ = x.contiguous();
  const auto &mu_ = mu.contiguous();
  const auto &rsigma_ = rsigma.contiguous();
  const auto &gamma_ = gamma.contiguous();

  auto dx = at::empty_like(x_);
  auto dgamma = at::empty_like(gamma_);
  auto dbeta = at::empty_like(gamma_);
  transformer_engine::TensorWrapper workspace;

  auto dz_cu = makeTransformerEngineTensor(dz_);
  auto x_cu = makeTransformerEngineTensor(x_);
  auto mu_cu = makeTransformerEngineTensor(mu_);
  auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
  auto gamma_cu = makeTransformerEngineTensor(gamma_);
  auto dx_cu = makeTransformerEngineTensor(dx);
  auto dgamma_cu = makeTransformerEngineTensor(dgamma);
  auto dbeta_cu = makeTransformerEngineTensor(dbeta);

  // This call populates tensors with the required config.
  nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
                     dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
                     at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                     zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  // Alloc space for Tensors.
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace =
      makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());

  // Actual call to bwd kernel.
  nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
                     dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
                     at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                     zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  return {dx, dgamma, dbeta};
}

std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
                                          const at::Tensor &bias, float eps, at::Tensor scale,
                                          at::Tensor amax, at::Tensor scale_inv,
                                          transformer_engine::DType otype, const int sm_margin,
                                          const bool zero_centered_gamma, const int scale_offset,
                                          const int amax_offset, const int scale_inv_offset) {
  using namespace transformer_engine;

  const auto &input_ = input.contiguous();

  auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
  return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
                                   sm_margin, zero_centered_gamma, scale_offset, amax_offset,
                                   scale_inv_offset);
}

std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
    const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps,
    at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv,
    transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma,
    const int scale_offset, const int amax_offset, const int scale_inv_offset) {
  using namespace transformer_engine;

  const auto &input_ = input.contiguous();
  const auto &weight_ = weight.contiguous();
  const auto &bias_ = bias.contiguous();

  // Tensor dimensions
  size_t N = static_cast<size_t>(input.size(0));
  size_t H = static_cast<size_t>(input.size(1));

  // Get pointers for FP8 scale, amax, scale-inverse
  void *scale_dptr = getDataPtr(scale, scale_offset);
  void *amax_dptr = getDataPtr(amax, amax_offset);
  void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

  // Construct Transformer Engine tensors
  DType itype = GetTransformerEngineDType(input.scalar_type());
  auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
  auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
  auto input_cu = makeTransformerEngineTensor(input_);
  auto gamma_cu = makeTransformerEngineTensor(weight_);
  auto beta_cu = makeTransformerEngineTensor(bias_);
  auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
                                          scale_inv_dptr);
  auto mu_cu = makeTransformerEngineTensor(mu);
  auto rsigma_cu = makeTransformerEngineTensor(rsigma);

  // Query workspace sizes
  transformer_engine::TensorWrapper workspace;
  nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
                     mu_cu.data(), rsigma_cu.data(), workspace.data(),
                     at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                     zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  // Allocate workspaces
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace =
      makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());

  // Launch kernel
  nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
                     mu_cu.data(), rsigma_cu.data(), workspace.data(),
                     at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                     zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  return {ln_out, mu, rsigma};
}

at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight,
                                 const at::Tensor &bias, float eps, at::Tensor scale,
                                 at::Tensor amax, at::Tensor scale_inv,
                                 transformer_engine::DType otype, const int sm_margin,
                                 const bool zero_centered_gamma, const int scale_offset,
                                 const int amax_offset, const int scale_inv_offset

) {
  // This is a specialized version of layernorm_fwd_fp8, optimized for inference,
  // which only returns the normalized output.
  std::vector<at::Tensor> out =
      layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin,
                        zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
  return out[0];
}

std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor &weight,
                                      const at::Tensor &bias, float eps, const int sm_margin,
                                      const bool zero_centered_gamma) {
  using namespace transformer_engine;

  DType itype = GetTransformerEngineDType(input.scalar_type());
  const auto &input_ = input.contiguous();
  auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));

  return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
}

std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
                                              const at::Tensor &bias, at::Tensor ln_out, float eps,
                                              const int sm_margin, const bool zero_centered_gamma) {
  using namespace transformer_engine;

  DType itype = GetTransformerEngineDType(input.scalar_type());

  return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(),
                                   at::Tensor(), itype, sm_margin, zero_centered_gamma);
}

at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight,
                             const at::Tensor &bias, float eps, const int sm_margin,
                             const bool zero_centered_gamma) {
  // This is a specialized version of layernorm_fwd, optimized for inference,
  // which only returns the normalized output.
  std::vector<at::Tensor> out =
      layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma);
  return out[0];
}

std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma) {
  const auto &dz_ = dz.contiguous();
  const auto &x_ = x.contiguous();
  const auto &rsigma_ = rsigma.contiguous();
  const auto &gamma_ = gamma.contiguous();

  auto dx = at::empty_like(x_);
  auto dgamma = at::empty_like(gamma_);
  transformer_engine::TensorWrapper workspace;

  auto dz_cu = makeTransformerEngineTensor(dz_);
  auto x_cu = makeTransformerEngineTensor(x_);
  auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
  auto gamma_cu = makeTransformerEngineTensor(gamma_);
  auto dx_cu = makeTransformerEngineTensor(dx);
  auto dgamma_cu = makeTransformerEngineTensor(dgamma);

  // This call populates tensors with the required config.
  nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
                   dgamma_cu.data(), workspace.data(),
                   at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                   zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  // Alloc space for Tensors.
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace =
      makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());

  // Actual call to bwd kernel.
  nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
                   dgamma_cu.data(), workspace.data(),
                   at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                   zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  return {dx, dgamma};
}

std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
                                        float eps, at::Tensor scale, at::Tensor amax,
                                        at::Tensor scale_inv, transformer_engine::DType otype,
                                        const int sm_margin, const bool zero_centered_gamma,
                                        const int scale_offset, const int amax_offset,
                                        const int scale_inv_offset) {
  using namespace transformer_engine;

  const auto &input_ = input.contiguous();
  const auto &weight_ = weight.contiguous();

  auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
  return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype,
                                 sm_margin, zero_centered_gamma, scale_offset, amax_offset,
                                 scale_inv_offset);
}

std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight,
                                                float eps, at::Tensor scale, at::Tensor ln_out,
                                                at::Tensor amax, at::Tensor scale_inv,
                                                transformer_engine::DType otype,
                                                const int sm_margin, const bool zero_centered_gamma,
                                                const int scale_offset, const int amax_offset,
                                                const int scale_inv_offset) {
  using namespace transformer_engine;

  // Tensor dimensions
  size_t N = static_cast<size_t>(input.size(0));
  size_t H = static_cast<size_t>(input.size(1));

  // Get pointers for FP8 scale, amax, scale-inverse
  void *scale_dptr = getDataPtr(scale, scale_offset);
  void *amax_dptr = getDataPtr(amax, amax_offset);
  void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

  // Construct Transformer Engine tensors
  DType itype = GetTransformerEngineDType(input.scalar_type());
  auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
  auto input_cu = makeTransformerEngineTensor(input);
  auto gamma_cu = makeTransformerEngineTensor(weight);
  auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
                                          scale_inv_dptr);
  auto rsigma_cu = makeTransformerEngineTensor(rsigma);

  // Query workspace sizes
  transformer_engine::TensorWrapper workspace;
  nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
                   workspace.data(),
                   at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                   zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  // Allocate workspaces
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace =
      makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());

  // Launch kernel
  nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
                   workspace.data(),
                   at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
                   zero_centered_gamma, at::cuda::getCurrentCUDAStream());

  return {ln_out, rsigma};
}

at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
                               at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                               transformer_engine::DType otype, const int sm_margin,
                               const bool zero_centered_gamma, const int scale_offset,
                               const int amax_offset, const int scale_inv_offset) {
  // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
  // which only returns the normalized output.
  std::vector<at::Tensor> out =
      rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin,
                      zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
  return out[0];
}

std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps,
                                    const int sm_margin, const bool zero_centered_gamma) {
  using namespace transformer_engine;

  const auto &input_ = input.contiguous();
  const auto &weight_ = weight.contiguous();

  DType itype = GetTransformerEngineDType(input.scalar_type());
  auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));

  return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma);
}

std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
                                            at::Tensor ln_out, float eps, const int sm_margin,
                                            const bool zero_centered_gamma) {
  using namespace transformer_engine;

  DType itype = GetTransformerEngineDType(input.scalar_type());

  return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(),
                                 at::Tensor(), itype, sm_margin, zero_centered_gamma);
}

at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
                           const int sm_margin, const bool zero_centered_gamma) {
  // This is a specialized version of rmsnorm_fwd, optimized for inference,
  // which only returns the normalized output.
  std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
  return out[0];
}
