1#!/usr/bin/env python3
2"""Generate test data for PSA cryptographic mechanisms.
3
4With no arguments, generate all test data. With non-option arguments,
5generate only the specified files.
6"""
7
8# Copyright The Mbed TLS Contributors
9# SPDX-License-Identifier: Apache-2.0
10#
11# Licensed under the Apache License, Version 2.0 (the "License"); you may
12# not use this file except in compliance with the License.
13# You may obtain a copy of the License at
14#
15# http://www.apache.org/licenses/LICENSE-2.0
16#
17# Unless required by applicable law or agreed to in writing, software
18# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
19# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20# See the License for the specific language governing permissions and
21# limitations under the License.
22
23import argparse
24import os
25import posixpath
26import re
27import sys
28from typing import Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional, TypeVar
29
30import scripts_path # pylint: disable=unused-import
31from mbedtls_dev import build_tree
32from mbedtls_dev import crypto_knowledge
33from mbedtls_dev import macro_collector
34from mbedtls_dev import psa_storage
35from mbedtls_dev import test_case
36
37T = TypeVar('T') #pylint: disable=invalid-name
38
39
40def psa_want_symbol(name: str) -> str:
41    """Return the PSA_WANT_xxx symbol associated with a PSA crypto feature."""
42    if name.startswith('PSA_'):
43        return name[:4] + 'WANT_' + name[4:]
44    else:
45        raise ValueError('Unable to determine the PSA_WANT_ symbol for ' + name)
46
47def finish_family_dependency(dep: str, bits: int) -> str:
48    """Finish dep if it's a family dependency symbol prefix.
49
50    A family dependency symbol prefix is a PSA_WANT_ symbol that needs to be
51    qualified by the key size. If dep is such a symbol, finish it by adjusting
52    the prefix and appending the key size. Other symbols are left unchanged.
53    """
54    return re.sub(r'_FAMILY_(.*)', r'_\1_' + str(bits), dep)
55
56def finish_family_dependencies(dependencies: List[str], bits: int) -> List[str]:
57    """Finish any family dependency symbol prefixes.
58
59    Apply `finish_family_dependency` to each element of `dependencies`.
60    """
61    return [finish_family_dependency(dep, bits) for dep in dependencies]
62
63SYMBOLS_WITHOUT_DEPENDENCY = frozenset([
64    'PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG', # modifier, only in policies
65    'PSA_ALG_AEAD_WITH_SHORTENED_TAG', # modifier
66    'PSA_ALG_ANY_HASH', # only in policies
67    'PSA_ALG_AT_LEAST_THIS_LENGTH_MAC', # modifier, only in policies
68    'PSA_ALG_KEY_AGREEMENT', # chaining
69    'PSA_ALG_TRUNCATED_MAC', # modifier
70])
71def automatic_dependencies(*expressions: str) -> List[str]:
72    """Infer dependencies of a test case by looking for PSA_xxx symbols.
73
74    The arguments are strings which should be C expressions. Do not use
75    string literals or comments as this function is not smart enough to
76    skip them.
77    """
78    used = set()
79    for expr in expressions:
80        used.update(re.findall(r'PSA_(?:ALG|ECC_FAMILY|KEY_TYPE)_\w+', expr))
81    used.difference_update(SYMBOLS_WITHOUT_DEPENDENCY)
82    return sorted(psa_want_symbol(name) for name in used)
83
84# A temporary hack: at the time of writing, not all dependency symbols
85# are implemented yet. Skip test cases for which the dependency symbols are
86# not available. Once all dependency symbols are available, this hack must
87# be removed so that a bug in the dependency symbols proprely leads to a test
88# failure.
89def read_implemented_dependencies(filename: str) -> FrozenSet[str]:
90    return frozenset(symbol
91                     for line in open(filename)
92                     for symbol in re.findall(r'\bPSA_WANT_\w+\b', line))
93_implemented_dependencies = None #type: Optional[FrozenSet[str]] #pylint: disable=invalid-name
94def hack_dependencies_not_implemented(dependencies: List[str]) -> None:
95    global _implemented_dependencies #pylint: disable=global-statement,invalid-name
96    if _implemented_dependencies is None:
97        _implemented_dependencies = \
98            read_implemented_dependencies('include/psa/crypto_config.h')
99    if not all((dep.lstrip('!') in _implemented_dependencies or 'PSA_WANT' not in dep)
100               for dep in dependencies):
101        dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET')
102
103
104class Information:
105    """Gather information about PSA constructors."""
106
107    def __init__(self) -> None:
108        self.constructors = self.read_psa_interface()
109
110    @staticmethod
111    def remove_unwanted_macros(
112            constructors: macro_collector.PSAMacroEnumerator
113    ) -> None:
114        # Mbed TLS doesn't support finite-field DH yet and will not support
115        # finite-field DSA. Don't attempt to generate any related test case.
116        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
117        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
118        constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
119        constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
120
121    def read_psa_interface(self) -> macro_collector.PSAMacroEnumerator:
122        """Return the list of known key types, algorithms, etc."""
123        constructors = macro_collector.InputsForTest()
124        header_file_names = ['include/psa/crypto_values.h',
125                             'include/psa/crypto_extra.h']
126        test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
127        for header_file_name in header_file_names:
128            constructors.parse_header(header_file_name)
129        for test_cases in test_suites:
130            constructors.parse_test_cases(test_cases)
131        self.remove_unwanted_macros(constructors)
132        constructors.gather_arguments()
133        return constructors
134
135
136def test_case_for_key_type_not_supported(
137        verb: str, key_type: str, bits: int,
138        dependencies: List[str],
139        *args: str,
140        param_descr: str = ''
141) -> test_case.TestCase:
142    """Return one test case exercising a key creation method
143    for an unsupported key type or size.
144    """
145    hack_dependencies_not_implemented(dependencies)
146    tc = test_case.TestCase()
147    short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type)
148    adverb = 'not' if dependencies else 'never'
149    if param_descr:
150        adverb = param_descr + ' ' + adverb
151    tc.set_description('PSA {} {} {}-bit {} supported'
152                       .format(verb, short_key_type, bits, adverb))
153    tc.set_dependencies(dependencies)
154    tc.set_function(verb + '_not_supported')
155    tc.set_arguments([key_type] + list(args))
156    return tc
157
158class NotSupported:
159    """Generate test cases for when something is not supported."""
160
161    def __init__(self, info: Information) -> None:
162        self.constructors = info.constructors
163
164    ALWAYS_SUPPORTED = frozenset([
165        'PSA_KEY_TYPE_DERIVE',
166        'PSA_KEY_TYPE_RAW_DATA',
167    ])
168    def test_cases_for_key_type_not_supported(
169            self,
170            kt: crypto_knowledge.KeyType,
171            param: Optional[int] = None,
172            param_descr: str = '',
173    ) -> Iterator[test_case.TestCase]:
174        """Return test cases exercising key creation when the given type is unsupported.
175
176        If param is present and not None, emit test cases conditioned on this
177        parameter not being supported. If it is absent or None, emit test cases
178        conditioned on the base type not being supported.
179        """
180        if kt.name in self.ALWAYS_SUPPORTED:
181            # Don't generate test cases for key types that are always supported.
182            # They would be skipped in all configurations, which is noise.
183            return
184        import_dependencies = [('!' if param is None else '') +
185                               psa_want_symbol(kt.name)]
186        if kt.params is not None:
187            import_dependencies += [('!' if param == i else '') +
188                                    psa_want_symbol(sym)
189                                    for i, sym in enumerate(kt.params)]
190        if kt.name.endswith('_PUBLIC_KEY'):
191            generate_dependencies = []
192        else:
193            generate_dependencies = import_dependencies
194        for bits in kt.sizes_to_test():
195            yield test_case_for_key_type_not_supported(
196                'import', kt.expression, bits,
197                finish_family_dependencies(import_dependencies, bits),
198                test_case.hex_string(kt.key_material(bits)),
199                param_descr=param_descr,
200            )
201            if not generate_dependencies and param is not None:
202                # If generation is impossible for this key type, rather than
203                # supported or not depending on implementation capabilities,
204                # only generate the test case once.
205                continue
206                # For public key we expect that key generation fails with
207                # INVALID_ARGUMENT. It is handled by KeyGenerate class.
208            if not kt.name.endswith('_PUBLIC_KEY'):
209                yield test_case_for_key_type_not_supported(
210                    'generate', kt.expression, bits,
211                    finish_family_dependencies(generate_dependencies, bits),
212                    str(bits),
213                    param_descr=param_descr,
214                )
215            # To be added: derive
216
217    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
218                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
219
220    def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
221        """Generate test cases that exercise the creation of keys of unsupported types."""
222        for key_type in sorted(self.constructors.key_types):
223            if key_type in self.ECC_KEY_TYPES:
224                continue
225            kt = crypto_knowledge.KeyType(key_type)
226            yield from self.test_cases_for_key_type_not_supported(kt)
227        for curve_family in sorted(self.constructors.ecc_curves):
228            for constr in self.ECC_KEY_TYPES:
229                kt = crypto_knowledge.KeyType(constr, [curve_family])
230                yield from self.test_cases_for_key_type_not_supported(
231                    kt, param_descr='type')
232                yield from self.test_cases_for_key_type_not_supported(
233                    kt, 0, param_descr='curve')
234
235def test_case_for_key_generation(
236        key_type: str, bits: int,
237        dependencies: List[str],
238        *args: str,
239        result: str = ''
240) -> test_case.TestCase:
241    """Return one test case exercising a key generation.
242    """
243    hack_dependencies_not_implemented(dependencies)
244    tc = test_case.TestCase()
245    short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type)
246    tc.set_description('PSA {} {}-bit'
247                       .format(short_key_type, bits))
248    tc.set_dependencies(dependencies)
249    tc.set_function('generate_key')
250    tc.set_arguments([key_type] + list(args) + [result])
251
252    return tc
253
254class KeyGenerate:
255    """Generate positive and negative (invalid argument) test cases for key generation."""
256
257    def __init__(self, info: Information) -> None:
258        self.constructors = info.constructors
259
260    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
261                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
262
263    @staticmethod
264    def test_cases_for_key_type_key_generation(
265            kt: crypto_knowledge.KeyType
266    ) -> Iterator[test_case.TestCase]:
267        """Return test cases exercising key generation.
268
269        All key types can be generated except for public keys. For public key
270        PSA_ERROR_INVALID_ARGUMENT status is expected.
271        """
272        result = 'PSA_SUCCESS'
273
274        import_dependencies = [psa_want_symbol(kt.name)]
275        if kt.params is not None:
276            import_dependencies += [psa_want_symbol(sym)
277                                    for i, sym in enumerate(kt.params)]
278        if kt.name.endswith('_PUBLIC_KEY'):
279            # The library checks whether the key type is a public key generically,
280            # before it reaches a point where it needs support for the specific key
281            # type, so it returns INVALID_ARGUMENT for unsupported public key types.
282            generate_dependencies = []
283            result = 'PSA_ERROR_INVALID_ARGUMENT'
284        else:
285            generate_dependencies = import_dependencies
286            if kt.name == 'PSA_KEY_TYPE_RSA_KEY_PAIR':
287                generate_dependencies.append("MBEDTLS_GENPRIME")
288        for bits in kt.sizes_to_test():
289            yield test_case_for_key_generation(
290                kt.expression, bits,
291                finish_family_dependencies(generate_dependencies, bits),
292                str(bits),
293                result
294            )
295
296    def test_cases_for_key_generation(self) -> Iterator[test_case.TestCase]:
297        """Generate test cases that exercise the generation of keys."""
298        for key_type in sorted(self.constructors.key_types):
299            if key_type in self.ECC_KEY_TYPES:
300                continue
301            kt = crypto_knowledge.KeyType(key_type)
302            yield from self.test_cases_for_key_type_key_generation(kt)
303        for curve_family in sorted(self.constructors.ecc_curves):
304            for constr in self.ECC_KEY_TYPES:
305                kt = crypto_knowledge.KeyType(constr, [curve_family])
306                yield from self.test_cases_for_key_type_key_generation(kt)
307
308class StorageKey(psa_storage.Key):
309    """Representation of a key for storage format testing."""
310
311    IMPLICIT_USAGE_FLAGS = {
312        'PSA_KEY_USAGE_SIGN_HASH': 'PSA_KEY_USAGE_SIGN_MESSAGE',
313        'PSA_KEY_USAGE_VERIFY_HASH': 'PSA_KEY_USAGE_VERIFY_MESSAGE'
314    } #type: Dict[str, str]
315    """Mapping of usage flags to the flags that they imply."""
316
317    def __init__(
318            self,
319            usage: str,
320            without_implicit_usage: Optional[bool] = False,
321            **kwargs
322    ) -> None:
323        """Prepare to generate a key.
324
325        * `usage`                 : The usage flags used for the key.
326        * `without_implicit_usage`: Flag to defide to apply the usage extension
327        """
328        super().__init__(usage=usage, **kwargs)
329
330        if not without_implicit_usage:
331            for flag, implicit in self.IMPLICIT_USAGE_FLAGS.items():
332                if self.usage.value() & psa_storage.Expr(flag).value() and \
333                   self.usage.value() & psa_storage.Expr(implicit).value() == 0:
334                    self.usage = psa_storage.Expr(self.usage.string + ' | ' + implicit)
335
336class StorageTestData(StorageKey):
337    """Representation of test case data for storage format testing."""
338
339    def __init__(
340            self,
341            description: str,
342            expected_usage: Optional[str] = None,
343            **kwargs
344    ) -> None:
345        """Prepare to generate test data
346
347        * `description`   : used for the the test case names
348        * `expected_usage`: the usage flags generated as the expected usage flags
349                            in the test cases. CAn differ from the usage flags
350                            stored in the keys because of the usage flags extension.
351        """
352        super().__init__(**kwargs)
353        self.description = description #type: str
354        self.expected_usage = expected_usage if expected_usage else self.usage.string #type: str
355
356class StorageFormat:
357    """Storage format stability test cases."""
358
359    def __init__(self, info: Information, version: int, forward: bool) -> None:
360        """Prepare to generate test cases for storage format stability.
361
362        * `info`: information about the API. See the `Information` class.
363        * `version`: the storage format version to generate test cases for.
364        * `forward`: if true, generate forward compatibility test cases which
365          save a key and check that its representation is as intended. Otherwise
366          generate backward compatibility test cases which inject a key
367          representation and check that it can be read and used.
368        """
369        self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator
370        self.version = version #type: int
371        self.forward = forward #type: bool
372
373    def make_test_case(self, key: StorageTestData) -> test_case.TestCase:
374        """Construct a storage format test case for the given key.
375
376        If ``forward`` is true, generate a forward compatibility test case:
377        create a key and validate that it has the expected representation.
378        Otherwise generate a backward compatibility test case: inject the
379        key representation into storage and validate that it can be read
380        correctly.
381        """
382        verb = 'save' if self.forward else 'read'
383        tc = test_case.TestCase()
384        tc.set_description('PSA storage {}: {}'.format(verb, key.description))
385        dependencies = automatic_dependencies(
386            key.lifetime.string, key.type.string,
387            key.expected_usage, key.alg.string, key.alg2.string,
388        )
389        dependencies = finish_family_dependencies(dependencies, key.bits)
390        tc.set_dependencies(dependencies)
391        tc.set_function('key_storage_' + verb)
392        if self.forward:
393            extra_arguments = []
394        else:
395            flags = []
396            # Some test keys have the RAW_DATA type and attributes that don't
397            # necessarily make sense. We do this to validate numerical
398            # encodings of the attributes.
399            # Raw data keys have no useful exercise anyway so there is no
400            # loss of test coverage.
401            if key.type.string != 'PSA_KEY_TYPE_RAW_DATA':
402                flags.append('TEST_FLAG_EXERCISE')
403            if 'READ_ONLY' in key.lifetime.string:
404                flags.append('TEST_FLAG_READ_ONLY')
405            extra_arguments = [' | '.join(flags) if flags else '0']
406        tc.set_arguments([key.lifetime.string,
407                          key.type.string, str(key.bits),
408                          key.expected_usage, key.alg.string, key.alg2.string,
409                          '"' + key.material.hex() + '"',
410                          '"' + key.hex() + '"',
411                          *extra_arguments])
412        return tc
413
414    def key_for_lifetime(
415            self,
416            lifetime: str,
417    ) -> StorageTestData:
418        """Construct a test key for the given lifetime."""
419        short = lifetime
420        short = re.sub(r'PSA_KEY_LIFETIME_FROM_PERSISTENCE_AND_LOCATION',
421                       r'', short)
422        short = re.sub(r'PSA_KEY_[A-Z]+_', r'', short)
423        description = 'lifetime: ' + short
424        key = StorageTestData(version=self.version,
425                              id=1, lifetime=lifetime,
426                              type='PSA_KEY_TYPE_RAW_DATA', bits=8,
427                              usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
428                              material=b'L',
429                              description=description)
430        return key
431
432    def all_keys_for_lifetimes(self) -> Iterator[StorageTestData]:
433        """Generate test keys covering lifetimes."""
434        lifetimes = sorted(self.constructors.lifetimes)
435        expressions = self.constructors.generate_expressions(lifetimes)
436        for lifetime in expressions:
437            # Don't attempt to create or load a volatile key in storage
438            if 'VOLATILE' in lifetime:
439                continue
440            # Don't attempt to create a read-only key in storage,
441            # but do attempt to load one.
442            if 'READ_ONLY' in lifetime and self.forward:
443                continue
444            yield self.key_for_lifetime(lifetime)
445
446    def keys_for_usage_flags(
447            self,
448            usage_flags: List[str],
449            short: Optional[str] = None,
450            test_implicit_usage: Optional[bool] = False
451    ) -> Iterator[StorageTestData]:
452        """Construct a test key for the given key usage."""
453        usage = ' | '.join(usage_flags) if usage_flags else '0'
454        if short is None:
455            short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage)
456        extra_desc = ' with implication' if test_implicit_usage else ''
457        description = 'usage' + extra_desc + ': ' + short
458        key1 = StorageTestData(version=self.version,
459                               id=1, lifetime=0x00000001,
460                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
461                               expected_usage=usage,
462                               usage=usage, alg=0, alg2=0,
463                               material=b'K',
464                               description=description)
465        yield key1
466
467        if test_implicit_usage:
468            description = 'usage without implication' + ': ' + short
469            key2 = StorageTestData(version=self.version,
470                                   id=1, lifetime=0x00000001,
471                                   type='PSA_KEY_TYPE_RAW_DATA', bits=8,
472                                   without_implicit_usage=True,
473                                   usage=usage, alg=0, alg2=0,
474                                   material=b'K',
475                                   description=description)
476            yield key2
477
478    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageTestData]:
479        """Generate test keys covering usage flags."""
480        known_flags = sorted(self.constructors.key_usage_flags)
481        yield from self.keys_for_usage_flags(['0'], **kwargs)
482        for usage_flag in known_flags:
483            yield from self.keys_for_usage_flags([usage_flag], **kwargs)
484        for flag1, flag2 in zip(known_flags,
485                                known_flags[1:] + [known_flags[0]]):
486            yield from self.keys_for_usage_flags([flag1, flag2], **kwargs)
487
488    def generate_key_for_all_usage_flags(self) -> Iterator[StorageTestData]:
489        known_flags = sorted(self.constructors.key_usage_flags)
490        yield from self.keys_for_usage_flags(known_flags, short='all known')
491
492    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
493        yield from self.generate_keys_for_usage_flags()
494        yield from self.generate_key_for_all_usage_flags()
495
496    def keys_for_type(
497            self,
498            key_type: str,
499            params: Optional[Iterable[str]] = None
500    ) -> Iterator[StorageTestData]:
501        """Generate test keys for the given key type.
502
503        For key types that depend on a parameter (e.g. elliptic curve family),
504        `param` is the parameter to pass to the constructor. Only a single
505        parameter is supported.
506        """
507        kt = crypto_knowledge.KeyType(key_type, params)
508        for bits in kt.sizes_to_test():
509            usage_flags = 'PSA_KEY_USAGE_EXPORT'
510            alg = 0
511            alg2 = 0
512            key_material = kt.key_material(bits)
513            short_expression = re.sub(r'\bPSA_(?:KEY_TYPE|ECC_FAMILY)_',
514                                      r'',
515                                      kt.expression)
516            description = 'type: {} {}-bit'.format(short_expression, bits)
517            key = StorageTestData(version=self.version,
518                                  id=1, lifetime=0x00000001,
519                                  type=kt.expression, bits=bits,
520                                  usage=usage_flags, alg=alg, alg2=alg2,
521                                  material=key_material,
522                                  description=description)
523            yield key
524
525    def all_keys_for_types(self) -> Iterator[StorageTestData]:
526        """Generate test keys covering key types and their representations."""
527        key_types = sorted(self.constructors.key_types)
528        for key_type in self.constructors.generate_expressions(key_types):
529            yield from self.keys_for_type(key_type)
530
531    def keys_for_algorithm(self, alg: str) -> Iterator[StorageTestData]:
532        """Generate test keys for the specified algorithm."""
533        # For now, we don't have information on the compatibility of key
534        # types and algorithms. So we just test the encoding of algorithms,
535        # and not that operations can be performed with them.
536        descr = re.sub(r'PSA_ALG_', r'', alg)
537        descr = re.sub(r',', r', ', re.sub(r' +', r'', descr))
538        usage = 'PSA_KEY_USAGE_EXPORT'
539        key1 = StorageTestData(version=self.version,
540                               id=1, lifetime=0x00000001,
541                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
542                               usage=usage, alg=alg, alg2=0,
543                               material=b'K',
544                               description='alg: ' + descr)
545        yield key1
546        key2 = StorageTestData(version=self.version,
547                               id=1, lifetime=0x00000001,
548                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
549                               usage=usage, alg=0, alg2=alg,
550                               material=b'L',
551                               description='alg2: ' + descr)
552        yield key2
553
554    def all_keys_for_algorithms(self) -> Iterator[StorageTestData]:
555        """Generate test keys covering algorithm encodings."""
556        algorithms = sorted(self.constructors.algorithms)
557        for alg in self.constructors.generate_expressions(algorithms):
558            yield from self.keys_for_algorithm(alg)
559
560    def generate_all_keys(self) -> Iterator[StorageTestData]:
561        """Generate all keys for the test cases."""
562        yield from self.all_keys_for_lifetimes()
563        yield from self.all_keys_for_usage_flags()
564        yield from self.all_keys_for_types()
565        yield from self.all_keys_for_algorithms()
566
567    def all_test_cases(self) -> Iterator[test_case.TestCase]:
568        """Generate all storage format test cases."""
569        # First build a list of all keys, then construct all the corresponding
570        # test cases. This allows all required information to be obtained in
571        # one go, which is a significant performance gain as the information
572        # includes numerical values obtained by compiling a C program.
573        all_keys = list(self.generate_all_keys())
574        for key in all_keys:
575            if key.location_value() != 0:
576                # Skip keys with a non-default location, because they
577                # require a driver and we currently have no mechanism to
578                # determine whether a driver is available.
579                continue
580            yield self.make_test_case(key)
581
582class StorageFormatForward(StorageFormat):
583    """Storage format stability test cases for forward compatibility."""
584
585    def __init__(self, info: Information, version: int) -> None:
586        super().__init__(info, version, True)
587
588class StorageFormatV0(StorageFormat):
589    """Storage format stability test cases for version 0 compatibility."""
590
591    def __init__(self, info: Information) -> None:
592        super().__init__(info, 0, False)
593
594    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
595        """Generate test keys covering usage flags."""
596        yield from self.generate_keys_for_usage_flags(test_implicit_usage=True)
597        yield from self.generate_key_for_all_usage_flags()
598
599    def keys_for_implicit_usage(
600            self,
601            implyer_usage: str,
602            alg: str,
603            key_type: crypto_knowledge.KeyType
604    ) -> StorageTestData:
605        # pylint: disable=too-many-locals
606        """Generate test keys for the specified implicit usage flag,
607           algorithm and key type combination.
608        """
609        bits = key_type.sizes_to_test()[0]
610        implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage]
611        usage_flags = 'PSA_KEY_USAGE_EXPORT'
612        material_usage_flags = usage_flags + ' | ' + implyer_usage
613        expected_usage_flags = material_usage_flags + ' | ' + implicit_usage
614        alg2 = 0
615        key_material = key_type.key_material(bits)
616        usage_expression = re.sub(r'PSA_KEY_USAGE_', r'', implyer_usage)
617        alg_expression = re.sub(r'PSA_ALG_', r'', alg)
618        alg_expression = re.sub(r',', r', ', re.sub(r' +', r'', alg_expression))
619        key_type_expression = re.sub(r'\bPSA_(?:KEY_TYPE|ECC_FAMILY)_',
620                                     r'',
621                                     key_type.expression)
622        description = 'implied by {}: {} {} {}-bit'.format(
623            usage_expression, alg_expression, key_type_expression, bits)
624        key = StorageTestData(version=self.version,
625                              id=1, lifetime=0x00000001,
626                              type=key_type.expression, bits=bits,
627                              usage=material_usage_flags,
628                              expected_usage=expected_usage_flags,
629                              without_implicit_usage=True,
630                              alg=alg, alg2=alg2,
631                              material=key_material,
632                              description=description)
633        return key
634
635    def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
636        # pylint: disable=too-many-locals
637        """Match possible key types for sign algorithms."""
638        # To create a valid combinaton both the algorithms and key types
639        # must be filtered. Pair them with keywords created from its names.
640        incompatible_alg_keyword = frozenset(['RAW', 'ANY', 'PURE'])
641        incompatible_key_type_keywords = frozenset(['MONTGOMERY'])
642        keyword_translation = {
643            'ECDSA': 'ECC',
644            'ED[0-9]*.*' : 'EDWARDS'
645        }
646        exclusive_keywords = {
647            'EDWARDS': 'ECC'
648        }
649        key_types = set(self.constructors.generate_expressions(self.constructors.key_types))
650        algorithms = set(self.constructors.generate_expressions(self.constructors.sign_algorithms))
651        alg_with_keys = {} #type: Dict[str, List[str]]
652        translation_table = str.maketrans('(', '_', ')')
653        for alg in algorithms:
654            # Generate keywords from the name of the algorithm
655            alg_keywords = set(alg.partition('(')[0].split(sep='_')[2:])
656            # Translate keywords for better matching with the key types
657            for keyword in alg_keywords.copy():
658                for pattern, replace in keyword_translation.items():
659                    if re.match(pattern, keyword):
660                        alg_keywords.remove(keyword)
661                        alg_keywords.add(replace)
662            # Filter out incompatible algortihms
663            if not alg_keywords.isdisjoint(incompatible_alg_keyword):
664                continue
665
666            for key_type in key_types:
667                # Generate keywords from the of the key type
668                key_type_keywords = set(key_type.translate(translation_table).split(sep='_')[3:])
669
670                # Remove ambigious keywords
671                for keyword1, keyword2 in exclusive_keywords.items():
672                    if keyword1 in key_type_keywords:
673                        key_type_keywords.remove(keyword2)
674
675                if key_type_keywords.isdisjoint(incompatible_key_type_keywords) and\
676                   not key_type_keywords.isdisjoint(alg_keywords):
677                    if alg in alg_with_keys:
678                        alg_with_keys[alg].append(key_type)
679                    else:
680                        alg_with_keys[alg] = [key_type]
681        return alg_with_keys
682
683    def all_keys_for_implicit_usage(self) -> Iterator[StorageTestData]:
684        """Generate test keys for usage flag extensions."""
685        # Generate a key type and algorithm pair for each extendable usage
686        # flag to generate a valid key for exercising. The key is generated
687        # without usage extension to check the extension compatiblity.
688        alg_with_keys = self.gather_key_types_for_sign_alg()
689
690        for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str):
691            for alg in sorted(alg_with_keys):
692                for key_type in sorted(alg_with_keys[alg]):
693                    # The key types must be filtered to fit the specific usage flag.
694                    kt = crypto_knowledge.KeyType(key_type)
695                    if kt.is_valid_for_signature(usage):
696                        yield self.keys_for_implicit_usage(usage, alg, kt)
697
698    def generate_all_keys(self) -> Iterator[StorageTestData]:
699        yield from super().generate_all_keys()
700        yield from self.all_keys_for_implicit_usage()
701
702class TestGenerator:
703    """Generate test data."""
704
705    def __init__(self, options) -> None:
706        self.test_suite_directory = self.get_option(options, 'directory',
707                                                    'tests/suites')
708        self.info = Information()
709
710    @staticmethod
711    def get_option(options, name: str, default: T) -> T:
712        value = getattr(options, name, None)
713        return default if value is None else value
714
715    def filename_for(self, basename: str) -> str:
716        """The location of the data file with the specified base name."""
717        return posixpath.join(self.test_suite_directory, basename + '.data')
718
719    def write_test_data_file(self, basename: str,
720                             test_cases: Iterable[test_case.TestCase]) -> None:
721        """Write the test cases to a .data file.
722
723        The output file is ``basename + '.data'`` in the test suite directory.
724        """
725        filename = self.filename_for(basename)
726        test_case.write_data_file(filename, test_cases)
727
728    TARGETS = {
729        'test_suite_psa_crypto_generate_key.generated':
730        lambda info: KeyGenerate(info).test_cases_for_key_generation(),
731        'test_suite_psa_crypto_not_supported.generated':
732        lambda info: NotSupported(info).test_cases_for_not_supported(),
733        'test_suite_psa_crypto_storage_format.current':
734        lambda info: StorageFormatForward(info, 0).all_test_cases(),
735        'test_suite_psa_crypto_storage_format.v0':
736        lambda info: StorageFormatV0(info).all_test_cases(),
737    } #type: Dict[str, Callable[[Information], Iterable[test_case.TestCase]]]
738
739    def generate_target(self, name: str) -> None:
740        test_cases = self.TARGETS[name](self.info)
741        self.write_test_data_file(name, test_cases)
742
743def main(args):
744    """Command line entry point."""
745    parser = argparse.ArgumentParser(description=__doc__)
746    parser.add_argument('--list', action='store_true',
747                        help='List available targets and exit')
748    parser.add_argument('--list-for-cmake', action='store_true',
749                        help='Print \';\'-separated list of available targets and exit')
750    parser.add_argument('--directory', metavar='DIR',
751                        help='Output directory (default: tests/suites)')
752    parser.add_argument('targets', nargs='*', metavar='TARGET',
753                        help='Target file to generate (default: all; "-": none)')
754    options = parser.parse_args(args)
755    build_tree.chdir_to_root()
756    generator = TestGenerator(options)
757    if options.list:
758        for name in sorted(generator.TARGETS):
759            print(generator.filename_for(name))
760        return
761    # List in a cmake list format (i.e. ';'-separated)
762    if options.list_for_cmake:
763        print(';'.join(generator.filename_for(name)
764                       for name in sorted(generator.TARGETS)), end='')
765        return
766    if options.targets:
767        # Allow "-" as a special case so you can run
768        # ``generate_psa_tests.py - $targets`` and it works uniformly whether
769        # ``$targets`` is empty or not.
770        options.targets = [os.path.basename(re.sub(r'\.data\Z', r'', target))
771                           for target in options.targets
772                           if target != '-']
773    else:
774        options.targets = sorted(generator.TARGETS)
775    for target in options.targets:
776        generator.generate_target(target)
777
778if __name__ == '__main__':
779    main(sys.argv[1:])
780