1"""Collect macro definitions from header files. 2""" 3 4# Copyright The Mbed TLS Contributors 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the "License"); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18 19import itertools 20import re 21from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union 22 23 24class ReadFileLineException(Exception): 25 def __init__(self, filename: str, line_number: Union[int, str]) -> None: 26 message = 'in {} at {}'.format(filename, line_number) 27 super(ReadFileLineException, self).__init__(message) 28 self.filename = filename 29 self.line_number = line_number 30 31 32class read_file_lines: 33 # Dear Pylint, conventionally, a context manager class name is lowercase. 34 # pylint: disable=invalid-name,too-few-public-methods 35 """Context manager to read a text file line by line. 36 37 ``` 38 with read_file_lines(filename) as lines: 39 for line in lines: 40 process(line) 41 ``` 42 is equivalent to 43 ``` 44 with open(filename, 'r') as input_file: 45 for line in input_file: 46 process(line) 47 ``` 48 except that if process(line) raises an exception, then the read_file_lines 49 snippet annotates the exception with the file name and line number. 50 """ 51 def __init__(self, filename: str, binary: bool = False) -> None: 52 self.filename = filename 53 self.line_number = 'entry' #type: Union[int, str] 54 self.generator = None #type: Optional[Iterable[Tuple[int, str]]] 55 self.binary = binary 56 def __enter__(self) -> 'read_file_lines': 57 self.generator = enumerate(open(self.filename, 58 'rb' if self.binary else 'r')) 59 return self 60 def __iter__(self) -> Iterator[str]: 61 assert self.generator is not None 62 for line_number, content in self.generator: 63 self.line_number = line_number 64 yield content 65 self.line_number = 'exit' 66 def __exit__(self, exc_type, exc_value, exc_traceback) -> None: 67 if exc_type is not None: 68 raise ReadFileLineException(self.filename, self.line_number) \ 69 from exc_value 70 71 72class PSAMacroEnumerator: 73 """Information about constructors of various PSA Crypto types. 74 75 This includes macro names as well as information about their arguments 76 when applicable. 77 78 This class only provides ways to enumerate expressions that evaluate to 79 values of the covered types. Derived classes are expected to populate 80 the set of known constructors of each kind, as well as populate 81 `self.arguments_for` for arguments that are not of a kind that is 82 enumerated here. 83 """ 84 #pylint: disable=too-many-instance-attributes 85 86 def __init__(self) -> None: 87 """Set up an empty set of known constructor macros. 88 """ 89 self.statuses = set() #type: Set[str] 90 self.lifetimes = set() #type: Set[str] 91 self.locations = set() #type: Set[str] 92 self.persistence_levels = set() #type: Set[str] 93 self.algorithms = set() #type: Set[str] 94 self.ecc_curves = set() #type: Set[str] 95 self.dh_groups = set() #type: Set[str] 96 self.key_types = set() #type: Set[str] 97 self.key_usage_flags = set() #type: Set[str] 98 self.hash_algorithms = set() #type: Set[str] 99 self.mac_algorithms = set() #type: Set[str] 100 self.ka_algorithms = set() #type: Set[str] 101 self.kdf_algorithms = set() #type: Set[str] 102 self.pake_algorithms = set() #type: Set[str] 103 self.aead_algorithms = set() #type: Set[str] 104 self.sign_algorithms = set() #type: Set[str] 105 # macro name -> list of argument names 106 self.argspecs = {} #type: Dict[str, List[str]] 107 # argument name -> list of values 108 self.arguments_for = { 109 'mac_length': [], 110 'min_mac_length': [], 111 'tag_length': [], 112 'min_tag_length': [], 113 } #type: Dict[str, List[str]] 114 # Whether to include intermediate macros in enumerations. Intermediate 115 # macros serve as category headers and are not valid values of their 116 # type. See `is_internal_name`. 117 # Always false in this class, may be set to true in derived classes. 118 self.include_intermediate = False 119 120 def is_internal_name(self, name: str) -> bool: 121 """Whether this is an internal macro. Internal macros will be skipped.""" 122 if not self.include_intermediate: 123 if name.endswith('_BASE') or name.endswith('_NONE'): 124 return True 125 if '_CATEGORY_' in name: 126 return True 127 return name.endswith('_FLAG') or name.endswith('_MASK') 128 129 def gather_arguments(self) -> None: 130 """Populate the list of values for macro arguments. 131 132 Call this after parsing all the inputs. 133 """ 134 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) 135 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) 136 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) 137 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) 138 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) 139 self.arguments_for['sign_alg'] = sorted(self.sign_algorithms) 140 self.arguments_for['curve'] = sorted(self.ecc_curves) 141 self.arguments_for['group'] = sorted(self.dh_groups) 142 self.arguments_for['persistence'] = sorted(self.persistence_levels) 143 self.arguments_for['location'] = sorted(self.locations) 144 self.arguments_for['lifetime'] = sorted(self.lifetimes) 145 146 @staticmethod 147 def _format_arguments(name: str, arguments: Iterable[str]) -> str: 148 """Format a macro call with arguments. 149 150 The resulting format is consistent with 151 `InputsForTest.normalize_argument`. 152 """ 153 return name + '(' + ', '.join(arguments) + ')' 154 155 _argument_split_re = re.compile(r' *, *') 156 @classmethod 157 def _argument_split(cls, arguments: str) -> List[str]: 158 return re.split(cls._argument_split_re, arguments) 159 160 def distribute_arguments(self, name: str) -> Iterator[str]: 161 """Generate macro calls with each tested argument set. 162 163 If name is a macro without arguments, just yield "name". 164 If name is a macro with arguments, yield a series of 165 "name(arg1,...,argN)" where each argument takes each possible 166 value at least once. 167 """ 168 try: 169 if name not in self.argspecs: 170 yield name 171 return 172 argspec = self.argspecs[name] 173 if argspec == []: 174 yield name + '()' 175 return 176 argument_lists = [self.arguments_for[arg] for arg in argspec] 177 arguments = [values[0] for values in argument_lists] 178 yield self._format_arguments(name, arguments) 179 # Dear Pylint, enumerate won't work here since we're modifying 180 # the array. 181 # pylint: disable=consider-using-enumerate 182 for i in range(len(arguments)): 183 for value in argument_lists[i][1:]: 184 arguments[i] = value 185 yield self._format_arguments(name, arguments) 186 arguments[i] = argument_lists[0][0] 187 except BaseException as e: 188 raise Exception('distribute_arguments({})'.format(name)) from e 189 190 def distribute_arguments_without_duplicates( 191 self, seen: Set[str], name: str 192 ) -> Iterator[str]: 193 """Same as `distribute_arguments`, but don't repeat seen results.""" 194 for result in self.distribute_arguments(name): 195 if result not in seen: 196 seen.add(result) 197 yield result 198 199 def generate_expressions(self, names: Iterable[str]) -> Iterator[str]: 200 """Generate expressions covering values constructed from the given names. 201 202 `names` can be any iterable collection of macro names. 203 204 For example: 205 * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])`` 206 generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for 207 every known hash algorithm ``h``. 208 * ``macros.generate_expressions(macros.key_types)`` generates all 209 key types. 210 """ 211 seen = set() #type: Set[str] 212 return itertools.chain(*( 213 self.distribute_arguments_without_duplicates(seen, name) 214 for name in names 215 )) 216 217 218class PSAMacroCollector(PSAMacroEnumerator): 219 """Collect PSA crypto macro definitions from C header files. 220 """ 221 222 def __init__(self, include_intermediate: bool = False) -> None: 223 """Set up an object to collect PSA macro definitions. 224 225 Call the read_file method of the constructed object on each header file. 226 227 * include_intermediate: if true, include intermediate macros such as 228 PSA_XXX_BASE that do not designate semantic values. 229 """ 230 super().__init__() 231 self.include_intermediate = include_intermediate 232 self.key_types_from_curve = {} #type: Dict[str, str] 233 self.key_types_from_group = {} #type: Dict[str, str] 234 self.algorithms_from_hash = {} #type: Dict[str, str] 235 236 @staticmethod 237 def algorithm_tester(name: str) -> str: 238 """The predicate for whether an algorithm is built from the given constructor. 239 240 The given name must be the name of an algorithm constructor of the 241 form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build 242 an algorithm value. Return the corresponding predicate macro which 243 is used as ``predicate(alg)`` to test whether ``alg`` can be built 244 as ``PSA_ALG_xxx(yyy)``. The predicate is usually called 245 ``PSA_ALG_IS_xxx``. 246 """ 247 prefix = 'PSA_ALG_' 248 assert name.startswith(prefix) 249 midfix = 'IS_' 250 suffix = name[len(prefix):] 251 if suffix in ['DSA', 'ECDSA']: 252 midfix += 'RANDOMIZED_' 253 elif suffix == 'RSA_PSS': 254 suffix += '_STANDARD_SALT' 255 return prefix + midfix + suffix 256 257 def record_algorithm_subtype(self, name: str, expansion: str) -> None: 258 """Record the subtype of an algorithm constructor. 259 260 Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm 261 is of a subtype that is tracked in its own set, add it to the relevant 262 set. 263 """ 264 # This code is very ad hoc and fragile. It should be replaced by 265 # something more robust. 266 if re.match(r'MAC(?:_|\Z)', name): 267 self.mac_algorithms.add(name) 268 elif re.match(r'KDF(?:_|\Z)', name): 269 self.kdf_algorithms.add(name) 270 elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): 271 self.hash_algorithms.add(name) 272 elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion): 273 self.mac_algorithms.add(name) 274 elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion): 275 self.aead_algorithms.add(name) 276 elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): 277 self.ka_algorithms.add(name) 278 elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion): 279 self.kdf_algorithms.add(name) 280 281 # "#define" followed by a macro name with either no parameters 282 # or a single parameter and a non-empty expansion. 283 # Grab the macro name in group 1, the parameter name if any in group 2 284 # and the expansion in group 3. 285 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' + 286 r'(?:\s+|\((\w+)\)\s*)' + 287 r'(.+)') 288 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED') 289 290 def read_line(self, line): 291 """Parse a C header line and record the PSA identifier it defines if any. 292 This function analyzes lines that start with "#define PSA_" 293 (up to non-significant whitespace) and skips all non-matching lines. 294 """ 295 # pylint: disable=too-many-branches 296 m = re.match(self._define_directive_re, line) 297 if not m: 298 return 299 name, parameter, expansion = m.groups() 300 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion) 301 if parameter: 302 self.argspecs[name] = [parameter] 303 if re.match(self._deprecated_definition_re, expansion): 304 # Skip deprecated values, which are assumed to be 305 # backward compatibility aliases that share 306 # numerical values with non-deprecated values. 307 return 308 if self.is_internal_name(name): 309 # Macro only to build actual values 310 return 311 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \ 312 and not parameter: 313 self.statuses.add(name) 314 elif name.startswith('PSA_KEY_TYPE_') and not parameter: 315 self.key_types.add(name) 316 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve': 317 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:] 318 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group': 319 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:] 320 elif name.startswith('PSA_ECC_FAMILY_') and not parameter: 321 self.ecc_curves.add(name) 322 elif name.startswith('PSA_DH_FAMILY_') and not parameter: 323 self.dh_groups.add(name) 324 elif name.startswith('PSA_ALG_') and not parameter: 325 if name in ['PSA_ALG_ECDSA_BASE', 326 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']: 327 # Ad hoc skipping of duplicate names for some numerical values 328 return 329 self.algorithms.add(name) 330 self.record_algorithm_subtype(name, expansion) 331 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': 332 self.algorithms_from_hash[name] = self.algorithm_tester(name) 333 elif name.startswith('PSA_KEY_USAGE_') and not parameter: 334 self.key_usage_flags.add(name) 335 else: 336 # Other macro without parameter 337 return 338 339 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') 340 _continued_line_re = re.compile(rb'\\\r?\n\Z') 341 def read_file(self, header_file): 342 for line in header_file: 343 m = re.search(self._continued_line_re, line) 344 while m: 345 cont = next(header_file) 346 line = line[:m.start(0)] + cont 347 m = re.search(self._continued_line_re, line) 348 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 349 self.read_line(line) 350 351 352class InputsForTest(PSAMacroEnumerator): 353 # pylint: disable=too-many-instance-attributes 354 """Accumulate information about macros to test. 355enumerate 356 This includes macro names as well as information about their arguments 357 when applicable. 358 """ 359 360 def __init__(self) -> None: 361 super().__init__() 362 self.all_declared = set() #type: Set[str] 363 # Identifier prefixes 364 self.table_by_prefix = { 365 'ERROR': self.statuses, 366 'ALG': self.algorithms, 367 'ECC_CURVE': self.ecc_curves, 368 'DH_GROUP': self.dh_groups, 369 'KEY_LIFETIME': self.lifetimes, 370 'KEY_LOCATION': self.locations, 371 'KEY_PERSISTENCE': self.persistence_levels, 372 'KEY_TYPE': self.key_types, 373 'KEY_USAGE': self.key_usage_flags, 374 } #type: Dict[str, Set[str]] 375 # Test functions 376 self.table_by_test_function = { 377 # Any function ending in _algorithm also gets added to 378 # self.algorithms. 379 'key_type': [self.key_types], 380 'block_cipher_key_type': [self.key_types], 381 'stream_cipher_key_type': [self.key_types], 382 'ecc_key_family': [self.ecc_curves], 383 'ecc_key_types': [self.ecc_curves], 384 'dh_key_family': [self.dh_groups], 385 'dh_key_types': [self.dh_groups], 386 'hash_algorithm': [self.hash_algorithms], 387 'mac_algorithm': [self.mac_algorithms], 388 'cipher_algorithm': [], 389 'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms], 390 'aead_algorithm': [self.aead_algorithms], 391 'key_derivation_algorithm': [self.kdf_algorithms], 392 'key_agreement_algorithm': [self.ka_algorithms], 393 'asymmetric_signature_algorithm': [self.sign_algorithms], 394 'asymmetric_signature_wildcard': [self.algorithms], 395 'asymmetric_encryption_algorithm': [], 396 'pake_algorithm': [self.pake_algorithms], 397 'other_algorithm': [], 398 'lifetime': [self.lifetimes], 399 } #type: Dict[str, List[Set[str]]] 400 self.arguments_for['mac_length'] += ['1', '63'] 401 self.arguments_for['min_mac_length'] += ['1', '63'] 402 self.arguments_for['tag_length'] += ['1', '63'] 403 self.arguments_for['min_tag_length'] += ['1', '63'] 404 405 def add_numerical_values(self) -> None: 406 """Add numerical values that are not supported to the known identifiers.""" 407 # Sets of names per type 408 self.algorithms.add('0xffffffff') 409 self.ecc_curves.add('0xff') 410 self.dh_groups.add('0xff') 411 self.key_types.add('0xffff') 412 self.key_usage_flags.add('0x80000000') 413 414 # Hard-coded values for unknown algorithms 415 # 416 # These have to have values that are correct for their respective 417 # PSA_ALG_IS_xxx macros, but are also not currently assigned and are 418 # not likely to be assigned in the near future. 419 self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH 420 self.mac_algorithms.add('0x03007fff') 421 self.ka_algorithms.add('0x09fc0000') 422 self.kdf_algorithms.add('0x080000ff') 423 self.pake_algorithms.add('0x0a0000ff') 424 # For AEAD algorithms, the only variability is over the tag length, 425 # and this only applies to known algorithms, so don't test an 426 # unknown algorithm. 427 428 def get_names(self, type_word: str) -> Set[str]: 429 """Return the set of known names of values of the given type.""" 430 return { 431 'status': self.statuses, 432 'algorithm': self.algorithms, 433 'ecc_curve': self.ecc_curves, 434 'dh_group': self.dh_groups, 435 'key_type': self.key_types, 436 'key_usage': self.key_usage_flags, 437 }[type_word] 438 439 # Regex for interesting header lines. 440 # Groups: 1=macro name, 2=type, 3=argument list (optional). 441 _header_line_re = \ 442 re.compile(r'#define +' + 443 r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + 444 r'(?:\(([^\n()]*)\))?') 445 # Regex of macro names to exclude. 446 _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') 447 # Additional excluded macros. 448 _excluded_names = set([ 449 # Macros that provide an alternative way to build the same 450 # algorithm as another macro. 451 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG', 452 'PSA_ALG_FULL_LENGTH_MAC', 453 # Auxiliary macro whose name doesn't fit the usual patterns for 454 # auxiliary macros. 455 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', 456 ]) 457 def parse_header_line(self, line: str) -> None: 458 """Parse a C header line, looking for "#define PSA_xxx".""" 459 m = re.match(self._header_line_re, line) 460 if not m: 461 return 462 name = m.group(1) 463 self.all_declared.add(name) 464 if re.search(self._excluded_name_re, name) or \ 465 name in self._excluded_names or \ 466 self.is_internal_name(name): 467 return 468 dest = self.table_by_prefix.get(m.group(2)) 469 if dest is None: 470 return 471 dest.add(name) 472 if m.group(3): 473 self.argspecs[name] = self._argument_split(m.group(3)) 474 475 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern 476 def parse_header(self, filename: str) -> None: 477 """Parse a C header file, looking for "#define PSA_xxx".""" 478 with read_file_lines(filename, binary=True) as lines: 479 for line in lines: 480 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 481 self.parse_header_line(line) 482 483 _macro_identifier_re = re.compile(r'[A-Z]\w+') 484 def generate_undeclared_names(self, expr: str) -> Iterable[str]: 485 for name in re.findall(self._macro_identifier_re, expr): 486 if name not in self.all_declared: 487 yield name 488 489 def accept_test_case_line(self, function: str, argument: str) -> bool: 490 #pylint: disable=unused-argument 491 undeclared = list(self.generate_undeclared_names(argument)) 492 if undeclared: 493 raise Exception('Undeclared names in test case', undeclared) 494 return True 495 496 @staticmethod 497 def normalize_argument(argument: str) -> str: 498 """Normalize whitespace in the given C expression. 499 500 The result uses the same whitespace as 501 ` PSAMacroEnumerator.distribute_arguments`. 502 """ 503 return re.sub(r',', r', ', re.sub(r' +', r'', argument)) 504 505 def add_test_case_line(self, function: str, argument: str) -> None: 506 """Parse a test case data line, looking for algorithm metadata tests.""" 507 sets = [] 508 if function.endswith('_algorithm'): 509 sets.append(self.algorithms) 510 if function == 'key_agreement_algorithm' and \ 511 argument.startswith('PSA_ALG_KEY_AGREEMENT('): 512 # We only want *raw* key agreement algorithms as such, so 513 # exclude ones that are already chained with a KDF. 514 # Keep the expression as one to test as an algorithm. 515 function = 'other_algorithm' 516 sets += self.table_by_test_function[function] 517 if self.accept_test_case_line(function, argument): 518 for s in sets: 519 s.add(self.normalize_argument(argument)) 520 521 # Regex matching a *.data line containing a test function call and 522 # its arguments. The actual definition is partly positional, but this 523 # regex is good enough in practice. 524 _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') 525 def parse_test_cases(self, filename: str) -> None: 526 """Parse a test case file (*.data), looking for algorithm metadata tests.""" 527 with read_file_lines(filename) as lines: 528 for line in lines: 529 m = re.match(self._test_case_line_re, line) 530 if m: 531 self.add_test_case_line(m.group(1), m.group(2)) 532