1"""Knowledge about the PSA key store as implemented in Mbed TLS.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19import re
20import struct
21from typing import Dict, List, Optional, Set, Union
22import unittest
23
24from mbedtls_dev import c_build_helper
25
26
27class Expr:
28    """Representation of a C expression with a known or knowable numerical value."""
29
30    def __init__(self, content: Union[int, str]):
31        if isinstance(content, int):
32            digits = 8 if content > 0xffff else 4
33            self.string = '{0:#0{1}x}'.format(content, digits + 2)
34            self.value_if_known = content #type: Optional[int]
35        else:
36            self.string = content
37            self.unknown_values.add(self.normalize(content))
38            self.value_if_known = None
39
40    value_cache = {} #type: Dict[str, int]
41    """Cache of known values of expressions."""
42
43    unknown_values = set() #type: Set[str]
44    """Expressions whose values are not present in `value_cache` yet."""
45
46    def update_cache(self) -> None:
47        """Update `value_cache` for expressions registered in `unknown_values`."""
48        expressions = sorted(self.unknown_values)
49        values = c_build_helper.get_c_expression_values(
50            'unsigned long', '%lu',
51            expressions,
52            header="""
53            #include <psa/crypto.h>
54            """,
55            include_path=['include']) #type: List[str]
56        for e, v in zip(expressions, values):
57            self.value_cache[e] = int(v, 0)
58        self.unknown_values.clear()
59
60    @staticmethod
61    def normalize(string: str) -> str:
62        """Put the given C expression in a canonical form.
63
64        This function is only intended to give correct results for the
65        relatively simple kind of C expression typically used with this
66        module.
67        """
68        return re.sub(r'\s+', r'', string)
69
70    def value(self) -> int:
71        """Return the numerical value of the expression."""
72        if self.value_if_known is None:
73            if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
74                return int(self.string, 0)
75            normalized = self.normalize(self.string)
76            if normalized not in self.value_cache:
77                self.update_cache()
78            self.value_if_known = self.value_cache[normalized]
79        return self.value_if_known
80
81Exprable = Union[str, int, Expr]
82"""Something that can be converted to a C expression with a known numerical value."""
83
84def as_expr(thing: Exprable) -> Expr:
85    """Return an `Expr` object for `thing`.
86
87    If `thing` is already an `Expr` object, return it. Otherwise build a new
88    `Expr` object from `thing`. `thing` can be an integer or a string that
89    contains a C expression.
90    """
91    if isinstance(thing, Expr):
92        return thing
93    else:
94        return Expr(thing)
95
96
97class Key:
98    """Representation of a PSA crypto key object and its storage encoding.
99    """
100
101    LATEST_VERSION = 0
102    """The latest version of the storage format."""
103
104    def __init__(self, *,
105                 version: Optional[int] = None,
106                 id: Optional[int] = None, #pylint: disable=redefined-builtin
107                 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
108                 type: Exprable, #pylint: disable=redefined-builtin
109                 bits: int,
110                 usage: Exprable, alg: Exprable, alg2: Exprable,
111                 material: bytes #pylint: disable=used-before-assignment
112                ) -> None:
113        self.version = self.LATEST_VERSION if version is None else version
114        self.id = id #pylint: disable=invalid-name #type: Optional[int]
115        self.lifetime = as_expr(lifetime) #type: Expr
116        self.type = as_expr(type) #type: Expr
117        self.bits = bits #type: int
118        self.usage = as_expr(usage) #type: Expr
119        self.alg = as_expr(alg) #type: Expr
120        self.alg2 = as_expr(alg2) #type: Expr
121        self.material = material #type: bytes
122
123    MAGIC = b'PSA\000KEY\000'
124
125    @staticmethod
126    def pack(
127            fmt: str,
128            *args: Union[int, Expr]
129    ) -> bytes: #pylint: disable=used-before-assignment
130        """Pack the given arguments into a byte string according to the given format.
131
132        This function is similar to `struct.pack`, but with the following differences:
133        * All integer values are encoded with standard sizes and in
134          little-endian representation. `fmt` must not include an endianness
135          prefix.
136        * Arguments can be `Expr` objects instead of integers.
137        * Only integer-valued elements are supported.
138        """
139        return struct.pack('<' + fmt, # little-endian, standard sizes
140                           *[arg.value() if isinstance(arg, Expr) else arg
141                             for arg in args])
142
143    def bytes(self) -> bytes:
144        """Return the representation of the key in storage as a byte array.
145
146        This is the content of the PSA storage file. When PSA storage is
147        implemented over stdio files, this does not include any wrapping made
148        by the PSA-storage-over-stdio-file implementation.
149        """
150        header = self.MAGIC + self.pack('L', self.version)
151        if self.version == 0:
152            attributes = self.pack('LHHLLL',
153                                   self.lifetime, self.type, self.bits,
154                                   self.usage, self.alg, self.alg2)
155            material = self.pack('L', len(self.material)) + self.material
156        else:
157            raise NotImplementedError
158        return header + attributes + material
159
160    def hex(self) -> str:
161        """Return the representation of the key as a hexadecimal string.
162
163        This is the hexadecimal representation of `self.bytes`.
164        """
165        return self.bytes().hex()
166
167    def location_value(self) -> int:
168        """The numerical value of the location encoded in the key's lifetime."""
169        return self.lifetime.value() >> 8
170
171
172class TestKey(unittest.TestCase):
173    # pylint: disable=line-too-long
174    """A few smoke tests for the functionality of the `Key` class."""
175
176    def test_numerical(self):
177        key = Key(version=0,
178                  id=1, lifetime=0x00000001,
179                  type=0x2400, bits=128,
180                  usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
181                  material=b'@ABCDEFGHIJKLMNO')
182        expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
183        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
184        self.assertEqual(key.hex(), expected_hex)
185
186    def test_names(self):
187        length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
188        key = Key(version=0,
189                  id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
190                  type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
191                  usage=0, alg=0, alg2=0,
192                  material=b'\x00' * length)
193        expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
194        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
195        self.assertEqual(key.hex(), expected_hex)
196
197    def test_defaults(self):
198        key = Key(type=0x1001, bits=8,
199                  usage=0, alg=0, alg2=0,
200                  material=b'\x2a')
201        expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
202        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
203        self.assertEqual(key.hex(), expected_hex)
204