1#!/usr/bin/env python3
2# vim: set syntax=python ts=4 :
3#
4# Copyright (c) 2022 Intel Corporation
5# SPDX-License-Identifier: Apache-2.0
6
7import logging
8import os
9import platform
10import re
11from multiprocessing import Lock, Value
12from pathlib import Path
13
14import scl
15import yaml
16from natsort import natsorted
17from twisterlib.environment import ZEPHYR_BASE
18
19try:
20    # Use the C LibYAML parser if available, rather than the Python parser.
21    # It's much faster.
22    from yaml import CDumper as Dumper
23    from yaml import CSafeLoader as SafeLoader
24except ImportError:
25    from yaml import Dumper, SafeLoader
26
27try:
28    from tabulate import tabulate
29except ImportError:
30    print("Install tabulate python module with pip to use --device-testing option.")
31
32logger = logging.getLogger('twister')
33
34
35class DUT:
36    def __init__(self,
37                 id=None,
38                 serial=None,
39                 serial_baud=None,
40                 platform=None,
41                 product=None,
42                 serial_pty=None,
43                 connected=False,
44                 runner_params=None,
45                 pre_script=None,
46                 post_script=None,
47                 post_flash_script=None,
48                 script_param=None,
49                 runner=None,
50                 flash_timeout=60,
51                 flash_with_test=False,
52                 flash_before=False):
53
54        self.serial = serial
55        self.baud = serial_baud or 115200
56        self.platform = platform
57        self.serial_pty = serial_pty
58        self._counter = Value("i", 0)
59        self._available = Value("i", 1)
60        self._failures = Value("i", 0)
61        self.connected = connected
62        self.pre_script = pre_script
63        self.id = id
64        self.product = product
65        self.runner = runner
66        self.runner_params = runner_params
67        self.flash_before = flash_before
68        self.fixtures = []
69        self.post_flash_script = post_flash_script
70        self.post_script = post_script
71        self.pre_script = pre_script
72        self.script_param = script_param
73        self.probe_id = None
74        self.notes = None
75        self.lock = Lock()
76        self.match = False
77        self.flash_timeout = flash_timeout
78        self.flash_with_test = flash_with_test
79
80    @property
81    def available(self):
82        with self._available.get_lock():
83            return self._available.value
84
85    @available.setter
86    def available(self, value):
87        with self._available.get_lock():
88            self._available.value = value
89
90    @property
91    def counter(self):
92        with self._counter.get_lock():
93            return self._counter.value
94
95    @counter.setter
96    def counter(self, value):
97        with self._counter.get_lock():
98            self._counter.value = value
99
100    def counter_increment(self, value=1):
101        with self._counter.get_lock():
102            self._counter.value += value
103
104    @property
105    def failures(self):
106        with self._failures.get_lock():
107            return self._failures.value
108
109    @failures.setter
110    def failures(self, value):
111        with self._failures.get_lock():
112            self._failures.value = value
113
114    def failures_increment(self, value=1):
115        with self._failures.get_lock():
116            self._failures.value += value
117
118    def to_dict(self):
119        d = {}
120        exclude = ['_available', '_counter', '_failures', 'match']
121        v = vars(self)
122        for k in v:
123            if k not in exclude and v[k]:
124                d[k] = v[k]
125        return d
126
127
128    def __repr__(self):
129        return f"<{self.platform} ({self.product}) on {self.serial}>"
130
131class HardwareMap:
132    schema_path = os.path.join(ZEPHYR_BASE, "scripts", "schemas", "twister", "hwmap-schema.yaml")
133
134    manufacturer = [
135        'ARM',
136        'SEGGER',
137        'MBED',
138        'STMicroelectronics',
139        'Atmel Corp.',
140        'Texas Instruments',
141        'Silicon Labs',
142        'NXP',
143        'NXP Semiconductors',
144        'Microchip Technology Inc.',
145        'FTDI',
146        'Digilent',
147        'Microsoft',
148        'Nuvoton',
149        'Espressif',
150    ]
151
152    runner_mapping = {
153        'pyocd': [
154            'DAPLink CMSIS-DAP',
155            'MBED CMSIS-DAP'
156        ],
157        'jlink': [
158            'J-Link',
159            'J-Link OB'
160        ],
161        'openocd': [
162            'STM32 STLink', '^XDS110.*', 'STLINK-V3'
163        ],
164        'dediprog': [
165            'TTL232R-3V3',
166            'MCP2200 USB Serial Port Emulator'
167        ]
168    }
169
170    def __init__(self, env=None):
171        self.detected = []
172        self.duts = []
173        self.options = env.options
174
175    def discover(self):
176
177        if self.options.generate_hardware_map:
178            self.scan(persistent=self.options.persistent_hardware_map)
179            self.save(self.options.generate_hardware_map)
180            return 0
181
182        if not self.options.device_testing and self.options.hardware_map:
183            self.load(self.options.hardware_map)
184            logger.info("Available devices:")
185            self.dump(connected_only=True)
186            return 0
187
188        if self.options.device_testing:
189            if self.options.hardware_map:
190                self.load(self.options.hardware_map)
191                if not self.options.platform:
192                    self.options.platform = []
193                    for d in self.duts:
194                        if d.connected and d.platform != 'unknown':
195                            self.options.platform.append(d.platform)
196
197            elif self.options.device_serial:
198                self.add_device(self.options.device_serial,
199                                self.options.platform[0],
200                                self.options.pre_script,
201                                False,
202                                baud=self.options.device_serial_baud,
203                                flash_timeout=self.options.device_flash_timeout,
204                                flash_with_test=self.options.device_flash_with_test,
205                                flash_before=self.options.flash_before,
206                                )
207
208            elif self.options.device_serial_pty:
209                self.add_device(self.options.device_serial_pty,
210                                self.options.platform[0],
211                                self.options.pre_script,
212                                True,
213                                flash_timeout=self.options.device_flash_timeout,
214                                flash_with_test=self.options.device_flash_with_test,
215                                flash_before=self.options.flash_before,
216                                )
217
218            # the fixtures given by twister command explicitly should be assigned to each DUT
219            if self.options.fixture:
220                for d in self.duts:
221                    d.fixtures.extend(self.options.fixture)
222        return 1
223
224
225    def summary(self, selected_platforms):
226        print("\nHardware distribution summary:\n")
227        table = []
228        header = ['Board', 'ID', 'Counter', 'Failures']
229        for d in self.duts:
230            if d.connected and d.platform in selected_platforms:
231                row = [d.platform, d.id, d.counter, d.failures]
232                table.append(row)
233        print(tabulate(table, headers=header, tablefmt="github"))
234
235
236    def add_device(
237        self,
238        serial,
239        platform,
240        pre_script,
241        is_pty,
242        baud=None,
243        flash_timeout=60,
244        flash_with_test=False,
245        flash_before=False
246    ):
247        device = DUT(
248            platform=platform,
249            connected=True,
250            pre_script=pre_script,
251            serial_baud=baud,
252            flash_timeout=flash_timeout,
253            flash_with_test=flash_with_test,
254            flash_before=flash_before
255        )
256        if is_pty:
257            device.serial_pty = serial
258        else:
259            device.serial = serial
260
261        self.duts.append(device)
262
263    def load(self, map_file):
264        hwm_schema = scl.yaml_load(self.schema_path)
265        duts = scl.yaml_load_verify(map_file, hwm_schema)
266        for dut in duts:
267            pre_script = dut.get('pre_script')
268            script_param = dut.get('script_param')
269            post_script = dut.get('post_script')
270            post_flash_script = dut.get('post_flash_script')
271            flash_timeout = dut.get('flash_timeout') or self.options.device_flash_timeout
272            flash_with_test = dut.get('flash_with_test')
273            if flash_with_test is None:
274                flash_with_test = self.options.device_flash_with_test
275            serial_pty = dut.get('serial_pty')
276            flash_before = dut.get('flash_before')
277            if flash_before is None:
278                flash_before = self.options.flash_before and (not flash_with_test)
279            platform = dut.get('platform')
280            if isinstance(platform, str):
281                platforms = platform.split()
282            elif isinstance(platform, list):
283                platforms = platform
284            else:
285                raise ValueError(f"Invalid platform value: {platform}")
286            id = dut.get('id')
287            runner = dut.get('runner')
288            runner_params = dut.get('runner_params')
289            serial = dut.get('serial')
290            baud = dut.get('baud', None)
291            product = dut.get('product')
292            fixtures = dut.get('fixtures', [])
293            connected = dut.get('connected') and ((serial or serial_pty) is not None)
294            if not connected:
295                continue
296            for plat in platforms:
297                new_dut = DUT(platform=plat,
298                              product=product,
299                              runner=runner,
300                              runner_params=runner_params,
301                              id=id,
302                              serial_pty=serial_pty,
303                              serial=serial,
304                              serial_baud=baud,
305                              connected=connected,
306                              pre_script=pre_script,
307                              flash_before=flash_before,
308                              post_script=post_script,
309                              post_flash_script=post_flash_script,
310                              script_param=script_param,
311                              flash_timeout=flash_timeout,
312                              flash_with_test=flash_with_test)
313                new_dut.fixtures = fixtures
314                new_dut.counter = 0
315                self.duts.append(new_dut)
316
317    def scan(self, persistent=False):
318        from serial.tools import list_ports
319
320        if persistent and platform.system() == 'Linux':
321            # On Linux, /dev/serial/by-id provides symlinks to
322            # '/dev/ttyACMx' nodes using names which are unique as
323            # long as manufacturers fill out USB metadata nicely.
324            #
325            # This creates a map from '/dev/ttyACMx' device nodes
326            # to '/dev/serial/by-id/usb-...' symlinks. The symlinks
327            # go into the hardware map because they stay the same
328            # even when the user unplugs / replugs the device.
329            #
330            # Some inexpensive USB/serial adapters don't result
331            # in unique names here, though, so use of this feature
332            # requires explicitly setting persistent=True.
333            by_id = Path('/dev/serial/by-id')
334            def readlink(link):
335                return str((by_id / link).resolve())
336
337            if by_id.exists():
338                persistent_map = {readlink(link): str(link)
339                                  for link in by_id.iterdir()}
340            else:
341                persistent_map = {}
342        else:
343            persistent_map = {}
344
345        serial_devices = list_ports.comports()
346        logger.info("Scanning connected hardware...")
347        for d in serial_devices:
348            if (
349                d.manufacturer
350                and d.manufacturer.casefold() in [m.casefold() for m in self.manufacturer]
351            ):
352
353                # TI XDS110 can have multiple serial devices for a single board
354                # assume endpoint 0 is the serial, skip all others
355                if d.manufacturer == 'Texas Instruments' and not d.location.endswith('0'):
356                    continue
357
358                if d.product is None:
359                    d.product = 'unknown'
360
361                s_dev = DUT(platform="unknown",
362                                        id=d.serial_number,
363                                        serial=persistent_map.get(d.device, d.device),
364                                        product=d.product,
365                                        runner='unknown',
366                                        connected=True)
367
368                for runner, _ in self.runner_mapping.items():
369                    products = self.runner_mapping.get(runner)
370                    if d.product in products:
371                        s_dev.runner = runner
372                        continue
373                    # Try regex matching
374                    for p in products:
375                        if re.match(p, d.product):
376                            s_dev.runner = runner
377
378                s_dev.connected = True
379                s_dev.lock = None
380                self.detected.append(s_dev)
381            else:
382                logger.warning(f"Unsupported device ({d.manufacturer}): {d}")
383
384    def save(self, hwm_file):
385        # list of board ids with boot-serial sequence
386        boot_ids = []
387
388        # use existing map
389        self.detected = natsorted(self.detected, key=lambda x: x.serial or '')
390        if os.path.exists(hwm_file):
391            with open(hwm_file) as yaml_file:
392                hwm = yaml.load(yaml_file, Loader=SafeLoader)
393                if hwm:
394                    hwm.sort(key=lambda x: x.get('id', ''))
395
396                    # disconnect everything except boards with boot-serial sequence
397                    for h in hwm:
398                        if h['product'] != 'BOOT-SERIAL' :
399                            h['connected'] = False
400                            h['serial'] = None
401                        else :
402                            boot_ids.append(h['id'])
403
404                    for _detected in self.detected:
405                        for h in hwm:
406                            if all([
407                                _detected.id == h['id'],
408                                _detected.product == h['product'],
409                                _detected.match is False,
410                                h['connected'] is False
411                            ]):
412                                h['connected'] = True
413                                h['serial'] = _detected.serial
414                                _detected.match = True
415                                break
416
417                new_duts = list(filter(lambda d: not d.match, self.detected))
418                new = []
419                for d in new_duts:
420                    new.append(d.to_dict())
421
422                if hwm:
423                    hwm = hwm + new
424                else:
425                    hwm = new
426
427            #remove duplicated devices with unknown platform names before saving the file
428            for h in hwm :
429                if h['id'] in boot_ids and h['platform'] == 'unknown':
430                    hwm.remove(h)
431
432            with open(hwm_file, 'w') as yaml_file:
433                yaml.dump(hwm, yaml_file, Dumper=Dumper, default_flow_style=False)
434
435            self.load(hwm_file)
436            logger.info("Registered devices:")
437            self.dump()
438
439        else:
440            # create new file
441            dl = []
442            for _connected in self.detected:
443                platform  = _connected.platform
444                id = _connected.id
445                runner = _connected.runner
446                serial = _connected.serial
447                product = _connected.product
448                d = {
449                    'platform': platform,
450                    'id': id,
451                    'runner': runner,
452                    'serial': serial,
453                    'product': product,
454                    'connected': _connected.connected
455                }
456                dl.append(d)
457            with open(hwm_file, 'w') as yaml_file:
458                yaml.dump(dl, yaml_file, Dumper=Dumper, default_flow_style=False)
459            logger.info("Detected devices:")
460            self.dump(detected=True)
461
462    def dump(self, filtered=None, header=None, connected_only=False, detected=False):
463        if filtered is None:
464            filtered = []
465        if header is None:
466            header = []
467        print("")
468        table = []
469        if detected:
470            to_show = self.detected
471        else:
472            to_show = self.duts
473
474        if not header:
475            header = ["Platform", "ID", "Serial device"]
476        for p in to_show:
477            platform = p.platform
478            connected = p.connected
479            if filtered and platform not in filtered:
480                continue
481
482            if not connected_only or connected:
483                table.append([platform, p.id, p.serial])
484
485        print(tabulate(table, headers=header, tablefmt="github"))
486