//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <thread>

#include "cpu_executor.hpp"

#include "ngraph/env_util.hpp"
#include "ngraph/except.hpp"

#define MAX_PARALLELISM_THRESHOLD 2

static int GetNumCores()
{
    const auto omp_num_threads = ngraph::getenv_int("OMP_NUM_THREADS");
    const auto ngraph_intra_op_parallelism = ngraph::getenv_int("NGRAPH_INTRA_OP_PARALLELISM");
    int count = 0;

    if (omp_num_threads > 0)
    {
        count = omp_num_threads;
    }
    else if (ngraph_intra_op_parallelism > 0)
    {
        count = ngraph_intra_op_parallelism;
    }
    else
    {
        count = std::thread::hardware_concurrency() / 2;
    }

    int max_parallelism_allowed = MAX_PARALLELISM_THRESHOLD * std::thread::hardware_concurrency();
    if (count > max_parallelism_allowed)
    {
        throw ngraph::ngraph_error(
            "OMP_NUM_THREADS and/or NGRAPH_INTRA_OP_PARALLELISM is too high: "
            "(" +
            std::to_string(count) + "). Please specify a value in range [1-" +
            std::to_string(max_parallelism_allowed) + "]");
    }

    return count < 1 ? 1 : count;
}

static int GetNumThreadPools()
{
    const auto ngraph_inter_op_parallelism = ngraph::getenv_int("NGRAPH_INTER_OP_PARALLELISM");
    int count = 0;

    if (ngraph_inter_op_parallelism > 0)
    {
        count = ngraph_inter_op_parallelism;
    }

    return count < 1 ? 1 : count;
}

namespace ngraph
{
    namespace runtime
    {
        namespace cpu
        {
            namespace executor
            {
                CPUExecutor::CPUExecutor(int num_thread_pools)
                    : m_num_thread_pools(num_thread_pools)
                {
                    m_num_cores = GetNumCores();
                    for (int i = 0; i < num_thread_pools; i++)
                    {
                        int num_threads_per_pool;

                        // Eigen threadpool will still be used for reductions
                        // and other tensor operations that dont use a parallelFor
                        num_threads_per_pool = GetNumCores();

                        // User override
                        int32_t eigen_tp_count =
                            ngraph::getenv_int("NGRAPH_CPU_EIGEN_THREAD_COUNT");
                        if (eigen_tp_count > 0)
                        {
                            const int tp_count = eigen_tp_count;
                            if (tp_count < 1 || tp_count > GetNumCores())
                            {
                                throw ngraph_error(
                                    "Unexpected value specified for NGRAPH_CPU_EIGEN_THREAD_COUNT "
                                    "(" +
                                    std::to_string(eigen_tp_count) +
                                    "). Please specify a value in range [1-" +
                                    std::to_string(GetNumCores()) + "]");
                            }
                            num_threads_per_pool = tp_count;
                        }

                        m_thread_pools.push_back(std::unique_ptr<Eigen::ThreadPool>(
                            new Eigen::ThreadPool(num_threads_per_pool)));
                        m_thread_pool_devices.push_back(
                            std::unique_ptr<Eigen::ThreadPoolDevice>(new Eigen::ThreadPoolDevice(
                                m_thread_pools[i].get(), num_threads_per_pool)));
#if defined(NGRAPH_TBB_ENABLE)
                        m_tbb_arenas.emplace_back(1);
#endif
                    }
                }

#if defined(NGRAPH_TBB_ENABLE)
                void CPUExecutor::execute(CPUKernelFunctor& f,
                                          CPURuntimeContext* ctx,
                                          CPUExecutionContext* ectx,
                                          bool use_tbb)
                {
                    auto tbb_functor = [&]() { f(ctx, ectx); };
                    if (use_tbb)
                    {
                        m_tbb_arenas[ectx->arena].execute(tbb_functor);
                    }
                    else
                    {
                        f(ctx, ectx);
                    }
                }
#else
                void CPUExecutor::execute(CPUKernelFunctor& f,
                                          CPURuntimeContext* ctx,
                                          CPUExecutionContext* ectx)
                {
                    f(ctx, ectx);
                }
#endif

                CPUExecutor& GetCPUExecutor()
                {
                    static int num_thread_pools = GetNumThreadPools();
                    static CPUExecutor cpu_executor(num_thread_pools < 1 ? 1 : num_thread_pools);
                    return cpu_executor;
                }
                dnnl::engine global_cpu_engine(dnnl::engine::kind::cpu, 0);
            }
        }
    }
}
