1#!/usr/bin/env python3
2"""Test the program psa_constant_names.
3Gather constant names from header files and test cases. Compile a C program
4to print out their numerical values, feed these numerical values to
5psa_constant_names, and check that the output is the original name.
6Return 0 if all test cases pass, 1 if the output was not always as expected,
7or 1 (with a Python backtrace) if there was an operational error.
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0
12#
13# Licensed under the Apache License, Version 2.0 (the "License"); you may
14# not use this file except in compliance with the License.
15# You may obtain a copy of the License at
16#
17# http://www.apache.org/licenses/LICENSE-2.0
18#
19# Unless required by applicable law or agreed to in writing, software
20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22# See the License for the specific language governing permissions and
23# limitations under the License.
24
25import argparse
26from collections import namedtuple
27import os
28import re
29import subprocess
30import sys
31from typing import Iterable, List, Optional, Tuple
32
33import scripts_path # pylint: disable=unused-import
34from mbedtls_dev import c_build_helper
35from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
36from mbedtls_dev import typing_util
37
38def gather_inputs(headers: Iterable[str],
39                  test_suites: Iterable[str],
40                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
41    """Read the list of inputs to test psa_constant_names with."""
42    inputs = inputs_class()
43    for header in headers:
44        inputs.parse_header(header)
45    for test_cases in test_suites:
46        inputs.parse_test_cases(test_cases)
47    inputs.add_numerical_values()
48    inputs.gather_arguments()
49    return inputs
50
51def run_c(type_word: str,
52          expressions: Iterable[str],
53          include_path: Optional[str] = None,
54          keep_c: bool = False) -> List[str]:
55    """Generate and run a program to print out numerical values of C expressions."""
56    if type_word == 'status':
57        cast_to = 'long'
58        printf_format = '%ld'
59    else:
60        cast_to = 'unsigned long'
61        printf_format = '0x%08lx'
62    return c_build_helper.get_c_expression_values(
63        cast_to, printf_format,
64        expressions,
65        caller='test_psa_constant_names.py for {} values'.format(type_word),
66        file_label=type_word,
67        header='#include <psa/crypto.h>',
68        include_path=include_path,
69        keep_c=keep_c
70    )
71
72NORMALIZE_STRIP_RE = re.compile(r'\s+')
73def normalize(expr: str) -> str:
74    """Normalize the C expression so as not to care about trivial differences.
75
76    Currently "trivial differences" means whitespace.
77    """
78    return re.sub(NORMALIZE_STRIP_RE, '', expr)
79
80def collect_values(inputs: InputsForTest,
81                   type_word: str,
82                   include_path: Optional[str] = None,
83                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
84    """Generate expressions using known macro names and calculate their values.
85
86    Return a list of pairs of (expr, value) where expr is an expression and
87    value is a string representation of its integer value.
88    """
89    names = inputs.get_names(type_word)
90    expressions = sorted(inputs.generate_expressions(names))
91    values = run_c(type_word, expressions,
92                   include_path=include_path, keep_c=keep_c)
93    return expressions, values
94
95class Tests:
96    """An object representing tests and their results."""
97
98    Error = namedtuple('Error',
99                       ['type', 'expression', 'value', 'output'])
100
101    def __init__(self, options) -> None:
102        self.options = options
103        self.count = 0
104        self.errors = [] #type: List[Tests.Error]
105
106    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
107        """Test psa_constant_names for the specified type.
108
109        Run the program on the names for this type.
110        Use the inputs to figure out what arguments to pass to macros that
111        take arguments.
112        """
113        expressions, values = collect_values(inputs, type_word,
114                                             include_path=self.options.include,
115                                             keep_c=self.options.keep_c)
116        output_bytes = subprocess.check_output([self.options.program,
117                                                type_word] + values)
118        output = output_bytes.decode('ascii')
119        outputs = output.strip().split('\n')
120        self.count += len(expressions)
121        for expr, value, output in zip(expressions, values, outputs):
122            if self.options.show:
123                sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
124            if normalize(expr) != normalize(output):
125                self.errors.append(self.Error(type=type_word,
126                                              expression=expr,
127                                              value=value,
128                                              output=output))
129
130    def run_all(self, inputs: InputsForTest) -> None:
131        """Run psa_constant_names on all the gathered inputs."""
132        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
133                          'key_type', 'key_usage']:
134            self.run_one(inputs, type_word)
135
136    def report(self, out: typing_util.Writable) -> None:
137        """Describe each case where the output is not as expected.
138
139        Write the errors to ``out``.
140        Also write a total.
141        """
142        for error in self.errors:
143            out.write('For {} "{}", got "{}" (value: {})\n'
144                      .format(error.type, error.expression,
145                              error.output, error.value))
146        out.write('{} test cases'.format(self.count))
147        if self.errors:
148            out.write(', {} FAIL\n'.format(len(self.errors)))
149        else:
150            out.write(' PASS\n')
151
152HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
153TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
154
155def main():
156    parser = argparse.ArgumentParser(description=globals()['__doc__'])
157    parser.add_argument('--include', '-I',
158                        action='append', default=['include'],
159                        help='Directory for header files')
160    parser.add_argument('--keep-c',
161                        action='store_true', dest='keep_c', default=False,
162                        help='Keep the intermediate C file')
163    parser.add_argument('--no-keep-c',
164                        action='store_false', dest='keep_c',
165                        help='Don\'t keep the intermediate C file (default)')
166    parser.add_argument('--program',
167                        default='programs/psa/psa_constant_names',
168                        help='Program to test')
169    parser.add_argument('--show',
170                        action='store_true',
171                        help='Show tested values on stdout')
172    parser.add_argument('--no-show',
173                        action='store_false', dest='show',
174                        help='Don\'t show tested values (default)')
175    options = parser.parse_args()
176    headers = [os.path.join(options.include[0], h) for h in HEADERS]
177    inputs = gather_inputs(headers, TEST_SUITES)
178    tests = Tests(options)
179    tests.run_all(inputs)
180    tests.report(sys.stdout)
181    if tests.errors:
182        sys.exit(1)
183
184if __name__ == '__main__':
185    main()
186