1"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19import itertools
20import re
21from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
22
23
24class ReadFileLineException(Exception):
25    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
26        message = 'in {} at {}'.format(filename, line_number)
27        super(ReadFileLineException, self).__init__(message)
28        self.filename = filename
29        self.line_number = line_number
30
31
32class read_file_lines:
33    # Dear Pylint, conventionally, a context manager class name is lowercase.
34    # pylint: disable=invalid-name,too-few-public-methods
35    """Context manager to read a text file line by line.
36
37    ```
38    with read_file_lines(filename) as lines:
39        for line in lines:
40            process(line)
41    ```
42    is equivalent to
43    ```
44    with open(filename, 'r') as input_file:
45        for line in input_file:
46            process(line)
47    ```
48    except that if process(line) raises an exception, then the read_file_lines
49    snippet annotates the exception with the file name and line number.
50    """
51    def __init__(self, filename: str, binary: bool = False) -> None:
52        self.filename = filename
53        self.line_number = 'entry' #type: Union[int, str]
54        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
55        self.binary = binary
56    def __enter__(self) -> 'read_file_lines':
57        self.generator = enumerate(open(self.filename,
58                                        'rb' if self.binary else 'r'))
59        return self
60    def __iter__(self) -> Iterator[str]:
61        assert self.generator is not None
62        for line_number, content in self.generator:
63            self.line_number = line_number
64            yield content
65        self.line_number = 'exit'
66    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
67        if exc_type is not None:
68            raise ReadFileLineException(self.filename, self.line_number) \
69                from exc_value
70
71
72class PSAMacroEnumerator:
73    """Information about constructors of various PSA Crypto types.
74
75    This includes macro names as well as information about their arguments
76    when applicable.
77
78    This class only provides ways to enumerate expressions that evaluate to
79    values of the covered types. Derived classes are expected to populate
80    the set of known constructors of each kind, as well as populate
81    `self.arguments_for` for arguments that are not of a kind that is
82    enumerated here.
83    """
84    #pylint: disable=too-many-instance-attributes
85
86    def __init__(self) -> None:
87        """Set up an empty set of known constructor macros.
88        """
89        self.statuses = set() #type: Set[str]
90        self.lifetimes = set() #type: Set[str]
91        self.locations = set() #type: Set[str]
92        self.persistence_levels = set() #type: Set[str]
93        self.algorithms = set() #type: Set[str]
94        self.ecc_curves = set() #type: Set[str]
95        self.dh_groups = set() #type: Set[str]
96        self.key_types = set() #type: Set[str]
97        self.key_usage_flags = set() #type: Set[str]
98        self.hash_algorithms = set() #type: Set[str]
99        self.mac_algorithms = set() #type: Set[str]
100        self.ka_algorithms = set() #type: Set[str]
101        self.kdf_algorithms = set() #type: Set[str]
102        self.pake_algorithms = set() #type: Set[str]
103        self.aead_algorithms = set() #type: Set[str]
104        self.sign_algorithms = set() #type: Set[str]
105        # macro name -> list of argument names
106        self.argspecs = {} #type: Dict[str, List[str]]
107        # argument name -> list of values
108        self.arguments_for = {
109            'mac_length': [],
110            'min_mac_length': [],
111            'tag_length': [],
112            'min_tag_length': [],
113        } #type: Dict[str, List[str]]
114        # Whether to include intermediate macros in enumerations. Intermediate
115        # macros serve as category headers and are not valid values of their
116        # type. See `is_internal_name`.
117        # Always false in this class, may be set to true in derived classes.
118        self.include_intermediate = False
119
120    def is_internal_name(self, name: str) -> bool:
121        """Whether this is an internal macro. Internal macros will be skipped."""
122        if not self.include_intermediate:
123            if name.endswith('_BASE') or name.endswith('_NONE'):
124                return True
125            if '_CATEGORY_' in name:
126                return True
127        return name.endswith('_FLAG') or name.endswith('_MASK')
128
129    def gather_arguments(self) -> None:
130        """Populate the list of values for macro arguments.
131
132        Call this after parsing all the inputs.
133        """
134        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
135        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
136        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
137        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
138        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
139        self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
140        self.arguments_for['curve'] = sorted(self.ecc_curves)
141        self.arguments_for['group'] = sorted(self.dh_groups)
142        self.arguments_for['persistence'] = sorted(self.persistence_levels)
143        self.arguments_for['location'] = sorted(self.locations)
144        self.arguments_for['lifetime'] = sorted(self.lifetimes)
145
146    @staticmethod
147    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
148        """Format a macro call with arguments.
149
150        The resulting format is consistent with
151        `InputsForTest.normalize_argument`.
152        """
153        return name + '(' + ', '.join(arguments) + ')'
154
155    _argument_split_re = re.compile(r' *, *')
156    @classmethod
157    def _argument_split(cls, arguments: str) -> List[str]:
158        return re.split(cls._argument_split_re, arguments)
159
160    def distribute_arguments(self, name: str) -> Iterator[str]:
161        """Generate macro calls with each tested argument set.
162
163        If name is a macro without arguments, just yield "name".
164        If name is a macro with arguments, yield a series of
165        "name(arg1,...,argN)" where each argument takes each possible
166        value at least once.
167        """
168        try:
169            if name not in self.argspecs:
170                yield name
171                return
172            argspec = self.argspecs[name]
173            if argspec == []:
174                yield name + '()'
175                return
176            argument_lists = [self.arguments_for[arg] for arg in argspec]
177            arguments = [values[0] for values in argument_lists]
178            yield self._format_arguments(name, arguments)
179            # Dear Pylint, enumerate won't work here since we're modifying
180            # the array.
181            # pylint: disable=consider-using-enumerate
182            for i in range(len(arguments)):
183                for value in argument_lists[i][1:]:
184                    arguments[i] = value
185                    yield self._format_arguments(name, arguments)
186                arguments[i] = argument_lists[0][0]
187        except BaseException as e:
188            raise Exception('distribute_arguments({})'.format(name)) from e
189
190    def distribute_arguments_without_duplicates(
191            self, seen: Set[str], name: str
192    ) -> Iterator[str]:
193        """Same as `distribute_arguments`, but don't repeat seen results."""
194        for result in self.distribute_arguments(name):
195            if result not in seen:
196                seen.add(result)
197                yield result
198
199    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
200        """Generate expressions covering values constructed from the given names.
201
202        `names` can be any iterable collection of macro names.
203
204        For example:
205        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
206          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
207          every known hash algorithm ``h``.
208        * ``macros.generate_expressions(macros.key_types)`` generates all
209          key types.
210        """
211        seen = set() #type: Set[str]
212        return itertools.chain(*(
213            self.distribute_arguments_without_duplicates(seen, name)
214            for name in names
215        ))
216
217
218class PSAMacroCollector(PSAMacroEnumerator):
219    """Collect PSA crypto macro definitions from C header files.
220    """
221
222    def __init__(self, include_intermediate: bool = False) -> None:
223        """Set up an object to collect PSA macro definitions.
224
225        Call the read_file method of the constructed object on each header file.
226
227        * include_intermediate: if true, include intermediate macros such as
228          PSA_XXX_BASE that do not designate semantic values.
229        """
230        super().__init__()
231        self.include_intermediate = include_intermediate
232        self.key_types_from_curve = {} #type: Dict[str, str]
233        self.key_types_from_group = {} #type: Dict[str, str]
234        self.algorithms_from_hash = {} #type: Dict[str, str]
235
236    @staticmethod
237    def algorithm_tester(name: str) -> str:
238        """The predicate for whether an algorithm is built from the given constructor.
239
240        The given name must be the name of an algorithm constructor of the
241        form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build
242        an algorithm value. Return the corresponding predicate macro which
243        is used as ``predicate(alg)`` to test whether ``alg`` can be built
244        as ``PSA_ALG_xxx(yyy)``. The predicate is usually called
245        ``PSA_ALG_IS_xxx``.
246        """
247        prefix = 'PSA_ALG_'
248        assert name.startswith(prefix)
249        midfix = 'IS_'
250        suffix = name[len(prefix):]
251        if suffix in ['DSA', 'ECDSA']:
252            midfix += 'RANDOMIZED_'
253        elif suffix == 'RSA_PSS':
254            suffix += '_STANDARD_SALT'
255        return prefix + midfix + suffix
256
257    def record_algorithm_subtype(self, name: str, expansion: str) -> None:
258        """Record the subtype of an algorithm constructor.
259
260        Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
261        is of a subtype that is tracked in its own set, add it to the relevant
262        set.
263        """
264        # This code is very ad hoc and fragile. It should be replaced by
265        # something more robust.
266        if re.match(r'MAC(?:_|\Z)', name):
267            self.mac_algorithms.add(name)
268        elif re.match(r'KDF(?:_|\Z)', name):
269            self.kdf_algorithms.add(name)
270        elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
271            self.hash_algorithms.add(name)
272        elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
273            self.mac_algorithms.add(name)
274        elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
275            self.aead_algorithms.add(name)
276        elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
277            self.ka_algorithms.add(name)
278        elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
279            self.kdf_algorithms.add(name)
280
281    # "#define" followed by a macro name with either no parameters
282    # or a single parameter and a non-empty expansion.
283    # Grab the macro name in group 1, the parameter name if any in group 2
284    # and the expansion in group 3.
285    _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
286                                      r'(?:\s+|\((\w+)\)\s*)' +
287                                      r'(.+)')
288    _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
289
290    def read_line(self, line):
291        """Parse a C header line and record the PSA identifier it defines if any.
292        This function analyzes lines that start with "#define PSA_"
293        (up to non-significant whitespace) and skips all non-matching lines.
294        """
295        # pylint: disable=too-many-branches
296        m = re.match(self._define_directive_re, line)
297        if not m:
298            return
299        name, parameter, expansion = m.groups()
300        expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
301        if parameter:
302            self.argspecs[name] = [parameter]
303        if re.match(self._deprecated_definition_re, expansion):
304            # Skip deprecated values, which are assumed to be
305            # backward compatibility aliases that share
306            # numerical values with non-deprecated values.
307            return
308        if self.is_internal_name(name):
309            # Macro only to build actual values
310            return
311        elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
312           and not parameter:
313            self.statuses.add(name)
314        elif name.startswith('PSA_KEY_TYPE_') and not parameter:
315            self.key_types.add(name)
316        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
317            self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
318        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
319            self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
320        elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
321            self.ecc_curves.add(name)
322        elif name.startswith('PSA_DH_FAMILY_') and not parameter:
323            self.dh_groups.add(name)
324        elif name.startswith('PSA_ALG_') and not parameter:
325            if name in ['PSA_ALG_ECDSA_BASE',
326                        'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
327                # Ad hoc skipping of duplicate names for some numerical values
328                return
329            self.algorithms.add(name)
330            self.record_algorithm_subtype(name, expansion)
331        elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
332            self.algorithms_from_hash[name] = self.algorithm_tester(name)
333        elif name.startswith('PSA_KEY_USAGE_') and not parameter:
334            self.key_usage_flags.add(name)
335        else:
336            # Other macro without parameter
337            return
338
339    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
340    _continued_line_re = re.compile(rb'\\\r?\n\Z')
341    def read_file(self, header_file):
342        for line in header_file:
343            m = re.search(self._continued_line_re, line)
344            while m:
345                cont = next(header_file)
346                line = line[:m.start(0)] + cont
347                m = re.search(self._continued_line_re, line)
348            line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
349            self.read_line(line)
350
351
352class InputsForTest(PSAMacroEnumerator):
353    # pylint: disable=too-many-instance-attributes
354    """Accumulate information about macros to test.
355enumerate
356    This includes macro names as well as information about their arguments
357    when applicable.
358    """
359
360    def __init__(self) -> None:
361        super().__init__()
362        self.all_declared = set() #type: Set[str]
363        # Identifier prefixes
364        self.table_by_prefix = {
365            'ERROR': self.statuses,
366            'ALG': self.algorithms,
367            'ECC_CURVE': self.ecc_curves,
368            'DH_GROUP': self.dh_groups,
369            'KEY_LIFETIME': self.lifetimes,
370            'KEY_LOCATION': self.locations,
371            'KEY_PERSISTENCE': self.persistence_levels,
372            'KEY_TYPE': self.key_types,
373            'KEY_USAGE': self.key_usage_flags,
374        } #type: Dict[str, Set[str]]
375        # Test functions
376        self.table_by_test_function = {
377            # Any function ending in _algorithm also gets added to
378            # self.algorithms.
379            'key_type': [self.key_types],
380            'block_cipher_key_type': [self.key_types],
381            'stream_cipher_key_type': [self.key_types],
382            'ecc_key_family': [self.ecc_curves],
383            'ecc_key_types': [self.ecc_curves],
384            'dh_key_family': [self.dh_groups],
385            'dh_key_types': [self.dh_groups],
386            'hash_algorithm': [self.hash_algorithms],
387            'mac_algorithm': [self.mac_algorithms],
388            'cipher_algorithm': [],
389            'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
390            'aead_algorithm': [self.aead_algorithms],
391            'key_derivation_algorithm': [self.kdf_algorithms],
392            'key_agreement_algorithm': [self.ka_algorithms],
393            'asymmetric_signature_algorithm': [self.sign_algorithms],
394            'asymmetric_signature_wildcard': [self.algorithms],
395            'asymmetric_encryption_algorithm': [],
396            'pake_algorithm': [self.pake_algorithms],
397            'other_algorithm': [],
398            'lifetime': [self.lifetimes],
399        } #type: Dict[str, List[Set[str]]]
400        self.arguments_for['mac_length'] += ['1', '63']
401        self.arguments_for['min_mac_length'] += ['1', '63']
402        self.arguments_for['tag_length'] += ['1', '63']
403        self.arguments_for['min_tag_length'] += ['1', '63']
404
405    def add_numerical_values(self) -> None:
406        """Add numerical values that are not supported to the known identifiers."""
407        # Sets of names per type
408        self.algorithms.add('0xffffffff')
409        self.ecc_curves.add('0xff')
410        self.dh_groups.add('0xff')
411        self.key_types.add('0xffff')
412        self.key_usage_flags.add('0x80000000')
413
414        # Hard-coded values for unknown algorithms
415        #
416        # These have to have values that are correct for their respective
417        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
418        # not likely to be assigned in the near future.
419        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
420        self.mac_algorithms.add('0x03007fff')
421        self.ka_algorithms.add('0x09fc0000')
422        self.kdf_algorithms.add('0x080000ff')
423        self.pake_algorithms.add('0x0a0000ff')
424        # For AEAD algorithms, the only variability is over the tag length,
425        # and this only applies to known algorithms, so don't test an
426        # unknown algorithm.
427
428    def get_names(self, type_word: str) -> Set[str]:
429        """Return the set of known names of values of the given type."""
430        return {
431            'status': self.statuses,
432            'algorithm': self.algorithms,
433            'ecc_curve': self.ecc_curves,
434            'dh_group': self.dh_groups,
435            'key_type': self.key_types,
436            'key_usage': self.key_usage_flags,
437        }[type_word]
438
439    # Regex for interesting header lines.
440    # Groups: 1=macro name, 2=type, 3=argument list (optional).
441    _header_line_re = \
442        re.compile(r'#define +' +
443                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
444                   r'(?:\(([^\n()]*)\))?')
445    # Regex of macro names to exclude.
446    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
447    # Additional excluded macros.
448    _excluded_names = set([
449        # Macros that provide an alternative way to build the same
450        # algorithm as another macro.
451        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
452        'PSA_ALG_FULL_LENGTH_MAC',
453        # Auxiliary macro whose name doesn't fit the usual patterns for
454        # auxiliary macros.
455        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
456    ])
457    def parse_header_line(self, line: str) -> None:
458        """Parse a C header line, looking for "#define PSA_xxx"."""
459        m = re.match(self._header_line_re, line)
460        if not m:
461            return
462        name = m.group(1)
463        self.all_declared.add(name)
464        if re.search(self._excluded_name_re, name) or \
465           name in self._excluded_names or \
466           self.is_internal_name(name):
467            return
468        dest = self.table_by_prefix.get(m.group(2))
469        if dest is None:
470            return
471        dest.add(name)
472        if m.group(3):
473            self.argspecs[name] = self._argument_split(m.group(3))
474
475    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
476    def parse_header(self, filename: str) -> None:
477        """Parse a C header file, looking for "#define PSA_xxx"."""
478        with read_file_lines(filename, binary=True) as lines:
479            for line in lines:
480                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
481                self.parse_header_line(line)
482
483    _macro_identifier_re = re.compile(r'[A-Z]\w+')
484    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
485        for name in re.findall(self._macro_identifier_re, expr):
486            if name not in self.all_declared:
487                yield name
488
489    def accept_test_case_line(self, function: str, argument: str) -> bool:
490        #pylint: disable=unused-argument
491        undeclared = list(self.generate_undeclared_names(argument))
492        if undeclared:
493            raise Exception('Undeclared names in test case', undeclared)
494        return True
495
496    @staticmethod
497    def normalize_argument(argument: str) -> str:
498        """Normalize whitespace in the given C expression.
499
500        The result uses the same whitespace as
501        ` PSAMacroEnumerator.distribute_arguments`.
502        """
503        return re.sub(r',', r', ', re.sub(r' +', r'', argument))
504
505    def add_test_case_line(self, function: str, argument: str) -> None:
506        """Parse a test case data line, looking for algorithm metadata tests."""
507        sets = []
508        if function.endswith('_algorithm'):
509            sets.append(self.algorithms)
510            if function == 'key_agreement_algorithm' and \
511               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
512                # We only want *raw* key agreement algorithms as such, so
513                # exclude ones that are already chained with a KDF.
514                # Keep the expression as one to test as an algorithm.
515                function = 'other_algorithm'
516        sets += self.table_by_test_function[function]
517        if self.accept_test_case_line(function, argument):
518            for s in sets:
519                s.add(self.normalize_argument(argument))
520
521    # Regex matching a *.data line containing a test function call and
522    # its arguments. The actual definition is partly positional, but this
523    # regex is good enough in practice.
524    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
525    def parse_test_cases(self, filename: str) -> None:
526        """Parse a test case file (*.data), looking for algorithm metadata tests."""
527        with read_file_lines(filename) as lines:
528            for line in lines:
529                m = re.match(self._test_case_line_re, line)
530                if m:
531                    self.add_test_case_line(m.group(1), m.group(2))
532