/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// #include <gtest/gtest.h>

// #include <cstdlib>
// #include <loss.hpp>
// #include <regularizers/no_regularizer.hpp>
// #include <utest/test_utils.hpp>
// #include <vector>

// using namespace HugeCTR;
// using namespace HugeCTR::test;

// void multi_cross_entropy_loss(size_t label_dim, size_t batch_size) {
//   std::shared_ptr<GeneralBuffer2<CudaAllocator>> buff = GeneralBuffer2<CudaAllocator>::create();

//   Tensor2<float> input_tensor;
//   buff->reserve({batch_size, label_dim}, &input_tensor);
//   Tensor2<float> label_tensor;
//   buff->reserve({batch_size, label_dim}, &label_tensor);
//   Tensor2<float> loss_tensor;
//   buff->reserve({1, 1}, &loss_tensor);

//   std::shared_ptr<BufferBlock2<float>> weight_buff = buff->create_block<float>();
//   std::shared_ptr<BufferBlock2<float>> wgrad_buff = buff->create_block<float>();

//   buff->allocate();

//   const std::vector<float> target_weight(label_dim, 1.0);

//   std::shared_ptr<NoRegularizer<float>> no_regularizer(new NoRegularizer<float>(
//       weight_buff->as_tensor(), wgrad_buff->as_tensor(), batch_size, test::get_default_gpu()));

//   MultiCrossEntropyLoss<float> mel(label_tensor, input_tensor, loss_tensor, no_regularizer,
//                                    target_weight, test::get_default_gpu(), 1);

//   std::unique_ptr<float[]> h_input(new float[batch_size * label_dim]);
//   std::unique_ptr<float[]> h_label(new float[batch_size * label_dim]);

//   float *d_input = input_tensor.get_ptr();
//   float *d_label = label_tensor.get_ptr();
//   float *d_loss = loss_tensor.get_ptr();

//   for (size_t i = 0; i < batch_size * label_dim; ++i) h_input[i] = rand() % 100 * 0.01f;
//   for (size_t i = 0; i < batch_size * label_dim; ++i) h_label[i] = rand() % 3 - 1;
//   HCTR_LIB_THROW(cudaMemcpy(d_input, h_input.get(), sizeof(float) * batch_size * label_dim,
//                             cudaMemcpyHostToDevice));
//   HCTR_LIB_THROW(cudaMemcpy(d_label, h_label.get(), sizeof(float) * batch_size * label_dim,
//                             cudaMemcpyHostToDevice));

//   mel.compute_and_init(true);

//   int scaler = 1;
// #ifdef SCALE_128
//   scaler = 128;
// #elif SCALE_256
//   scaler = 256;
// #elif SCALE_512
//   scaler = 512;
// #elif SCALE_1024
//   scaler = 1024;
// #endif

//   const float MIN_ = 1e-6;
//   float cpu_loss = 0.f;
//   for (size_t i = 0; i < batch_size * label_dim; i++) {
//     float x = h_input[i];
//     float y = h_label[i];
//     float val = 1.f / (1.f + exp(-x));
//     int target_weight_idx = i % label_dim;
//     float loss = y * log(val + MIN_) + (1.0f - y) * log(1.0f - val + MIN_);
//     cpu_loss += (h_label[i] < -0.5) ? 0.f : (target_weight[target_weight_idx] * loss);
//     float grad = -1.0f * val * (y - val) * exp(-x) / (1.0f - val + MIN_);
//     h_input[i] =
//         (h_label[i] < -0.5)
//             ? 0.f
//             : (target_weight[target_weight_idx] * grad / (batch_size * label_dim) * scaler);

//     // if(i == 0){
//     //   HCTR_LOG(INFO, WORLD, "i=%d, x=%f, y=%f, target_weight[target_weight_idx]=%f, loss=%f,
//     //   h_input=%f\n", i, x, y, target_weight[target_weight_idx], loss, h_input[i]);
//     // }
//   }
//   cpu_loss = -cpu_loss / (batch_size * label_dim);
//   ASSERT_EQ(true, cpu_gpu_cmp(h_input.get(), d_input, batch_size * label_dim))
//       << " CSE Gradient calculation failed" << std::endl;
//   ASSERT_EQ(true, cpu_gpu_cmp(&cpu_loss, d_loss, 1)) << " CSE Loss calculation failed" <<
//   std::endl;
// }

// TEST(loss_test, MultiCrossEntropyLoss11_1024) { multi_cross_entropy_loss(11, 1024); }
