1load(
2    "//tensorflow:tensorflow.bzl",
3    "tf_cc_test",
4)
5load(
6    "//tensorflow/lite:build_def.bzl",
7    "generated_test_models",
8)
9
10# This is forked from `tensorflow/lite/build_def.bzl`.
11# TODO(b/136499575): Merge this back to TFLite codebase when open sourcing.
12def mlir_generated_test_denylisted_models():
13    return [
14        # TODO(b/150647400): This test passes in TF2 with tf.compat.v1 but
15        # fails in TF1 with tf.compat.v1. Due to the testing environments
16        # changing on 3/3, this will only be disabled temporarily.
17        "unidirectional_sequence_lstm",
18        "unidirectional_sequence_rnn",
19    ]
20
21# Test cases which only work with MLIR-based conversion now.
22def mlir_only_generated_test_models():
23    return [
24        "batchmatmul",
25        "broadcast_to",
26        "broadcast_gradient_args",
27        "cond",
28        "complex_abs",
29        "control_dep",
30        "conv_bias_relu6",
31        "conv3d",
32        "cumsum",
33        # TODO(b/186563810): Enable after resolving tensorflow_addons dep issue
34        # that causes test failures in the exported codebase.
35        # copybara:uncomment_begin
36        # "dense_image_warp",
37        # copybara:uncomment_end
38        "dynamic_rnn",
39        "einsum",
40        "identify_dilated_conv",
41        "identify_dilated_conv1d",
42        "imag",
43        "irfft2d",
44        "is_finite",
45        "max_pool_with_argmax",
46        "parse_example",
47        "real",
48        "reciprocal",
49        "reduce_all",
50        "rfft",
51        "rfft2d",
52        "segment_sum",
53        "shape_to_strided_slice",
54        "softplus",
55        "static_hashtable",
56        "static_rnn_with_control_flow_v2",
57        "stft",
58        "tensor_list_concat",
59        "tensor_list_get_item",
60        "tensor_list_length",
61        "tensor_list_resize",
62        "tensor_list_set_item",
63        "tensor_list_dynamic_shape",
64        "where_v2",
65        "while",
66    ]
67
68# Test cases which only work internally now.
69def no_oss_generated_test_models():
70    return [
71        "cond",
72        "equal",
73        "fill",
74        "gather",
75        "gather_nd",
76        "not_equal",
77        "parse_example",
78        "slice",
79        "sparse_to_dense",
80        "squeeze",
81        "static_hashtable",
82        "strided_slice",
83        "tile",
84        "while",
85    ]
86
87# List of models that fail generated tests for the conversion mode.
88# If you have to disable a test, please add here with a link to the appropriate
89# bug or issue.
90def generated_test_models_failing(conversion_mode):
91    return []
92
93def mlir_generated_test_models():
94    """Returns a list of models to be tested with MLIR-based conversion."""
95    models = []
96    denylisted_models = mlir_generated_test_denylisted_models()
97    for model in generated_test_models() + mlir_only_generated_test_models():
98        if model not in denylisted_models:
99            models.append(model)
100    return models
101
102def generated_test_conversion_modes():
103    """Returns a list of conversion modes."""
104
105    return ["forward-compat", "", "mlir-quant"]
106
107def generated_test_models_all():
108    """Generates a list of all tests with the different converters.
109
110    Returns:
111      List of tuples representing:
112            (conversion mode, name of test, test tags, test args).
113    """
114    conversion_modes = generated_test_conversion_modes()
115    no_oss_tests = no_oss_generated_test_models()
116    options = []
117    for conversion_mode in conversion_modes:
118        failing_tests = generated_test_models_failing(conversion_mode)
119        for test in mlir_generated_test_models():
120            tags = []
121            args = []
122
123            # TODO(b/187992093): Exclude tests that are failing in OSS for now.
124            if test in no_oss_tests:
125                tags.append("no_oss")
126
127            # Forward-compat coverage testing is largely redundant, and
128            # contributes to coverage test bloat.
129            if conversion_mode == "forward-compat":
130                tags.append("nozapfhahn")
131
132            if test in failing_tests:
133                tags.append("notap")
134                tags.append("manual")
135            if conversion_mode:
136                test += "_%s" % conversion_mode
137            options.append((conversion_mode, test, tags, args))
138
139    return options
140
141def gen_zip_test(name, test_name, conversion_mode, **kwargs):
142    """Generate a zipped-example test and its dependent zip files.
143
144    Args:
145      name: str. Resulting cc_test target name
146      test_name: str. Test targets this model. Comes from the list above.
147      conversion_mode: str. Which conversion mode to run with. Comes from the
148        list above.
149      **kwargs: tf_cc_test kwargs
150    """
151    flags = ""
152
153    if conversion_mode == "forward-compat":
154        flags += " --make_forward_compat_test"
155    elif conversion_mode == "mlir-quant":
156        flags += " --mlir_quantizer"
157
158    gen_zipped_test_file(
159        name = "zip_%s" % test_name,
160        file = "%s.zip" % test_name,
161        flags = flags,
162    )
163    tf_cc_test(name, **kwargs)
164
165def gen_zipped_test_file(name, file, flags = ""):
166    """Generate a zip file of tests by using :generate_examples.
167
168    Args:
169      name: str. Name of output. We will produce "`file`.files" as a target.
170      file: str. The name of one of the generated_examples targets, e.g. "transpose"
171      flags: str. Any additional flags to include
172    """
173    native.genrule(
174        name = file + ".files",
175        cmd = (("$(locations :generate_examples) " +
176                " --zip_to_output {0} {1} $(@D)").format(file, flags)),
177        outs = [file],
178        # `exec_tools` is required for PY3 compatibility in place of `tools`.
179        exec_tools = [
180            ":generate_examples",
181        ],
182    )
183
184    native.filegroup(
185        name = name,
186        srcs = [file],
187    )
188