1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4
5from __future__ import annotations
6
7import logging
8import os
9if os.name != 'nt':
10    import pty
11import re
12import subprocess
13import time
14from pathlib import Path
15
16import serial
17from twister_harness.device.device_adapter import DeviceAdapter
18from twister_harness.exceptions import (
19    TwisterHarnessException,
20    TwisterHarnessTimeoutException,
21)
22from twister_harness.device.utils import log_command, terminate_process
23from twister_harness.twister_harness_config import DeviceConfig
24
25logger = logging.getLogger(__name__)
26
27
28class HardwareAdapter(DeviceAdapter):
29    """Adapter class for real device."""
30
31    def __init__(self, device_config: DeviceConfig) -> None:
32        super().__init__(device_config)
33        self._flashing_timeout: float = device_config.flash_timeout
34        self._serial_connection: serial.Serial | None = None
35        self._serial_pty_proc: subprocess.Popen | None = None
36        self._serial_buffer: bytearray = bytearray()
37
38        self.device_log_path: Path = device_config.build_dir / 'device.log'
39        self._log_files.append(self.device_log_path)
40
41    def generate_command(self) -> None:
42        """Return command to flash."""
43        command = [
44            self.west,
45            'flash',
46            '--skip-rebuild',
47            '--build-dir', str(self.device_config.build_dir),
48        ]
49
50        command_extra_args = []
51        if self.device_config.west_flash_extra_args:
52            command_extra_args.extend(self.device_config.west_flash_extra_args)
53
54        if self.device_config.runner:
55            runner_base_args, runner_extra_args = self._prepare_runner_args()
56            command.extend(runner_base_args)
57            command_extra_args.extend(runner_extra_args)
58
59        if command_extra_args:
60            command.append('--')
61            command.extend(command_extra_args)
62        self.command = command
63
64    def _prepare_runner_args(self) -> tuple[list[str], list[str]]:
65        base_args: list[str] = []
66        extra_args: list[str] = []
67        runner = self.device_config.runner
68        base_args.extend(['--runner', runner])
69        if self.device_config.runner_params:
70            for param in self.device_config.runner_params:
71                extra_args.append(param)
72        if board_id := self.device_config.id:
73            if runner == 'pyocd':
74                extra_args.append('--board-id')
75                extra_args.append(board_id)
76            elif runner == "esp32":
77                extra_args.append("--esp-device")
78                extra_args.append(board_id)
79            elif runner in ('nrfjprog', 'nrfutil', 'nrfutil_next'):
80                extra_args.append('--dev-id')
81                extra_args.append(board_id)
82            elif runner == 'openocd' and self.device_config.product in ['STM32 STLink', 'STLINK-V3']:
83                extra_args.append('--cmd-pre-init')
84                extra_args.append(f'hla_serial {board_id}')
85            elif runner == 'openocd' and self.device_config.product == 'EDBG CMSIS-DAP':
86                extra_args.append('--cmd-pre-init')
87                extra_args.append(f'cmsis_dap_serial {board_id}')
88            elif runner == "openocd" and self.device_config.product == "LPC-LINK2 CMSIS-DAP":
89                extra_args.append("--cmd-pre-init")
90                extra_args.append(f'adapter serial {board_id}')
91            elif runner == 'jlink':
92                base_args.append('--dev-id')
93                base_args.append(board_id)
94            elif runner == 'stm32cubeprogrammer':
95                base_args.append(f'--tool-opt=sn={board_id}')
96            elif runner == 'linkserver':
97                base_args.append(f'--probe={board_id}')
98        return base_args, extra_args
99
100    def _flash_and_run(self) -> None:
101        """Flash application on a device."""
102        if not self.command:
103            msg = 'Flash command is empty, please verify if it was generated properly.'
104            logger.error(msg)
105            raise TwisterHarnessException(msg)
106
107        if self.device_config.pre_script:
108            self._run_custom_script(self.device_config.pre_script, self.base_timeout)
109
110        if self.device_config.id:
111            logger.debug('Flashing device %s', self.device_config.id)
112        log_command(logger, 'Flashing command', self.command, level=logging.DEBUG)
113
114        process = stdout = None
115        try:
116            process = subprocess.Popen(self.command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=self.env)
117            stdout, _ = process.communicate(timeout=self._flashing_timeout)
118        except subprocess.TimeoutExpired as exc:
119            process.kill()
120            msg = f'Timeout occurred ({self._flashing_timeout}s) during flashing.'
121            logger.error(msg)
122            raise TwisterHarnessTimeoutException(msg) from exc
123        except subprocess.SubprocessError as exc:
124            msg = f'Flashing subprocess failed due to SubprocessError {exc}'
125            logger.error(msg)
126            raise TwisterHarnessTimeoutException(msg) from exc
127        finally:
128            if stdout is not None:
129                stdout_decoded = stdout.decode(errors='ignore')
130                with open(self.device_log_path, 'a+') as log_file:
131                    log_file.write(stdout_decoded)
132            if self.device_config.post_flash_script:
133                self._run_custom_script(self.device_config.post_flash_script, self.base_timeout)
134            if process is not None and process.returncode == 0:
135                logger.debug('Flashing finished')
136            else:
137                msg = f'Could not flash device {self.device_config.id}'
138                logger.error(msg)
139                raise TwisterHarnessException(msg)
140
141    def _connect_device(self) -> None:
142        serial_name = self._open_serial_pty() or self.device_config.serial
143        logger.debug('Opening serial connection for %s', serial_name)
144        try:
145            self._serial_connection = serial.Serial(
146                serial_name,
147                baudrate=self.device_config.baud,
148                parity=serial.PARITY_NONE,
149                stopbits=serial.STOPBITS_ONE,
150                bytesize=serial.EIGHTBITS,
151                timeout=self.base_timeout,
152            )
153        except serial.SerialException as exc:
154            logger.exception('Cannot open connection: %s', exc)
155            self._close_serial_pty()
156            raise
157
158        self._serial_connection.flush()
159        self._serial_connection.reset_input_buffer()
160        self._serial_connection.reset_output_buffer()
161
162    def _open_serial_pty(self) -> str | None:
163        """Open a pty pair, run process and return tty name"""
164        if not self.device_config.serial_pty:
165            return None
166
167        try:
168            master, slave = pty.openpty()
169        except NameError as exc:
170            logger.exception('PTY module is not available.')
171            raise exc
172
173        try:
174            self._serial_pty_proc = subprocess.Popen(
175                re.split(',| ', self.device_config.serial_pty),
176                stdout=master,
177                stdin=master,
178                stderr=master
179            )
180        except subprocess.CalledProcessError as exc:
181            logger.exception('Failed to run subprocess %s, error %s', self.device_config.serial_pty, str(exc))
182            raise
183        return os.ttyname(slave)
184
185    def _disconnect_device(self) -> None:
186        if self._serial_connection:
187            serial_name = self._serial_connection.port
188            self._serial_connection.close()
189            # self._serial_connection = None
190            logger.debug('Closed serial connection for %s', serial_name)
191        self._close_serial_pty()
192
193    def _close_serial_pty(self) -> None:
194        """Terminate the process opened for serial pty script"""
195        if self._serial_pty_proc:
196            self._serial_pty_proc.terminate()
197            self._serial_pty_proc.communicate(timeout=self.base_timeout)
198            logger.debug('Process %s terminated', self.device_config.serial_pty)
199            self._serial_pty_proc = None
200
201    def _close_device(self) -> None:
202        if self.device_config.post_script:
203            self._run_custom_script(self.device_config.post_script, self.base_timeout)
204
205    def is_device_running(self) -> bool:
206        return self._device_run.is_set()
207
208    def is_device_connected(self) -> bool:
209        return bool(
210            self.is_device_running()
211            and self._device_connected.is_set()
212            and self._serial_connection
213            and self._serial_connection.is_open
214        )
215
216    def _read_device_output(self) -> bytes:
217        try:
218            output = self._readline_serial()
219        except (serial.SerialException, TypeError, IOError):
220            # serial was probably disconnected
221            output = b''
222        return output
223
224    def _readline_serial(self) -> bytes:
225        """
226        This method was created to avoid using PySerial built-in readline
227        method which cause blocking reader thread even if there is no data to
228        read. Instead for this, following implementation try to read data only
229        if they are available. Inspiration for this code was taken from this
230        comment:
231        https://github.com/pyserial/pyserial/issues/216#issuecomment-369414522
232        """
233        line = self._readline_from_serial_buffer()
234        if line is not None:
235            return line
236        while True:
237            if self._serial_connection is None or not self._serial_connection.is_open:
238                return b''
239            elif self._serial_connection.in_waiting == 0:
240                time.sleep(0.05)
241                continue
242            else:
243                bytes_to_read = max(1, min(2048, self._serial_connection.in_waiting))
244                output = self._serial_connection.read(bytes_to_read)
245                self._serial_buffer.extend(output)
246                line = self._readline_from_serial_buffer()
247                if line is not None:
248                    return line
249
250    def _readline_from_serial_buffer(self) -> bytes | None:
251        idx = self._serial_buffer.find(b"\n")
252        if idx >= 0:
253            line = self._serial_buffer[:idx+1]
254            self._serial_buffer = self._serial_buffer[idx+1:]
255            return bytes(line)
256        else:
257            return None
258
259    def _write_to_device(self, data: bytes) -> None:
260        self._serial_connection.write(data)
261
262    def _flush_device_output(self) -> None:
263        if self.is_device_connected():
264            self._serial_connection.flush()
265            self._serial_connection.reset_input_buffer()
266
267    def _clear_internal_resources(self) -> None:
268        super()._clear_internal_resources()
269        self._serial_connection = None
270        self._serial_pty_proc = None
271        self._serial_buffer.clear()
272
273    @staticmethod
274    def _run_custom_script(script_path: str | Path, timeout: float) -> None:
275        with subprocess.Popen(str(script_path), stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
276            try:
277                stdout, stderr = proc.communicate(timeout=timeout)
278                logger.debug(stdout.decode())
279                if proc.returncode != 0:
280                    msg = f'Custom script failure: \n{stderr.decode(errors="ignore")}'
281                    logger.error(msg)
282                    raise TwisterHarnessException(msg)
283
284            except subprocess.TimeoutExpired as exc:
285                terminate_process(proc)
286                proc.communicate(timeout=timeout)
287                msg = f'Timeout occurred ({timeout}s) during execution custom script: {script_path}'
288                logger.error(msg)
289                raise TwisterHarnessTimeoutException(msg) from exc
290