1#!/usr/bin/env python3
2#
3# Copyright (c) 2025 Nordic Semiconductor ASA
4#
5# SPDX-License-Identifier: Apache-2.0
6
7"""
8This script is used to install TLS credentials on a device via a serial connection.
9It supports both deleting and writing credentials, as well as checking for their existence.
10It also verifies the hash of the installed credentials against the expected hash.
11
12This script is based on https://github.com/nRFCloud/utils/, specifically
13"command_interface.py" and "device_credentials_installer.py".
14"""
15
16import argparse
17import base64
18import hashlib
19import logging
20import math
21import os
22import sys
23import time
24
25import serial
26
27# Configure logging
28logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29logger = logging.getLogger(__name__)
30
31CMD_TERM_DICT = {'NULL': '\0', 'CR': '\r', 'LF': '\n', 'CRLF': '\r\n'}
32# 'CR' is the default termination value for the at_host library in the nRF Connect SDK
33cmd_term_key = 'CR'
34
35TLS_CRED_TYPES = ["CA", "SERV", "PK"]
36TLS_CRED_CHUNK_SIZE = 48
37serial_timeout = 1
38ser = None
39
40
41class TLSCredShellInterface:
42    def __init__(self, serial_write_line, serial_wait_for_response, verbose):
43        self.serial_write_line = serial_write_line
44        self.serial_wait_for_response = serial_wait_for_response
45        self.verbose = verbose
46
47    def write_raw(self, command):
48        if self.verbose:
49            logger.debug(f'-> {command}')
50        self.serial_write_line(command)
51
52    def write_credential(self, sectag, cred_type, cred_text):
53        # Because the Zephyr shell does not support multi-line commands,
54        # we must base-64 encode our PEM strings and install them as if they were binary.
55        # Yes, this does mean we are base-64 encoding a string which is already mostly base-64.
56        # We could alternatively strip the ===== BEGIN/END XXXX ===== header/footer, and then pass
57        # everything else directly as a binary payload (using BIN mode instead of BINT, since
58        # MBedTLS uses the NULL terminator to determine if the credential is raw DER, or is a
59        # PEM string). But this will fail for multi-CA installs, such as CoAP.
60
61        # text -> bytes -> base64 bytes -> base64 text
62        encoded = base64.b64encode(cred_text.encode()).decode()
63        self.write_raw("cred buf clear")
64        chunks = math.ceil(len(encoded) / TLS_CRED_CHUNK_SIZE)
65        for c in range(chunks):
66            chunk = encoded[c * TLS_CRED_CHUNK_SIZE : (c + 1) * TLS_CRED_CHUNK_SIZE]
67            self.write_raw(f"cred buf {chunk}")
68            result, output = self.serial_wait_for_response("Stored", "RX ring buffer full")
69            if not result:
70                logging.error("Failed to store chunk in the device: unknown error")
71            if output and b"RX ring buffer full" in output:
72                logging.error(f"Failed to store chunk in the device: {output}")
73                return False
74        if not 0 <= cred_type < len(TLS_CRED_TYPES):
75            logger.error(
76                f"Invalid credential type: {cred_type}. Range [0, {len(TLS_CRED_TYPES) - 1}]."
77            )
78            return False
79        self.write_raw(f"cred add {sectag} {TLS_CRED_TYPES[cred_type]} DEFAULT bint")
80        result, _ = self.serial_wait_for_response("Added TLS credential", "already exists")
81        time.sleep(1)
82        return result
83
84    def delete_credential(self, sectag, cred_type):
85        if not 0 <= cred_type < len(TLS_CRED_TYPES):
86            logger.error(
87                f"Invalid credential type: {cred_type}. Range [0, {len(TLS_CRED_TYPES) - 1}]."
88            )
89            return False
90        self.write_raw(f'cred del {sectag} {TLS_CRED_TYPES[cred_type]}')
91        result, _ = self.serial_wait_for_response(
92            "Deleted TLS credential", "There is no TLS credential"
93        )
94        time.sleep(2)
95        return result
96
97    def check_credential_exists(self, sectag, cred_type, get_hash=True):
98        self.write_raw(f'cred list {sectag} {TLS_CRED_TYPES[cred_type]}')
99        _, output = self.serial_wait_for_response(
100            "1 credentials found.",
101            "0 credentials found.",
102            store=f"{sectag},{TLS_CRED_TYPES[cred_type]}",
103        )
104
105        if not output:
106            return False, None
107
108        if not get_hash:
109            return True, None
110
111        data = output.decode().split(",")
112        logger.debug(f"Cred list output: {data}")
113        if len(data) < 4:
114            logger.error("Invalid output format from device, skipping hash check.")
115            return False, None
116        cred_hash = data[2].strip()
117        status_code = data[3].strip()
118
119        if status_code != "0":
120            logger.warning(f"Error retrieving credential hash: {output.decode().strip()}.")
121            logger.warning("Device might not support credential digests.")
122            return True, None
123
124        return True, cred_hash
125
126    def calculate_expected_hash(self, cred_text):
127        cred_hash = hashlib.sha256(cred_text.encode('utf-8') + b'\x00')
128        return base64.b64encode(cred_hash.digest()).decode()
129
130    def check_cred_command(self):
131        logger.info("Checking for 'cred' command existence...")
132        self.serial_write_line("cred")
133        result, output = self.serial_wait_for_response(
134            "TLS Credentials Commands", "command not found", store="cred"
135        )
136        logger.debug(f"Result: {result}, Output: {output}")
137        if not result:
138            logger.error("Device did not respond to 'cred' command.")
139            return False
140        if output and b"command not found" in output:
141            logger.error("Device does not support 'cred' command.")
142            logger.error("Hint: Add 'CONFIG_TLS_CREDENTIALS_SHELL=y' to your prj.conf file.")
143            return False
144        logger.info("'cred' command found.")
145        return True
146
147
148def write_line(line, hidden=False):
149    if not hidden:
150        logger.debug(f'-> {line}')
151    ser.write(bytes((line + CMD_TERM_DICT[cmd_term_key]).encode('utf-8')))
152
153
154def wait_for_prompt(val1='uart:~$ ', val2=None, timeout=15, store=None):
155    found = False
156    retval = False
157    output = None
158
159    if not ser:
160        logger.error('Serial interface not initialized')
161        return False, None
162
163    if isinstance(val1, str):
164        val1 = val1.encode()
165
166    if isinstance(val2, str):
167        val2 = val2.encode()
168
169    if isinstance(store, str):
170        store = store.encode()
171
172    ser.flush()
173
174    while not found and timeout != 0:
175        try:
176            line = ser.readline()
177        except serial.SerialException as e:
178            logger.error(f"Error reading from serial interface: {e}")
179            return False, None
180        except Exception as e:
181            logger.error(f"Unexpected error: {e}")
182            return False, None
183
184        if line == b'\r\n':
185            continue
186
187        if line is None or len(line) == 0:
188            if timeout > 0:
189                timeout -= serial_timeout
190            continue
191
192        logger.debug(f'<- {line.decode("utf-8", errors="replace")}')
193
194        if val1 in line:
195            found = True
196            retval = True
197        elif val2 is not None and val2 in line:
198            found = True
199            retval = False
200        elif store is not None and (store in line or str(store) in str(line)):
201            output = line
202
203    if b'\n' not in line:
204        logger.debug('')
205
206    ser.flush()
207    if store is not None and output is None:
208        logger.error(f'String {store} not detected in line {line}')
209
210    if timeout == 0:
211        logger.error('Serial timeout waiting for prompt')
212
213    return retval, output
214
215
216def parse_args(in_args):
217    parser = argparse.ArgumentParser(
218        description="Device Credentials Installer",
219        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
220        allow_abbrev=False,
221    )
222    parser.add_argument(
223        "-p", "--port", type=str, help="Specify which serial port to open", default="/dev/ttyACM1"
224    )
225    parser.add_argument(
226        "-x",
227        "--xonxoff",
228        help="Enable software flow control for serial connection",
229        action='store_true',
230        default=False,
231    )
232    parser.add_argument(
233        "-r",
234        "--rtscts-off",
235        help="Disable hardware (RTS/CTS) flow control for serial connection",
236        action='store_true',
237        default=False,
238    )
239    parser.add_argument(
240        "-f",
241        "--dsrdtr",
242        help="Enable hardware (DSR/DTR) flow control for serial connection",
243        action='store_true',
244        default=False,
245    )
246    parser.add_argument(
247        "-d", "--delete", help="Delete sectag from device first", action='store_true', default=False
248    )
249    parser.add_argument(
250        "-l",
251        "--local-cert-file",
252        type=str,
253        help="Filepath to a local certificate (PEM) to use for the device",
254        required=True,
255    )
256    parser.add_argument(
257        "-t", "--cert-type", type=int, help="Certificate type to use for the device", default=1
258    )
259    parser.add_argument(
260        "-S", "--sectag", type=int, help="integer: Security tag to use", default=16842753
261    )
262    parser.add_argument(
263        "-H",
264        "--check-hash",
265        help="Check hash of the credential after writing",
266        action='store_true',
267        default=False,
268    )
269
270    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
271    args = parser.parse_args(in_args)
272    return args
273
274
275def main(in_args):
276    global ser
277
278    args = parse_args(in_args)
279
280    if args.verbose:
281        logger.setLevel(logging.DEBUG)
282
283    if not os.path.isfile(args.local_cert_file):
284        logger.error(f'Local certificate file {args.local_cert_file} does not exist')
285        sys.exit(3)
286
287    logger.info(f'Opening port {args.port}')
288    try:
289        try:
290            ser = serial.Serial(
291                args.port,
292                115200,
293                xonxoff=args.xonxoff,
294                rtscts=(not args.rtscts_off),
295                dsrdtr=args.dsrdtr,
296                timeout=serial_timeout,
297            )
298            ser.reset_input_buffer()
299            ser.reset_output_buffer()
300        except FileNotFoundError:
301            logger.error(f'Specified port {args.port} does not exist or cannot be accessed')
302            sys.exit(2)
303        except serial.SerialException as e:
304            logger.error(f'Failed to open serial port {args.port}: {e}')
305            sys.exit(2)
306    except serial.serialutil.SerialException:
307        logger.error('Port could not be opened; not a device, or open already')
308        sys.exit(2)
309
310    cred_if = TLSCredShellInterface(write_line, wait_for_prompt, args.verbose)
311    cmd_exits = cred_if.check_cred_command()
312    if not cmd_exits:
313        sys.exit(1)
314
315    with open(args.local_cert_file) as f:
316        dev_bytes = f.read()
317
318    if args.delete:
319        logger.info(f'Deleting sectag {args.sectag}...')
320        cred_if.delete_credential(args.sectag, args.cert_type)
321
322    result = cred_if.write_credential(args.sectag, args.cert_type, dev_bytes)
323    if not result:
324        logger.error(f'Failed to write credential for sectag {args.sectag}, it may already exist')
325        sys.exit(5)
326    logger.info(f'Writing sectag {args.sectag}...')
327    result, cred_hash = cred_if.check_credential_exists(
328        args.sectag, args.cert_type, args.check_hash
329    )
330    if args.check_hash:
331        logger.debug(f'Checking hash for sectag {args.sectag}...')
332    if not result:
333        logger.error(f'Failed to check credential existence for sectag {args.sectag}')
334        sys.exit(4)
335    if cred_hash:
336        logger.debug(f'Credential hash: {cred_hash}')
337        expected_hash = cred_if.calculate_expected_hash(dev_bytes)
338        if cred_hash != expected_hash:
339            logger.error(
340                f'Hash mismatch for sectag {args.sectag}. Exp: {expected_hash}, got: {cred_hash}'
341            )
342            sys.exit(6)
343    logger.info(f'Credential for sectag {args.sectag} written successfully')
344    sys.exit(0)
345
346
347def run():
348    try:
349        main(sys.argv[1:])
350    except KeyboardInterrupt:
351        logger.info("Execution interrupted by user (Ctrl-C). Exiting...")
352        sys.exit(1)
353
354
355if __name__ == '__main__':
356    run()
357