1#!/usr/bin/env python3
2#
3# Copyright (c) 2024 STMicroelectronics
4# SPDX-License-Identifier: Apache-2.0
5
6"""Injects SLIDs in LLEXT ELFs' symbol tables.
7
8When Kconfig option CONFIG_LLEXT_EXPORT_BUILTINS_BY_SLID is enabled,
9all imports from the Zephyr kernel & application are resolved using
10SLIDs instead of symbol names. This script stores the SLID of all
11imported symbols in their associated entry in the ELF symbol table
12to allow the LLEXT subsystem to link it properly at runtime.
13
14Note that this script is idempotent in theory. However, to prevent
15any catastrophic problem, the script will abort if the 'st_value'
16field of the `ElfX_Sym` structure is found to be non-zero, which is
17the case after one invocation. For this reason, in practice, the script
18cannot actually be executed twice on the same ELF file.
19"""
20
21import argparse
22import logging
23import shutil
24import sys
25
26from elftools.elf.elffile import ELFFile
27from elftools.elf.sections import SymbolTableSection
28
29import llext_slidlib
30
31class LLEXTSymtabPreparator():
32    def __init__(self, elf_path, log):
33        self.log = log
34        self.elf_path = elf_path
35        self.elf_fd = open(elf_path, "rb+")
36        self.elf = ELFFile(self.elf_fd)
37
38    def _find_symtab(self):
39        e_type = self.elf.header['e_type']
40        if e_type == 'ET_DYN':
41            symtab_name = ".dynsym"
42        elif e_type == 'ET_REL':
43            symtab_name = ".symtab"
44        else:
45            self.log.error(f"unexpected ELF file type {e_type}")
46            return None
47
48        symtab = self.elf.get_section_by_name(symtab_name)
49        if not isinstance(symtab, SymbolTableSection):
50            self.log.debug(f"section {symtab_name} not found.")
51            return None
52
53        self.log.info(f"processing symbol table from '{symtab_name}'...")
54        self.log.debug(f"(symbol table is at file offset 0x{symtab['sh_offset']:X})")
55        return symtab
56
57    def _find_imports_in_symtab(self, symtab):
58        i = 0
59        imports = []
60        for sym in symtab.iter_symbols():
61            #Check if symbol is an import
62            if sym.entry['st_info']['type'] == 'STT_NOTYPE' and \
63                sym.entry['st_info']['bind'] == 'STB_GLOBAL' and \
64                sym.entry['st_shndx'] == 'SHN_UNDEF':
65
66                self.log.debug(f"found imported symbol '{sym.name}' at index {i}")
67                imports.append((i, sym))
68
69            i += 1
70        return imports
71
72    def _prepare_inner(self):
73        #1) Locate the symbol table
74        symtab = self._find_symtab()
75        if symtab is None:
76            self.log.error("no symbol table found in file")
77            return 1
78
79        #2) Find imported symbols in symbol table
80        imports = self._find_imports_in_symtab(symtab)
81        self.log.info(f"LLEXT has {len(imports)} import(s)")
82
83        #3) Write SLIDs in each symbol's 'st_value' field
84        def make_stvalue_reader_writer():
85            byteorder = "little" if self.elf.little_endian else "big"
86            if self.elf.elfclass == 32:
87                sizeof_Elf_Sym = 0x10    #sizeof(Elf32_Sym)
88                offsetof_st_value = 0x4  #offsetof(Elf32_Sym, st_value)
89                sizeof_st_value = 0x4    #sizeof(Elf32_Sym.st_value)
90            else:
91                sizeof_Elf_Sym = 0x18
92                offsetof_st_value = 0x8
93                sizeof_st_value = 0x8
94
95            def seek(symidx):
96                self.elf_fd.seek(
97                    symtab['sh_offset'] +
98                    symidx * sizeof_Elf_Sym +
99                    offsetof_st_value)
100
101            def reader(symbol_index):
102                seek(symbol_index)
103                return int.from_bytes(self.elf_fd.read(sizeof_st_value), byteorder)
104
105            def writer(symbol_index, st_value):
106                seek(symbol_index)
107                self.elf_fd.write(int.to_bytes(st_value, sizeof_st_value, byteorder))
108
109            return reader, writer
110
111        rd_st_val, wr_st_val = make_stvalue_reader_writer()
112        slid_size = self.elf.elfclass // 8
113
114        for (index, symbol) in imports:
115            slid = llext_slidlib.generate_slid(symbol.name, slid_size)
116            slid_as_str = llext_slidlib.format_slid(slid, slid_size)
117            msg = f"{symbol.name} -> {slid_as_str}"
118
119            self.log.info(msg)
120
121            # Make sure we're not overwriting something actually important
122            original_st_value = rd_st_val(index)
123            if original_st_value != 0:
124                self.log.error(f"unexpected non-zero st_value for symbol {symbol.name}")
125                return 1
126
127            wr_st_val(index, slid)
128
129        return 0
130
131    def prepare_llext(self):
132        res = self._prepare_inner()
133        self.elf_fd.close()
134        return res
135
136# Disable duplicate code warning for the code that follows,
137# as it is expected for these functions to be similar.
138# pylint: disable=duplicate-code
139def _parse_args(argv):
140    """Parse the command line arguments."""
141    parser = argparse.ArgumentParser(
142        description=__doc__,
143        formatter_class=argparse.RawDescriptionHelpFormatter,
144        allow_abbrev=False)
145
146    parser.add_argument("-f", "--elf-file", required=True,
147                        help="LLEXT ELF file to process")
148    parser.add_argument("-o", "--output-file",
149                        help=("Additional output file where processed ELF "
150                        "will be copied"))
151    parser.add_argument("-sl", "--slid-listing",
152                        help="write the SLID listing to a file")
153    parser.add_argument("-v", "--verbose", action="count",
154                        help=("enable verbose output, can be used multiple times "
155                              "to increase verbosity level"))
156    parser.add_argument("--always-succeed", action="store_true",
157                        help="always exit with a return code of 0, used for testing")
158
159    return parser.parse_args(argv)
160
161def _init_log(verbose):
162    """Initialize a logger object."""
163    log = logging.getLogger(__file__)
164
165    console = logging.StreamHandler()
166    console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
167    log.addHandler(console)
168
169    if verbose and verbose > 1:
170        log.setLevel(logging.DEBUG)
171    elif verbose and verbose > 0:
172        log.setLevel(logging.INFO)
173    else:
174        log.setLevel(logging.WARNING)
175
176    return log
177
178def main(argv=None):
179    args = _parse_args(argv)
180
181    log = _init_log(args.verbose)
182
183    log.info(f"inject_slids_in_llext: {args.elf_file}")
184
185    preparator = LLEXTSymtabPreparator(args.elf_file, log)
186
187    res = preparator.prepare_llext()
188
189    if args.always_succeed:
190        return 0
191
192    if res == 0 and args.output_file:
193        shutil.copy(args.elf_file, args.output_file)
194
195    return res
196
197if __name__ == "__main__":
198    sys.exit(main(sys.argv[1:]))
199