1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for strided_slice operators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26def _make_shape_to_strided_slice_test(options,
27                                      test_parameters,
28                                      expected_tf_failures=0):
29  """Utility function to make shape_to_strided_slice_tests."""
30
31  def build_graph(parameters):
32    """Build graph for shape_stride_slice test."""
33    input_tensor = tf.compat.v1.placeholder(
34        dtype=parameters["dtype"],
35        name="input",
36        shape=parameters["dynamic_input_shape"])
37    begin = parameters["begin"]
38    end = parameters["end"]
39    strides = parameters["strides"]
40    tensors = [input_tensor]
41    out = tf.strided_slice(
42        tf.shape(input_tensor),
43        begin,
44        end,
45        strides,
46        begin_mask=parameters["begin_mask"],
47        end_mask=parameters["end_mask"])
48    return tensors, [out]
49
50  def build_inputs(parameters, sess, inputs, outputs):
51    """Build inputs for stride_slice test."""
52    input_values = create_tensor_data(
53        parameters["dtype"],
54        parameters["input_shape"],
55        min_value=-1,
56        max_value=1)
57    values = [input_values]
58
59    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
60
61  make_zip_of_tests(
62      options,
63      test_parameters,
64      build_graph,
65      build_inputs,
66      expected_tf_failures=expected_tf_failures)
67
68
69@register_make_test_function()
70def make_shape_to_strided_slice_tests(options):
71  """Make a set of tests to do shape op into strided_slice."""
72
73  test_parameters = [
74      # Test dynamic shape into strided slice quantization works.
75      {
76          "dtype": [tf.float32],
77          "dynamic_input_shape": [[None, 2, 2, 5]],
78          "input_shape": [[12, 2, 2, 5]],
79          "strides": [[1]],
80          "begin": [[0]],
81          "end": [[1]],
82          "begin_mask": [0],
83          "end_mask": [0],
84          "fully_quantize": [False, True],
85          "dynamic_range_quantize": [False],
86      },
87      {
88          "dtype": [tf.float32],
89          "dynamic_input_shape": [[None, 2, 2, 5]],
90          "input_shape": [[12, 2, 2, 5]],
91          "strides": [[1]],
92          "begin": [[0]],
93          "end": [[1]],
94          "begin_mask": [0],
95          "end_mask": [0],
96          "fully_quantize": [False],
97          "dynamic_range_quantize": [True],
98      },
99  ]
100  _make_shape_to_strided_slice_test(
101      options, test_parameters, expected_tf_failures=0)
102