/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * License); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*
 * Copyright (c) 2020, OPEN AI LAB
 * Author: lswang@openailab.com
 */

#include "algorithm/centerface.hpp"

#include "utilities/cmdline.hpp"
#include "utilities/timer.hpp"

#include <algorithm>
#include <numeric>

#define VXDEVICE "VX"
const float DET_THRESHOLD   =   0.45f;
const float NMS_THRESHOLD   =   0.30f;

const int MODEL_WIDTH       =   640;
const int MODEL_HEIGHT      =   384;


int main(int argc, char* argv[])
{
    cmdline::parser cmd;

    cmd.add<std::string>("model", 'm', "model config file", true, "");
    cmd.add<std::string>("image", 'i', "image to infer", true, "");
    cmd.add<std::string>("device", 'd', "device", false, "CPU");
    cmd.add<float>("score_threshold", 's', "score threshold", false, DET_THRESHOLD);
    cmd.add<float>("iou_threshold", 'o', "iou threshold", false, NMS_THRESHOLD);
    cmd.add<int>("skip_nms", 'k', "skip nms", false, 0);
    cmd.add<int>("repeat", 'r', "repeat time", false, 20);
    cmd.add<int>("benchmark", 'b', "benchmark time", false, 0);

    cmd.parse_check(argc, argv);

    auto model_path = cmd.get<std::string>("model");
    auto image_path = cmd.get<std::string>("image");
    auto device = cmd.get<std::string>("device");
    auto score_threshold = cmd.get<float>("score_threshold");
    auto iou_threshold = cmd.get<float>("iou_threshold");
    auto skip_nms = cmd.get<int>("skip_nms") != 0;
    auto repeat = cmd.get<int>("repeat");
    auto benchmark = cmd.get<int>("benchmark") != 0;

    cv::Mat image = cv::imread(image_path);

    if (image.empty())
    {
        fprintf(stderr, "Reading image was failed.\n");
        return -1;
    }

    init_tengine();

    cv::Size input_shape(MODEL_WIDTH, MODEL_HEIGHT);

    CenterFace detector;
    detector.Load(model_path, input_shape, device);

    std::vector<Region> faces;
    detector.Detect(image, faces, score_threshold, iou_threshold);
    for (auto& face : faces)
    {
        // box
        cv::Rect2f rect(face.box.x, face.box.y, face.box.width, face.box.height);

        // calculate cosine distance


        // draw box
        cv::rectangle(image, rect, cv::Scalar(0, 0, 255), 2);
        std::string box_confidence = "DET: " + std::to_string(face.confidence);
        cv::putText(image, box_confidence, rect.tl() + cv::Point2f(5, -10), cv::FONT_HERSHEY_TRIPLEX, 0.6f, cv::Scalar(255, 255, 0));
    }

    cv::imwrite("demo.png", image);

    release_tengine();

    return 0;
}
