/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <algorithm>
#include <functional>
#include <layers/reduce_sum_layer.hpp>
#include <network_buffer_channels.hpp>
#include <utils.cuh>
#include <utils.hpp>

namespace HugeCTR {

namespace {

template <size_t length, typename T>
__device__ int array_length(T (&arr)[length]) {
  return length;
}

// this kernel can support dims_size=1/2/3
template <typename T, typename... Args>
__global__ void reduce_sum_kernel(const T* input, T* output, int axis, Args... args) {
  size_t in_dims[] = {args...};
  int dims_size = array_length(in_dims);
  T local_sum = 0.0f;

  if (axis == 0) {  // block_num = dim1 * dim2, do dim0 number of elements reduction in one block
    if (dims_size == 1) {  // dims_size == 1
      for (int tid = threadIdx.x; tid < in_dims[0]; tid += blockDim.x) {
        local_sum += input[tid];
      }
    } else if (dims_size == 2) {  // dims_size == 2
      for (int tid = threadIdx.x; tid < in_dims[0]; tid += blockDim.x) {
        local_sum += input[tid * in_dims[1] + blockIdx.x];
      }
    } else if (dims_size == 3) {  // dims_size == 3
      for (int tid = threadIdx.x; tid < in_dims[0]; tid += blockDim.x) {
        local_sum += input[tid * (in_dims[1] * in_dims[2]) + blockIdx.x];
      }
    }
  } else if (axis ==
             1) {  // block_num = dim0 * dim2, do dim1 number of elements reduction in one block
    if (dims_size == 2) {  // dims_size == 2
      for (int tid = threadIdx.x; tid < in_dims[1]; tid += blockDim.x) {
        local_sum += input[blockIdx.x * in_dims[1] + tid];
      }
    } else if (dims_size == 3) {  // dims_size == 3
      for (int tid = threadIdx.x; tid < in_dims[1]; tid += blockDim.x) {
        local_sum += input[blockIdx.x / in_dims[2] * (in_dims[1] * in_dims[2]) + tid * in_dims[2] +
                           blockIdx.x % in_dims[2]];
      }
    }
  } else if (axis ==
             2) {  // block_num = dim0 * dim1, do dim2 number of elements reduction in one block
    for (int tid = threadIdx.x; tid < in_dims[2]; tid += blockDim.x) {
      local_sum += input[blockIdx.x * in_dims[2] + tid];
    }
  }

  local_sum = blockReduceSum(local_sum);
  if (threadIdx.x == 0) {
    output[blockIdx.x] = local_sum;
  }
}

template <typename T, typename... Args>
__global__ void reduce_sum_dgrad_kernel(const T* top_grad, T* dgrad, int axis, Args... args) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  size_t in_dims[] = {args...};
  int dims_size = array_length(in_dims);

  if (axis == 0) {
    if (dims_size == 1) {  // dims_size == 1
      if (tid < in_dims[0]) {
        dgrad[tid] = top_grad[0];
      }
    } else if (dims_size == 2) {  // dims_size == 2
      if (tid < (in_dims[0] * in_dims[1])) {
        dgrad[tid] = top_grad[tid % in_dims[1]];
      }
    } else if (dims_size == 3) {  // dims_size == 3
      if (tid < (in_dims[0] * in_dims[1] * in_dims[2])) {
        int dim1_index = tid % (in_dims[1] * in_dims[2]) / in_dims[2];
        int dim2_index = tid % in_dims[2];
        dgrad[tid] = top_grad[dim1_index * in_dims[2] + dim2_index];
      }
    }
  } else if (axis == 1) {
    if (dims_size == 2) {  // dims_size == 2
      if (tid < (in_dims[0] * in_dims[1])) {
        dgrad[tid] = top_grad[tid / in_dims[1]];
      }
    } else if (dims_size == 3) {  // dims_size == 3
      if (tid < (in_dims[0] * in_dims[1] * in_dims[2])) {
        int dim0_index = tid / (in_dims[1] * in_dims[2]);
        int dim2_index = tid % in_dims[2];
        dgrad[tid] = top_grad[dim0_index * in_dims[2] + dim2_index];
      }
    }
  } else if (axis == 2) {
    int dim0_index = tid / (in_dims[1] * in_dims[2]);
    int dim1_index = tid % (in_dims[1] * in_dims[2]) / in_dims[2];
    dgrad[tid] = top_grad[dim0_index * in_dims[1] + dim1_index];
  }
}

}  // end of namespace

template <typename T>
ReduceSumLayer<T>::ReduceSumLayer(const core23::Tensor& input_tensor, core23::Tensor& output_tensor,
                                  int axis, const std::shared_ptr<GPUResource>& gpu_resource)
    : Layer({input_tensor}, {}, gpu_resource), axis_(axis) {
  try {
    // error input checking
    const auto& in_shape = input_tensor.shape();
    if (in_shape.size() == 0) {
      HCTR_OWN_THROW(Error_t::WrongInput, "The input dims can not be 0");
    }
    if (axis >= in_shape.dims() || axis < 0) {
      HCTR_OWN_THROW(Error_t::WrongInput, "The axis is overflow");
    }

    core23::Shape out_shape(in_shape.dims());
    for (auto i = 0; i < in_shape.dims(); i++) {
      if (i == axis) {
        out_shape.set(i, 1);
      } else {
        out_shape.set(i, in_shape.size(i));
      }
    }
    core23::BufferParams buf_p{.channel = GetBlobsBufferChannel()};

    output_tensor = core23::Tensor(input_tensor.my_params().shape(out_shape).buffer_params(buf_p));
    output_tensors_.push_back(output_tensor);
  } catch (const std::runtime_error& rt_err) {
    HCTR_LOG_S(ERROR, WORLD) << rt_err.what() << std::endl;
    throw;
  }
}

template <typename T>
void ReduceSumLayer<T>::fprop(bool is_train) {
  CudaDeviceContext context(get_device_id());
  auto* input = input_tensors_[0].data<T>();
  auto* output = output_tensors_[0].data<T>();
  auto in_shape = input_tensors_[0].shape();
  auto out_shape = output_tensors_[0].shape();

  auto block_num = out_shape.size();

  dim3 blockSize(256, 1, 1);
  dim3 gridSize(block_num, 1, 1);
  if (in_shape.dims() == 1) {
    reduce_sum_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        input, output, axis_, static_cast<size_t>(in_shape.size(0)));
  } else if (in_shape.dims() == 2) {
    reduce_sum_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        input, output, axis_, static_cast<size_t>(in_shape.size(0)),
        static_cast<size_t>(in_shape.size(1)));
  } else if (in_shape.dims() == 3) {
    reduce_sum_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        input, output, axis_, static_cast<size_t>(in_shape.size(0)),
        static_cast<size_t>(in_shape.size(1)), static_cast<size_t>(in_shape.size(2)));
  }
}

template <typename T>
void ReduceSumLayer<T>::bprop() {
  CudaDeviceContext context(get_device_id());

  auto* input = input_tensors_[0].data<T>();
  auto* output = output_tensors_[0].data<T>();
  auto in_shape = input_tensors_[0].shape();

  auto size = in_shape.size();

  dim3 blockSize(256, 1, 1);
  dim3 gridSize((size + blockSize.x - 1) / blockSize.x, 1, 1);
  if (in_shape.dims() == 1) {
    reduce_sum_dgrad_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        output, input, axis_, static_cast<size_t>(in_shape.size(0)));
  } else if (in_shape.dims() == 2) {
    reduce_sum_dgrad_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        output, input, axis_, static_cast<size_t>(in_shape.size(0)),
        static_cast<size_t>(in_shape.size(1)));
  } else if (in_shape.dims() == 3) {
    reduce_sum_dgrad_kernel<<<gridSize, blockSize, 0, get_gpu().get_stream()>>>(
        output, input, axis_, static_cast<size_t>(in_shape.size(0)),
        static_cast<size_t>(in_shape.size(1)), static_cast<size_t>(in_shape.size(2)));
  }
}

template class ReduceSumLayer<float>;
template class ReduceSumLayer<__half>;

}  // namespace HugeCTR
