1#!/usr/bin/env python3 2"""Test the program psa_constant_names. 3Gather constant names from header files and test cases. Compile a C program 4to print out their numerical values, feed these numerical values to 5psa_constant_names, and check that the output is the original name. 6Return 0 if all test cases pass, 1 if the output was not always as expected, 7or 1 (with a Python backtrace) if there was an operational error. 8""" 9 10# Copyright The Mbed TLS Contributors 11# SPDX-License-Identifier: Apache-2.0 12# 13# Licensed under the Apache License, Version 2.0 (the "License"); you may 14# not use this file except in compliance with the License. 15# You may obtain a copy of the License at 16# 17# http://www.apache.org/licenses/LICENSE-2.0 18# 19# Unless required by applicable law or agreed to in writing, software 20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22# See the License for the specific language governing permissions and 23# limitations under the License. 24 25import argparse 26from collections import namedtuple 27import os 28import re 29import subprocess 30import sys 31from typing import Iterable, List, Optional, Tuple 32 33import scripts_path # pylint: disable=unused-import 34from mbedtls_dev import c_build_helper 35from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator 36from mbedtls_dev import typing_util 37 38def gather_inputs(headers: Iterable[str], 39 test_suites: Iterable[str], 40 inputs_class=InputsForTest) -> PSAMacroEnumerator: 41 """Read the list of inputs to test psa_constant_names with.""" 42 inputs = inputs_class() 43 for header in headers: 44 inputs.parse_header(header) 45 for test_cases in test_suites: 46 inputs.parse_test_cases(test_cases) 47 inputs.add_numerical_values() 48 inputs.gather_arguments() 49 return inputs 50 51def run_c(type_word: str, 52 expressions: Iterable[str], 53 include_path: Optional[str] = None, 54 keep_c: bool = False) -> List[str]: 55 """Generate and run a program to print out numerical values of C expressions.""" 56 if type_word == 'status': 57 cast_to = 'long' 58 printf_format = '%ld' 59 else: 60 cast_to = 'unsigned long' 61 printf_format = '0x%08lx' 62 return c_build_helper.get_c_expression_values( 63 cast_to, printf_format, 64 expressions, 65 caller='test_psa_constant_names.py for {} values'.format(type_word), 66 file_label=type_word, 67 header='#include <psa/crypto.h>', 68 include_path=include_path, 69 keep_c=keep_c 70 ) 71 72NORMALIZE_STRIP_RE = re.compile(r'\s+') 73def normalize(expr: str) -> str: 74 """Normalize the C expression so as not to care about trivial differences. 75 76 Currently "trivial differences" means whitespace. 77 """ 78 return re.sub(NORMALIZE_STRIP_RE, '', expr) 79 80def collect_values(inputs: InputsForTest, 81 type_word: str, 82 include_path: Optional[str] = None, 83 keep_c: bool = False) -> Tuple[List[str], List[str]]: 84 """Generate expressions using known macro names and calculate their values. 85 86 Return a list of pairs of (expr, value) where expr is an expression and 87 value is a string representation of its integer value. 88 """ 89 names = inputs.get_names(type_word) 90 expressions = sorted(inputs.generate_expressions(names)) 91 values = run_c(type_word, expressions, 92 include_path=include_path, keep_c=keep_c) 93 return expressions, values 94 95class Tests: 96 """An object representing tests and their results.""" 97 98 Error = namedtuple('Error', 99 ['type', 'expression', 'value', 'output']) 100 101 def __init__(self, options) -> None: 102 self.options = options 103 self.count = 0 104 self.errors = [] #type: List[Tests.Error] 105 106 def run_one(self, inputs: InputsForTest, type_word: str) -> None: 107 """Test psa_constant_names for the specified type. 108 109 Run the program on the names for this type. 110 Use the inputs to figure out what arguments to pass to macros that 111 take arguments. 112 """ 113 expressions, values = collect_values(inputs, type_word, 114 include_path=self.options.include, 115 keep_c=self.options.keep_c) 116 output_bytes = subprocess.check_output([self.options.program, 117 type_word] + values) 118 output = output_bytes.decode('ascii') 119 outputs = output.strip().split('\n') 120 self.count += len(expressions) 121 for expr, value, output in zip(expressions, values, outputs): 122 if self.options.show: 123 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output)) 124 if normalize(expr) != normalize(output): 125 self.errors.append(self.Error(type=type_word, 126 expression=expr, 127 value=value, 128 output=output)) 129 130 def run_all(self, inputs: InputsForTest) -> None: 131 """Run psa_constant_names on all the gathered inputs.""" 132 for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 133 'key_type', 'key_usage']: 134 self.run_one(inputs, type_word) 135 136 def report(self, out: typing_util.Writable) -> None: 137 """Describe each case where the output is not as expected. 138 139 Write the errors to ``out``. 140 Also write a total. 141 """ 142 for error in self.errors: 143 out.write('For {} "{}", got "{}" (value: {})\n' 144 .format(error.type, error.expression, 145 error.output, error.value)) 146 out.write('{} test cases'.format(self.count)) 147 if self.errors: 148 out.write(', {} FAIL\n'.format(len(self.errors))) 149 else: 150 out.write(' PASS\n') 151 152HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] 153TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] 154 155def main(): 156 parser = argparse.ArgumentParser(description=globals()['__doc__']) 157 parser.add_argument('--include', '-I', 158 action='append', default=['include'], 159 help='Directory for header files') 160 parser.add_argument('--keep-c', 161 action='store_true', dest='keep_c', default=False, 162 help='Keep the intermediate C file') 163 parser.add_argument('--no-keep-c', 164 action='store_false', dest='keep_c', 165 help='Don\'t keep the intermediate C file (default)') 166 parser.add_argument('--program', 167 default='programs/psa/psa_constant_names', 168 help='Program to test') 169 parser.add_argument('--show', 170 action='store_true', 171 help='Show tested values on stdout') 172 parser.add_argument('--no-show', 173 action='store_false', dest='show', 174 help='Don\'t show tested values (default)') 175 options = parser.parse_args() 176 headers = [os.path.join(options.include[0], h) for h in HEADERS] 177 inputs = gather_inputs(headers, TEST_SUITES) 178 tests = Tests(options) 179 tests.run_all(inputs) 180 tests.report(sys.stdout) 181 if tests.errors: 182 sys.exit(1) 183 184if __name__ == '__main__': 185 main() 186