# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
    Test tensor restful api.
Usage:
    pytest tests/st/func/datavisual
"""
import pytest

from mindinsight.datavisual.common.enums import PluginNameEnum

from .....utils.tools import get_url
from .. import globals as gbl

BASE_URL = '/v1/mindinsight/datavisual/tensors'


class TestTensors:
    """Test Tensors."""

    @pytest.mark.level0
    @pytest.mark.env_single
    @pytest.mark.platform_x86_cpu
    @pytest.mark.platform_arm_ascend_training
    @pytest.mark.platform_x86_gpu_training
    @pytest.mark.platform_x86_ascend_training
    @pytest.mark.usefixtures("init_summary_logs")
    def test_tensors(self, client):
        """Test getting tensor data."""
        plugin_name = PluginNameEnum.TENSOR.value
        train_id = gbl.get_train_ids()[0]
        tag_name = gbl.get_tags(train_id, plugin_name)[0]
        expected_tensors = gbl.get_metadata(train_id, tag_name)

        params = dict(train_id=train_id, tag=tag_name)
        url = get_url(BASE_URL, params)
        response = client.get(url)
        tensors = response.get_json().get("tensors")[0].get("values")

        for tensors, expected_tensors in zip(tensors, expected_tensors):
            assert tensors.get("wall_time") == expected_tensors.get("wall_time")
            assert tensors.get("step") == expected_tensors.get("step")
