1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
5
6import logging
7import os
8import re
9import shlex
10import shutil
11from dataclasses import dataclass
12from pathlib import Path
13from subprocess import check_output, getstatusoutput
14
15logger = logging.getLogger(__name__)
16
17
18class MCUmgrException(Exception):
19    """General MCUmgr exception."""
20
21
22@dataclass
23class MCUmgrImage:
24    image: int
25    slot: int
26    version: str = ''
27    flags: str = ''
28    hash: str = ''
29
30
31class MCUmgr:
32    """Sample wrapper for mcumgr command-line tool"""
33    mcumgr_exec = 'mcumgr'
34
35    def __init__(self, connection_options: str):
36        self.conn_opts = connection_options
37
38    @classmethod
39    def create_for_serial(cls, serial_port: str) -> MCUmgr:
40        return cls(connection_options=f'--conntype serial --connstring={serial_port}')
41
42    @classmethod
43    def is_available(cls) -> bool:
44        exitcode, output = getstatusoutput(f'{cls.mcumgr_exec} version')
45        if exitcode != 0:
46            logger.warning(f'mcumgr tool not available: {output}')
47            return False
48        return True
49
50    def run_command(self, cmd: str) -> str:
51        command = f'{self.mcumgr_exec} {self.conn_opts} {cmd}'
52        logger.info(f'CMD: {command}')
53        return check_output(shlex.split(command), text=True)
54
55    def reset_device(self):
56        self.run_command('reset')
57
58    def image_upload(self, image: Path | str, slot: int | None = None, timeout: int = 30):
59        command = f'-t {timeout} image upload {image}'
60        if slot is not None:
61            command += f' -e -n {slot}'
62        self.run_command(command)
63        logger.info('Image successfully uploaded')
64
65    def get_image_list(self) -> list[MCUmgrImage]:
66        output = self.run_command('image list')
67        return self._parse_image_list(output)
68
69    @staticmethod
70    def _parse_image_list(cmd_output: str) -> list[MCUmgrImage]:
71        image_list = []
72        re_image = re.compile(r'image=(\d+)\s+slot=(\d+)')
73        re_version = re.compile(r'version:\s+(\S+)')
74        re_flags = re.compile(r'flags:\s+(.+)')
75        re_hash = re.compile(r'hash:\s+(\w+)')
76        for line in cmd_output.splitlines():
77            if m := re_image.search(line):
78                image_list.append(
79                    MCUmgrImage(
80                        image=int(m.group(1)),
81                        slot=int(m.group(2))
82                    )
83                )
84            elif image_list:
85                if m := re_version.search(line):
86                    image_list[-1].version = m.group(1)
87                elif m := re_flags.search(line):
88                    image_list[-1].flags = m.group(1)
89                elif m := re_hash.search(line):
90                    image_list[-1].hash = m.group(1)
91        return image_list
92
93    def get_hash_to_test(self) -> str:
94        image_list = self.get_image_list()
95        for image in image_list:
96            if 'active' not in image.flags:
97                return image.hash
98        logger.warning(f'Images returned by mcumgr (no not active):\n{image_list}')
99        raise MCUmgrException('No not active image found')
100
101    def get_hash_to_confirm(self):
102        image_list = self.get_image_list()
103        for image in image_list:
104            if 'confirmed' not in image.flags:
105                return image.hash
106        logger.warning(f'Images returned by mcumgr (no not confirmed):\n{image_list}')
107        raise MCUmgrException('No not confirmed image found')
108
109    def image_test(self, hash: str | None = None):
110        if not hash:
111            hash = self.get_hash_to_test()
112        self.run_command(f'image test {hash}')
113
114    def image_confirm(self, hash: str | None = None):
115        if not hash:
116            hash = self.get_hash_to_confirm()
117        self.run_command(f'image confirm {hash}')
118
119
120class MCUmgrBle(MCUmgr):
121    """MCUmgr wrapper for BLE connection"""
122
123    @classmethod
124    def create_for_ble(cls, hci_index: int, peer_name: str) -> MCUmgr:
125        """Create MCUmgr instance for BLE connection"""
126        connection_string = (
127            f'--conntype ble --hci {hci_index} '
128            f'--connstring peer_name="{peer_name}"'
129        )
130        return cls(connection_options=connection_string)
131
132    @classmethod
133    def is_available(cls) -> bool:
134        """Check if mcumgr is available. For BLE, it requires root privileges."""
135        if os.getuid() != 0 and 'sudo' not in cls.mcumgr_exec:
136            mcumgr_path = shutil.which(cls.mcumgr_exec)
137            cls.mcumgr_exec = f'sudo {mcumgr_path}'
138        return super().is_available()
139