# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for beam unit tests."""

import os
from typing import Optional, Union, List
import tempfile
from absl.testing import absltest

from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to
from temporian.core.data.dtype import DType

from temporian.test.utils import assertEqualEventSet
from temporian.beam.io.csv import from_csv as beam_from_csv
from temporian.beam.io.csv import to_csv as beam_to_csv
from temporian.beam.evaluation import run_multi_io
from temporian.io.csv import to_csv, from_csv
from temporian.io.tensorflow import to_tensorflow_record, from_tensorflow_record
from temporian.core.data.node import EventSetNode
from temporian.implementation.numpy.data.event_set import EventSet
from temporian.beam.io.tensorflow import (
    from_tensorflow_record as beam_from_tensorflow_record,
)
from temporian.beam.io.tensorflow import (
    to_tensorflow_record as beam_to_tensorflow_record,
)


def check_beam_implementation(
    test: absltest.TestCase,
    input_data: Union[EventSet, List[EventSet]],
    output_node: EventSetNode,
    cast: Optional[DType] = None,
):
    """Checks the result of the Numpy backend against the Beam backend.

    Args:
        test: The absl's test.
        input_data: An event set to feed to a graph.
        output_node: Output of the graph.
        input_node: Input of the graph. If not set, uses input_data.node()
            instead.
        cast: DType to cast beam's output to after loading it from csv. Useful
            for comparing outputs that are expected to be int32 for example,
            since when written to CSV those will be loaded back up as int64.
    """

    if isinstance(input_data, EventSet):
        input_data = [input_data]

    tmp_dir = tempfile.mkdtemp()
    output_path = os.path.join(tmp_dir, "output.csv")
    input_paths = []

    # Export input data to csv
    for input_idx, input_evset in enumerate(input_data):
        input_path = os.path.join(tmp_dir, f"input_{input_idx}.csv")
        input_paths.append(input_path)
        to_tensorflow_record(input_evset, path=input_path)

    # Run the Temporian program using the Beam backend
    with TestPipeline() as p:
        input_pcollection = {}
        for input_path, input_evset in zip(input_paths, input_data):
            input_pcollection[
                input_evset.node()
            ] = p | beam_from_tensorflow_record(
                input_path, input_evset.node().schema
            )

        output_pcollection = run_multi_io(
            inputs=input_pcollection, outputs=[output_node]
        )

        assert len(output_pcollection) == 1

        output = output_pcollection[output_node] | beam_to_tensorflow_record(
            output_path, output_node.schema, shard_name_template=""
        )

        assert_that(
            output,
            equal_to([output_path]),
        )

    beam_output = from_tensorflow_record(output_path, output_node.schema)

    if cast:
        beam_output = beam_output.cast(cast)

    # Run the Temporian program using the numpy backend
    expected_output = output_node.run(input_data)

    assertEqualEventSet(test, beam_output, expected_output)
