1load( 2 "//tensorflow:tensorflow.bzl", 3 "tf_cc_test", 4) 5load( 6 "//tensorflow/lite:build_def.bzl", 7 "generated_test_models", 8) 9 10# This is forked from `tensorflow/lite/build_def.bzl`. 11# TODO(b/136499575): Merge this back to TFLite codebase when open sourcing. 12def mlir_generated_test_denylisted_models(): 13 return [ 14 # TODO(b/150647400): This test passes in TF2 with tf.compat.v1 but 15 # fails in TF1 with tf.compat.v1. Due to the testing environments 16 # changing on 3/3, this will only be disabled temporarily. 17 "unidirectional_sequence_lstm", 18 "unidirectional_sequence_rnn", 19 ] 20 21# Test cases which only work with MLIR-based conversion now. 22def mlir_only_generated_test_models(): 23 return [ 24 "batchmatmul", 25 "broadcast_to", 26 "broadcast_gradient_args", 27 "cond", 28 "complex_abs", 29 "control_dep", 30 "conv_bias_relu6", 31 "conv3d", 32 "cumsum", 33 # TODO(b/186563810): Enable after resolving tensorflow_addons dep issue 34 # that causes test failures in the exported codebase. 35 # copybara:uncomment_begin 36 # "dense_image_warp", 37 # copybara:uncomment_end 38 "dynamic_rnn", 39 "einsum", 40 "identify_dilated_conv", 41 "identify_dilated_conv1d", 42 "imag", 43 "irfft2d", 44 "is_finite", 45 "max_pool_with_argmax", 46 "parse_example", 47 "real", 48 "reciprocal", 49 "reduce_all", 50 "rfft", 51 "rfft2d", 52 "segment_sum", 53 "shape_to_strided_slice", 54 "softplus", 55 "static_hashtable", 56 "static_rnn_with_control_flow_v2", 57 "stft", 58 "tensor_list_concat", 59 "tensor_list_get_item", 60 "tensor_list_length", 61 "tensor_list_resize", 62 "tensor_list_set_item", 63 "tensor_list_dynamic_shape", 64 "where_v2", 65 "while", 66 ] 67 68# Test cases which only work internally now. 69def no_oss_generated_test_models(): 70 return [ 71 "cond", 72 "equal", 73 "fill", 74 "gather", 75 "gather_nd", 76 "not_equal", 77 "parse_example", 78 "slice", 79 "sparse_to_dense", 80 "squeeze", 81 "static_hashtable", 82 "strided_slice", 83 "tile", 84 "while", 85 ] 86 87# List of models that fail generated tests for the conversion mode. 88# If you have to disable a test, please add here with a link to the appropriate 89# bug or issue. 90def generated_test_models_failing(conversion_mode): 91 return [] 92 93def mlir_generated_test_models(): 94 """Returns a list of models to be tested with MLIR-based conversion.""" 95 models = [] 96 denylisted_models = mlir_generated_test_denylisted_models() 97 for model in generated_test_models() + mlir_only_generated_test_models(): 98 if model not in denylisted_models: 99 models.append(model) 100 return models 101 102def generated_test_conversion_modes(): 103 """Returns a list of conversion modes.""" 104 105 return ["forward-compat", "", "mlir-quant"] 106 107def generated_test_models_all(): 108 """Generates a list of all tests with the different converters. 109 110 Returns: 111 List of tuples representing: 112 (conversion mode, name of test, test tags, test args). 113 """ 114 conversion_modes = generated_test_conversion_modes() 115 no_oss_tests = no_oss_generated_test_models() 116 options = [] 117 for conversion_mode in conversion_modes: 118 failing_tests = generated_test_models_failing(conversion_mode) 119 for test in mlir_generated_test_models(): 120 tags = [] 121 args = [] 122 123 # TODO(b/187992093): Exclude tests that are failing in OSS for now. 124 if test in no_oss_tests: 125 tags.append("no_oss") 126 127 # Forward-compat coverage testing is largely redundant, and 128 # contributes to coverage test bloat. 129 if conversion_mode == "forward-compat": 130 tags.append("nozapfhahn") 131 132 if test in failing_tests: 133 tags.append("notap") 134 tags.append("manual") 135 if conversion_mode: 136 test += "_%s" % conversion_mode 137 options.append((conversion_mode, test, tags, args)) 138 139 return options 140 141def gen_zip_test(name, test_name, conversion_mode, **kwargs): 142 """Generate a zipped-example test and its dependent zip files. 143 144 Args: 145 name: str. Resulting cc_test target name 146 test_name: str. Test targets this model. Comes from the list above. 147 conversion_mode: str. Which conversion mode to run with. Comes from the 148 list above. 149 **kwargs: tf_cc_test kwargs 150 """ 151 flags = "" 152 153 if conversion_mode == "forward-compat": 154 flags += " --make_forward_compat_test" 155 elif conversion_mode == "mlir-quant": 156 flags += " --mlir_quantizer" 157 158 gen_zipped_test_file( 159 name = "zip_%s" % test_name, 160 file = "%s.zip" % test_name, 161 flags = flags, 162 ) 163 tf_cc_test(name, **kwargs) 164 165def gen_zipped_test_file(name, file, flags = ""): 166 """Generate a zip file of tests by using :generate_examples. 167 168 Args: 169 name: str. Name of output. We will produce "`file`.files" as a target. 170 file: str. The name of one of the generated_examples targets, e.g. "transpose" 171 flags: str. Any additional flags to include 172 """ 173 native.genrule( 174 name = file + ".files", 175 cmd = (("$(locations :generate_examples) " + 176 " --zip_to_output {0} {1} $(@D)").format(file, flags)), 177 outs = [file], 178 # `exec_tools` is required for PY3 compatibility in place of `tools`. 179 exec_tools = [ 180 ":generate_examples", 181 ], 182 ) 183 184 native.filegroup( 185 name = name, 186 srcs = [file], 187 ) 188