//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "cpu_layout_descriptor.hpp"
#include <algorithm>
#include <numeric>

#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/dnnl_utils.hpp"

#define UNDEF undef
#define F32 data_type::f32

namespace ngraph
{
    namespace runtime
    {
        namespace cpu
        {
            const dnnl::memory::desc
                LayoutDescriptor::DummyDesc(dnnl::memory::dims(TENSOR_MAX_DIMS),
                                            dnnl::memory::F32,
                                            dnnl::memory::FORMAT::UNDEF);

            LayoutDescriptor::LayoutDescriptor(const ngraph::descriptor::Tensor& tv)
                : TensorLayout(tv)
                , m_offset(0)
                , m_dnnl_md(LayoutDescriptor::DummyDesc)
            {
                auto shape = get_shape();
                size_t s = 1;

                for (size_t i = 0; i < shape.size(); i++)
                {
                    m_strides.emplace_back(s);
                    s *= shape[shape.size() - (i + 1)];
                }
                std::reverse(m_strides.begin(), m_strides.end());
                m_buffer_size = shape_size(tv.get_shape()) * tv.get_element_type().size();
            }

            size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
            {
                if (indices.size() != m_strides.size())
                {
                    throw ngraph_error("Indices have incorrect rank");
                }
                size_t result = 0;
                for (size_t i = 0; i < indices.size(); i++)
                {
                    result += m_strides[i] * indices[i];
                }
                return result;
            }

            bool LayoutDescriptor::
                operator==(const ngraph::descriptor::layout::TensorLayout& other) const
            {
                const LayoutDescriptor* p_other = dynamic_cast<const LayoutDescriptor*>(&other);
                if (!p_other)
                {
                    return false;
                }

                if (get_element_type() != p_other->get_element_type())
                {
                    return false;
                }

                if (p_other->is_dnnl_layout())
                {
                    if (!is_dnnl_layout())
                    {
                        return false;
                    }
                    return runtime::cpu::dnnl_utils::compare_dnnl_mds(m_dnnl_md,
                                                                      p_other->get_dnnl_md());
                }

                if (m_strides != p_other->m_strides)
                {
                    return false;
                }

                if (m_offset != p_other->m_offset)
                {
                    return false;
                }

                return true;
            }

            void LayoutDescriptor::set_dnnl_md(const dnnl::memory::desc& md)
            {
                m_dnnl_md = md;

                // Since DNNL could internally pad the tensor to make blocked layouts
                // we need to compute DNNL memory requirement based on its memory desc
                // http://intel.github.io/mkl-dnn/understanding_memory_formats.html
                try
                {
                    m_buffer_size = md.get_size();
                }
                catch (const dnnl::error& e)
                {
                    throw ngraph_error("error in computing dnnl memory size from memory desc: " +
                                       DNNL_ERROR_MESSAGE);
                }
            }

            bool LayoutDescriptor::is_row_major_layout()
            {
                if (!is_dnnl_layout())
                    return true;
                auto native_md = runtime::cpu::dnnl_utils::create_blocked_dnnl_md(
                    get_shape(), m_strides, get_element_type());
                return runtime::cpu::dnnl_utils::compare_dnnl_mds(m_dnnl_md, native_md);
            }
        }
    }
}
