# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.from_sparse_tensor_slices()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test


@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class FromSparseTensorSlicesTest(test_base.DatasetTestBase):

  def testFromSparseTensorSlices(self):
    """Test a dataset based on slices of a `tf.SparseTensor`."""
    st = array_ops.sparse_placeholder(dtypes.float64)
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.from_sparse_tensor_slices(st))
    init_op = iterator.initializer
    get_next = sparse_tensor.SparseTensor(*iterator.get_next())

    with self.cached_session() as sess:
      slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]

      # Test with sparse tensor in the appropriate order.
      indices = np.array(
          [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
      values = np.array([val for s in slices for val in s])
      dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
      sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
                                                    dense_shape)
      sess.run(init_op, feed_dict={st: sparse_feed})
      for i, s in enumerate(slices):
        results = sess.run(get_next)
        self.assertAllEqual(s, results.values)
        expected_indices = np.array(
            [[j] for j in range(len(slices[i]))]).reshape([-1, 1])
        self.assertAllEqual(expected_indices, results.indices)
        self.assertAllEqual(dense_shape[1:], results.dense_shape)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Test with sparse tensor in the reverse order, which is not
      # currently supported.
      reverse_order_indices = indices[::-1, :]
      reverse_order_values = values[::-1]
      sparse_feed = sparse_tensor.SparseTensorValue(
          reverse_order_indices, reverse_order_values, dense_shape)
      with self.assertRaises(errors.UnimplementedError):
        sess.run(init_op, feed_dict={st: sparse_feed})

      # Test with an empty sparse tensor.
      empty_indices = np.empty((0, 4), dtype=np.int64)
      empty_values = np.empty((0,), dtype=np.float64)
      empty_dense_shape = [0, 4, 37, 9]
      sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
                                                    empty_dense_shape)
      sess.run(init_op, feed_dict={st: sparse_feed})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testEmptySparseTensorSlicesInvalid(self):
    """Test a dataset based on invalid `tf.sparse.SparseTensor`."""
    st = array_ops.sparse_placeholder(dtypes.float64)
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.from_sparse_tensor_slices(st))
    init_op = iterator.initializer

    with self.cached_session() as sess:
      # Test with an empty sparse tensor but with non empty values.
      empty_indices = np.empty((0, 4), dtype=np.int64)
      non_empty_values = [1, 2, 3, 4]
      empty_dense_shape = [0, 4, 37, 9]
      sparse_feed = sparse_tensor.SparseTensorValue(empty_indices,
                                                    non_empty_values,
                                                    empty_dense_shape)
      # Here, we expect the test to fail when running the feed.
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(init_op, feed_dict={st: sparse_feed})

  def testEmptySparseTensorSlicesInvalid2(self):
    """Test a dataset based on invalid `tf.sparse.SparseTensor`."""
    st = array_ops.sparse_placeholder(dtypes.float64)
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.from_sparse_tensor_slices(st))
    init_op = iterator.initializer

    with self.cached_session() as sess:
      # Test with an empty sparse tensor but with non empty values.
      empty_indices = [[]]
      empty_values = []
      dense_shape = [1, 1]
      sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
                                                    dense_shape)
      # Here, we expect the test to fail when running the feed.
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(init_op, feed_dict={st: sparse_feed})


if __name__ == "__main__":
  test.main()
