1#!/usr/bin/env python3
2# SPDX-License-Identifier: BSD-2-Clause
3#
4# Copyright (C) 2023, STMicroelectronics
5#
6
7try:
8    from elftools.elf.elffile import ELFFile
9    from elftools.elf.sections import SymbolTableSection
10    from elftools.elf.enums import ENUM_P_TYPE_ARM
11    from elftools.elf.enums import *
12except ImportError:
13    print("""
14***
15ERROR: "pyelftools" python module is not installed or version < 0.25.
16***
17""")
18    raise
19
20try:
21    from Cryptodome.Hash import SHA256
22    from Cryptodome.Signature import pkcs1_15
23    from Cryptodome.PublicKey import RSA
24    from Cryptodome.Signature import DSS
25    from Cryptodome.PublicKey import ECC
26except ImportError:
27    print("""
28            ***
29            ERROR: "pycryptodomex" python module should be installed.
30            ***
31        """)
32    raise
33
34import os
35import sys
36import struct
37import logging
38import binascii
39
40#  Generated file structure:
41#
42#                   -----+-------------+
43#                  /     |    Magic    |  32-bit word, magic value equal to
44#                 /      +-------------+  0x3543A468
45#                /       +-------------+
46#               /        |   version   |  32-bit word, version of the format
47#              /         +-------------+
48# +-----------+          +-------------+
49# |   Header  |          |  TLV size   |  32-bit word, size of the TLV
50# +-----------+          +-------------+  (aligned on 64-bit), in bytes.
51#              \         +-------------+
52#               \        |  sign size  |  32-bit word, size of the signature
53#                \       +-------------+  (aligned on 64-bit), in bytes.
54#                 \      +-------------+
55#                  \     | images size |  32-bit word, size of the images to
56#                   -----+-------------+  load (aligned on 64-bit), in bytes.
57#
58#                        +-------------+  Information used to authenticate the
59#                        |     TLV     |  images and boot the remote processor,
60#                        |             |  stored in Type-Length-Value format.
61#                        +-------------+  'Type' and 'Length' are 32-bit words.
62#
63#                        +-------------+
64#                        | Signature   |   Signature of the header and the TLV.
65#                        +-------------+
66#
67#                        +-------------+
68#                        |   Firmware  |
69#                        |    image 1  |
70#                        +-------------+
71#                               ...
72#                        +-------------+
73#                        |   Firmware  |
74#                        |    image n  |
75#                        +-------------+
76
77# Generic type definitions
78TLV_TYPES = {
79        'SIGNTYPE': 0x00000001,   # algorithm used for signature
80        'HASHTYPE': 0x00000002,   # algorithm used for computing segment's hash
81        'NUM_IMG':  0x00000003,   # number of images to load
82        'IMGTYPE':  0x00000004,   # array of type of images to load
83        'IMGSIZE':  0x00000005,   # array of the size of the images to load
84        'HASHTABLE': 0x000000010,  # segment hash table for authentication
85        'PKEYINFO': 0x0000000011,  # information to retrieve signature key
86}
87
88GENERIC_TLV_TYPE_RANGE = range(0x00000000, 0x00010000)
89PLATFORM_TLV_TYPE_RANGE = range(0x00010000, 0x00020000)
90
91HEADER_MAGIC = 0x3543A468
92
93logging.basicConfig(stream=sys.stderr, level=logging.INFO)
94
95ENUM_HASH_TYPE = dict(
96    SHA256=1,
97)
98
99ENUM_SIGNATURE_TYPE = dict(
100    RSA=1,
101    ECC=2,
102)
103
104ENUM_BINARY_TYPE = dict(
105    ELF=1,
106)
107
108
109def dump_buffer(buf, step=16, name="", logger=logging.debug, indent=""):
110    logger("%s%s:" % (indent, name))
111    for i in range(0, len(buf), step):
112        logger("%s    " % (indent) + " ".
113               join(["%02X" % c for c in buf[i:i+step]]))
114    logger("\n")
115
116
117class TLV():
118    def __init__(self):
119        self.buf = bytearray()
120        self.tlvs = {}
121
122    def add(self, kind, payload):
123        """
124        Add a TLV record. Argument type is either the type scalar ID or a
125        matching string defined in TLV_TYPES.
126        """
127        if isinstance(kind, int):
128            buf = struct.pack('II', kind, len(payload))
129        else:
130            buf = struct.pack('II', TLV_TYPES[kind], len(payload))
131
132        # Ensure that each TLV is 64-bit aligned
133        align_64b = (len(payload) + len(buf)) % 8
134        self.buf += buf
135        self.buf += payload
136        if align_64b:
137            self.buf += bytearray(8 - align_64b)
138
139    def add_plat_tlv(self, cust_tlv):
140        # Get list of custom protected TLVs from the command-line
141        for tlv in cust_tlv:
142            type_id = int(tlv[0], 0)
143
144            if type_id not in PLATFORM_TLV_TYPE_RANGE:
145                raise Exception('TLV %s not in range' % hex(type_id))
146
147            value = tlv[1]
148            if value.startswith('0x'):
149                int_val = int(value[2:], 16)
150                self.tlvs[type_id] = int_val.to_bytes(4, 'little')
151            else:
152                self.tlvs[type_id] = value.encode('utf-8')
153
154        if self.tlvs is not None:
155            for type_id, value in self.tlvs.items():
156                self.add(type_id, value)
157
158    def get(self):
159        """
160        Get a byte-array that concatenates all the TLV added.
161        """
162        if len(self.buf) == 0:
163            return bytes()
164        return bytes(self.buf)
165
166
167class RSA_Signature(object):
168
169    def __init__(self, key):
170        self._hasher = SHA256.new()
171        self.signer = pkcs1_15.new(key)
172
173    def hash_compute(self, segment):
174        self._hasher.update(segment)
175
176    def sign(self):
177        return self.signer.sign(self._hasher)
178
179
180class ECC_Signature(object):
181
182    def __init__(self, key):
183        self._hasher = SHA256.new()
184        self.signer = DSS.new(key, 'fips-186-3')
185
186    def hash_compute(self, segment):
187        self._hasher.update(segment)
188
189    def sign(self):
190        return self.signer.sign(self._hasher)
191
192
193Signature = {
194        1: RSA_Signature,
195        2: ECC_Signature,
196}
197
198
199class SegmentHashStruct:
200    pass
201
202
203class SegmentHash(object):
204    '''
205        Hash table based on Elf program segments
206    '''
207    def __init__(self, img):
208        self._num_segments = img.num_segments()
209        self._pack_fmt = '<%dL' % 8
210        self.img = img
211        self.hashProgTable = bytes()
212        self._offset = 0
213
214    def get_table(self):
215        '''
216            Create a segment hash table containing for each segment:
217                - the segments header
218                - a hash of the segment
219        '''
220        h = SHA256.new()
221        seg = SegmentHashStruct()
222        self.size = (h.digest_size + 32) * self._num_segments
223        logging.debug("hash section size %d" % self.size)
224        del h
225        self.buf = bytearray(self.size)
226        self._bufview_ = memoryview(self.buf)
227
228        for i in range(self._num_segments):
229            h = SHA256.new()
230            segment = self.img.get_segment(i)
231            seg.header = self.img.get_segment(i).header
232            logging.debug("compute hash for segment offset %s" % seg.header)
233            h.update(segment.data())
234            seg.hash = h.digest()
235            logging.debug("hash computed: %s" % seg.hash)
236            del h
237            struct.pack_into('<I', self._bufview_, self._offset,
238                             ENUM_P_TYPE_ARM[seg.header.p_type])
239            self._offset += 4
240            struct.pack_into('<7I', self._bufview_, self._offset,
241                             seg.header.p_offset, seg.header.p_vaddr,
242                             seg.header.p_paddr, seg.header.p_filesz,
243                             seg.header.p_memsz, seg.header.p_flags,
244                             seg.header.p_align)
245            self._offset += 28
246            struct.pack_into('<32B', self._bufview_, self._offset, *seg.hash)
247            self._offset += 32
248        dump_buffer(self.buf, name='hash table', indent="\t")
249        return self.buf
250
251
252class ImageHeader(object):
253    '''
254        Image header
255    '''
256
257    magic = 'HELF'   # SHDR_MAGIC
258    version = 1
259
260    MAGIC_OFFSET = 0
261    VERSION_OFFSET = 4
262    SIGN_LEN_OFFSET = 8
263    IMG_LEN_OFFSET = 12
264    TLV_LEN_OFFSET = 16
265    PTLV_LEN_OFFSET = 20
266
267    def __init__(self):
268        self.size = 56
269
270        self.magic = HEADER_MAGIC
271        self.version = 1
272        self.tlv_length = 0
273        self.sign_length = 0
274        self.img_length = 0
275
276        self.shdr = struct.pack('<IIIII',
277                                self.magic, self.version,
278                                self.tlv_length, self.sign_length,
279                                self.img_length)
280
281    def dump(self):
282        logging.debug("\tMAGIC\t\t= %08X" % (self.magic))
283        logging.debug("\tHEADER_VERSION\t= %08X" % (self.version))
284        logging.debug("\tTLV_LENGTH\t= %08X" % (self.tlv_length))
285        logging.debug("\tSIGN_LENGTH\t= %08X" % (self.sign_length))
286        logging.debug("\tIMAGE_LENGTH\t= %08X" % (self.img_length))
287
288    def get_packed(self):
289        return struct.pack('<IIIII',
290                           self.magic, self.version,
291                           self.tlv_length, self.sign_length, self.img_length)
292
293
294def get_args(logger):
295    from argparse import ArgumentParser, RawDescriptionHelpFormatter
296    import textwrap
297
298    parser = ArgumentParser(
299        description='Sign a remote processor firmware loadable by OP-TEE.',
300        usage='\n   %(prog)s [ arguments ]\n\n'
301        '   Generate signed loadable binary \n' +
302        '   Takes arguments --in, --out --key\n' +
303        '   %(prog)s --help  show available arguments\n\n')
304    parser.add_argument('--in', required=True, dest='in_file',
305                        help='Name of firmware input file ' +
306                             '(can be used multiple times)', action='append')
307    parser.add_argument('--out', required=True, dest='out_file',
308                        help='Name of the signed firmware output file')
309    parser.add_argument('--key', required=True,
310                        help='Name of signing key file',
311                        dest='key_file')
312    parser.add_argument('--key_info', required=False,
313                        help='Name file containing extra key information',
314                        dest='key_info')
315    parser.add_argument('--key_type', required=False,
316                        help='Type of signing key: should be RSA or ECC',
317                        default='RSA',
318                        dest='key_type')
319    parser.add_argument('--plat-tlv', required=False, nargs=2,
320                        metavar=("ID", "value"), action='append',
321                        help='Platform TLV that will be placed into image '
322                             'plat_tlv area. Add "0x" prefix to interpret '
323                             'the value as an integer, otherwise it will be '
324                             'interpreted as a string. Option can be used '
325                             'multiple times to add multiple TLVs.',
326                        default=[], dest='plat_tlv')
327
328    parsed = parser.parse_args()
329
330    # Set defaults for optional arguments.
331
332    if parsed.out_file is None:
333        parsed.out_file = str(parsed.in_file)+'.sig'
334
335    return parsed
336
337
338def rsa_key(key_file):
339    return RSA.importKey(open(key_file).read())
340
341
342def ecc_key(key_file):
343    return ECC.import_key(open(key_file).read())
344
345
346key_type = {
347        1: rsa_key,
348        2: ecc_key,
349}
350
351
352def rsa_sig_size(key):
353    return key.size_in_bytes()
354
355
356def ecc_sig_size(key):
357    # to be improve...
358    # DSA size is N/4  so 64 for DSA (L,N) = (2048, 256)
359    return 64
360
361
362sig_size_type = {
363        1: rsa_sig_size,
364        2: ecc_sig_size,
365}
366
367
368def main():
369    from Cryptodome.Signature import pss
370    from Cryptodome.Hash import SHA256
371    from Cryptodome.PublicKey import RSA
372    import base64
373    import logging
374    import struct
375
376    logging.basicConfig()
377    logger = logging.getLogger(os.path.basename(__file__))
378
379    args = get_args(logger)
380
381    # Initialise the header */
382    s_header = ImageHeader()
383    tlv = TLV()
384
385    sign_type = ENUM_SIGNATURE_TYPE[args.key_type]
386    get_key = key_type.get(sign_type, lambda: "Invalid sign type")
387
388    key = get_key(args.key_file)
389
390    if not key.has_private():
391        logger.error('Provided key cannot be used for signing, ')
392        sys.exit(1)
393
394    tlv.add('SIGNTYPE', sign_type.to_bytes(1, 'little'))
395
396    images_type = []
397    hash_tlv = bytearray()
398    images_size = []
399
400    # Firmware image
401    for inputf in args.in_file:
402        logging.debug("image  %s" % inputf)
403        input_file = open(inputf, 'rb')
404        img = ELFFile(input_file)
405
406        # Only ARM machine has been tested and well supported yet.
407        # Indeed this script uses of ENUM_P_TYPE_ARM dic
408        assert img.get_machine_arch() in ["ARM"]
409
410        # Need to reopen the file to get the raw data
411        with open(inputf, 'rb') as f:
412            bin_img = f.read()
413        size = len(bin_img)
414        align_64b = size % 8
415        if align_64b:
416            size += 8 - align_64b
417
418        images_size.extend(size.to_bytes(4, 'little'))
419        s_header.img_length += size
420        f.close()
421
422        # Store image type information
423        bin_type = ENUM_BINARY_TYPE['ELF']
424        images_type += bin_type.to_bytes(1, 'little')
425
426        # Compute the hash table and add it to TLV blob
427        hash_table = SegmentHash(img)
428        hash_tlv.extend(hash_table.get_table())
429
430    # Add image information
431    # The 'IMGTYPE' contains a byte array of the image type (ENUM_BINARY_TYPE).
432    # The 'IMGSIZE' contains a byte array of the size (32-bit) of each image.
433    tlv.add('NUM_IMG', len(args.in_file).to_bytes(1, 'little'))
434    tlv.add('IMGTYPE', bytearray(images_type))
435    tlv.add('IMGSIZE', bytearray(images_size))
436
437    # Add hash type information in TLV blob
438    # The 'HASHTYPE' TLV contains a byte associated to ENUM_HASH_TYPE.
439    hash_type = ENUM_HASH_TYPE['SHA256']
440    tlv.add('HASHTYPE', hash_type.to_bytes(1, 'little'))
441
442    # Add hash table information in TLV blob
443    # The HASHTABLE TLV contains a byte array containing all the ELF segment
444    # with associated hash.
445    tlv.add('HASHTABLE', hash_tlv)
446
447    # Add optional key information to TLV
448    if args.key_info:
449        with open(args.key_info, 'rb') as f:
450            key_info = f.read()
451        tlv.add('PKEYINFO', key_info)
452
453    # Compute custom TLV that will be passed to the platform PTA
454    # Get list of custom protected TLVs from the command-line
455    if args.plat_tlv:
456        tlv.add_plat_tlv(args.plat_tlv)
457
458    # Get the TLV area and compute its size (with 64 bit alignment)
459    tlvs_buff = tlv.get()
460    s_header.tlv_length = len(tlvs_buff)
461
462    align_64b = 8 - (s_header.tlv_length % 8)
463    if align_64b:
464        s_header.tlv_length += 8 - align_64b
465        tlvs_buff += bytearray(8 - align_64b)
466
467    dump_buffer(tlvs_buff, name='TLVS', indent="\t")
468
469    # Signature chunk
470    sign_size = sig_size_type.get(ENUM_SIGNATURE_TYPE[args.key_type],
471                                  lambda: "Invalid sign type")(key)
472    s_header.sign_length = sign_size
473
474    # Construct the Header
475    header = s_header.get_packed()
476
477    # Generate signature
478    signer = Signature.get(ENUM_SIGNATURE_TYPE[args.key_type])(key)
479
480    signer.hash_compute(header)
481    signer.hash_compute(tlvs_buff)
482    signature = signer.sign()
483    if len(signature) != sign_size:
484        raise Exception(("Actual signature length is not equal to ",
485                         "the computed one: {} != {}".
486                         format(len(signature), sign_size)))
487
488    s_header.dump()
489
490    with open(args.out_file, 'wb') as f:
491        f.write(header)
492        f.write(tlvs_buff)
493        f.write(signature)
494        align_64b = sign_size % 8
495        if align_64b:
496            f.write(bytearray(8 - align_64b))
497        for inputf in args.in_file:
498            with open(inputf, 'rb') as fin:
499                bin_img = fin.read()
500            f.write(bin_img)
501            fin.close()
502            align_64b = len(bin_img) % 8
503            if align_64b:
504                f.write(bytearray(8 - align_64b))
505
506
507if __name__ == "__main__":
508    main()
509