1#!/usr/bin/env python3
2# type: ignore[attr-defined]
3
4#
5# Copyright (c) 2024, Arm Limited. All rights reserved.
6#
7# SPDX-License-Identifier: BSD-3-Clause
8#
9
10"""Contains unit tests for the CLI functionality."""
11
12from math import ceil, log2
13from pathlib import Path
14from re import findall, search
15from unittest import mock
16
17import pytest
18import yaml
19from click.testing import CliRunner
20
21from tlc.cli import cli
22from tlc.te import TransferEntry
23from tlc.tl import TransferList
24
25
26def test_create_empty_tl(tmpdir):
27    runner = CliRunner()
28    test_file = tmpdir.join("tl.bin")
29
30    result = runner.invoke(cli, ["create", test_file.strpath])
31    assert result.exit_code == 0
32    assert TransferList.fromfile(test_file) is not None
33
34
35def test_create_with_fdt(tmpdir):
36    runner = CliRunner()
37    fdt = tmpdir.join("fdt.dtb")
38    fdt.write_binary(b"\x00" * 100)
39
40    result = runner.invoke(
41        cli,
42        [
43            "create",
44            "--fdt",
45            fdt.strpath,
46            "--size",
47            "1000",
48            tmpdir.join("tl.bin").strpath,
49        ],
50    )
51    assert result.exit_code == 0
52
53
54def test_add_single_entry(tlcrunner, tmptlstr):
55    tlcrunner.invoke(cli, ["add", "--entry", "0", "/dev/null", tmptlstr])
56
57    tl = TransferList.fromfile(tmptlstr)
58    assert tl is not None
59    assert len(tl.entries) == 1
60    assert tl.entries[0].id == 0
61
62
63def test_add_multiple_entries(tlcrunner, tlc_entries, tmptlstr):
64    for id, path in tlc_entries:
65        tlcrunner.invoke(cli, ["add", "--entry", id, path, tmptlstr])
66
67    tl = TransferList.fromfile(tmptlstr)
68    assert tl is not None
69    assert len(tl.entries) == len(tlc_entries)
70
71
72def test_info(tlcrunner, tmptlstr, tmpfdt):
73    tlcrunner.invoke(cli, ["add", "--entry", "0", "/dev/null", tmptlstr])
74    tlcrunner.invoke(cli, ["add", "--fdt", tmpfdt.strpath, tmptlstr])
75
76    result = tlcrunner.invoke(cli, ["info", tmptlstr])
77    assert result.exit_code == 0
78    assert "signature" in result.stdout
79    assert "id" in result.stdout
80
81    result = tlcrunner.invoke(cli, ["info", "--header", tmptlstr])
82    assert result.exit_code == 0
83    assert "signature" in result.stdout
84    assert "id" not in result.stdout
85
86    result = tlcrunner.invoke(cli, ["info", "--entries", tmptlstr])
87    assert result.exit_code == 0
88    assert "signature" not in result.stdout
89    assert "id" in result.stdout
90
91
92def test_raises_max_size_error(tmptlstr, tmpfdt):
93    tmpfdt.write_binary(bytes(6000))
94
95    runner = CliRunner()
96    result = runner.invoke(cli, ["create", "--fdt", tmpfdt, tmptlstr])
97
98    assert result.exception
99    assert isinstance(result.exception, MemoryError)
100    assert "TL max size exceeded, consider increasing with the option -s" in str(
101        result.exception
102    )
103    assert "TL size has exceeded the maximum allocation" in str(
104        result.exception.__cause__
105    )
106
107
108def test_info_get_fdt_offset(tmptlstr, tmpfdt):
109    runner = CliRunner()
110    with runner.isolated_filesystem():
111        runner.invoke(cli, ["create", "--size", "1000", tmptlstr])
112        runner.invoke(cli, ["add", "--entry", "1", tmpfdt.strpath, tmptlstr])
113        result = runner.invoke(cli, ["info", "--fdt-offset", tmptlstr])
114
115    assert result.exit_code == 0
116    assert result.output.strip("\n").isdigit()
117
118
119def test_remove_tag(tlcrunner, tmptlstr):
120    tlcrunner.invoke(cli, ["add", "--entry", "0", "/dev/null", tmptlstr])
121    result = tlcrunner.invoke(cli, ["info", tmptlstr])
122
123    assert result.exit_code == 0
124    assert "signature" in result.stdout
125
126    tlcrunner.invoke(cli, ["remove", "--tags", "0", tmptlstr])
127    tl = TransferList.fromfile(tmptlstr)
128
129    assert result.exit_code == 0
130    assert len(tl.entries) == 0
131
132
133def test_unpack_tl(tlcrunner, tmptlstr, tmpfdt, tmpdir):
134    with tlcrunner.isolated_filesystem(temp_dir=tmpdir):
135        tlcrunner.invoke(cli, ["add", "--entry", 1, tmpfdt.strpath, tmptlstr])
136        tlcrunner.invoke(cli, ["unpack", tmptlstr])
137        assert Path("te_0_1.bin").exists()
138
139
140def test_unpack_multiple_tes(tlcrunner, tlc_entries, tmptlstr, tmpdir):
141    with tlcrunner.isolated_filesystem(temp_dir=tmpdir):
142        for id, path in tlc_entries:
143            tlcrunner.invoke(cli, ["add", "--entry", id, path, tmptlstr])
144
145    assert all(
146        filter(
147            lambda te: (Path(tmpdir.strpath) / f"te_{te[0]}.bin").exists(), tlc_entries
148        )
149    )
150
151
152def test_unpack_into_dir(tlcrunner, tmpdir, tmptlstr, tmpfdt):
153    tlcrunner.invoke(cli, ["add", "--entry", 1, tmpfdt.strpath, tmptlstr])
154    tlcrunner.invoke(cli, ["unpack", "-C", tmpdir.strpath, tmptlstr])
155
156    assert (Path(tmpdir.strpath) / "te_0_1.bin").exists()
157
158
159def test_unpack_into_dir_with_conflicting_tags(tlcrunner, tmpdir, tmptlstr, tmpfdt):
160    tlcrunner.invoke(cli, ["add", "--entry", 1, tmpfdt.strpath, tmptlstr])
161    tlcrunner.invoke(cli, ["add", "--entry", 1, tmpfdt.strpath, tmptlstr])
162    tlcrunner.invoke(cli, ["unpack", "-C", tmpdir.strpath, tmptlstr])
163
164    assert (Path(tmpdir.strpath) / "te_0_1.bin").exists()
165    assert (Path(tmpdir.strpath) / "te_1_1.bin").exists()
166
167
168def test_validate_invalid_signature(tmptlstr, tlcrunner, monkeypatch):
169    tl = TransferList()
170    tl.signature = 0xDEADBEEF
171
172    mock_open = lambda tmptlstr, mode: mock.mock_open(read_data=tl.header_to_bytes())()
173    monkeypatch.setattr("builtins.open", mock_open)
174
175    result = tlcrunner.invoke(cli, ["validate", tmptlstr])
176    assert result.exit_code != 0
177
178
179def test_validate_misaligned_entries(tmptlstr, tlcrunner, monkeypatch):
180    """Base address of a TE must be 8-byte aligned."""
181    mock_open = lambda tmptlstr, mode: mock.mock_open(
182        read_data=TransferList().header_to_bytes()
183        + bytes(5)
184        + TransferEntry(0, 0, bytes(0)).header_to_bytes
185    )()
186    monkeypatch.setattr("builtins.open", mock_open)
187
188    result = tlcrunner.invoke(cli, ["validate", tmptlstr])
189
190    assert result.exit_code == 1
191
192
193@pytest.mark.parametrize(
194    "version", [0, TransferList.version, TransferList.version + 1, 1 << 8]
195)
196def test_validate_unsupported_version(version, tmptlstr, tlcrunner, monkeypatch):
197    tl = TransferList()
198    tl.version = version
199
200    mock_open = lambda tmptlstr, mode: mock.mock_open(read_data=tl.header_to_bytes())()
201    monkeypatch.setattr("builtins.open", mock_open)
202
203    result = tlcrunner.invoke(cli, ["validate", tmptlstr])
204
205    if version >= TransferList.version and version <= 0xFF:
206        assert result.exit_code == 0
207    else:
208        assert result.exit_code == 1
209
210
211def test_create_entry_from_yaml_and_blob_file(
212    tlcrunner, tmpyamlconfig_blob_file, tmptlstr, non_empty_tag_id
213):
214    tlcrunner.invoke(
215        cli,
216        [
217            "create",
218            "--from-yaml",
219            tmpyamlconfig_blob_file.strpath,
220            tmptlstr,
221        ],
222    )
223
224    tl = TransferList.fromfile(tmptlstr)
225    assert tl is not None
226    assert len(tl.entries) == 1
227    assert tl.entries[0].id == non_empty_tag_id
228
229
230@pytest.mark.parametrize(
231    "entry",
232    [
233        {"tag_id": 0},
234        {
235            "tag_id": 0x104,
236            "addr": 0x0400100000000010,
237            "size": 0x0003300000000000,
238        },
239        {
240            "tag_id": 0x100,
241            "pp_addr": 100,
242        },
243        {
244            "tag_id": "optee_pageable_part",
245            "pp_addr": 100,
246        },
247    ],
248)
249def test_create_from_yaml_check_sum_bytes(tlcrunner, tmpyamlconfig, tmptlstr, entry):
250    """Test creating a TL from a yaml file, but only check that the sum of the
251    data in the yaml file matches the sum of the data in the TL. This means
252    you don't have to type the exact sequence of expected bytes. All the data
253    in the yaml file must be integers (except for the tag IDs, which can be
254    strings).
255    """
256    # create yaml config file
257    config = {
258        "has_checksum": True,
259        "max_size": 0x1000,
260        "entries": [entry],
261    }
262    with open(tmpyamlconfig, "w") as f:
263        yaml.safe_dump(config, f)
264
265    # invoke TLC
266    tlcrunner.invoke(
267        cli,
268        [
269            "create",
270            "--from-yaml",
271            tmpyamlconfig,
272            tmptlstr,
273        ],
274    )
275
276    # open created TL, and check
277    tl = TransferList.fromfile(tmptlstr)
278    assert tl is not None
279    assert len(tl.entries) == 1
280
281    # Check that the sum of all the data in the transfer entry in the yaml file
282    # is the same as the sum of all the data in the transfer list. Don't count
283    # the tag id or the TE headers.
284
285    # every item in the entry dict must be an integer
286    yaml_total = 0
287    for key, data in iter_nested_dict(entry):
288        if key != "tag_id":
289            num_bytes = ceil(log2(data + 1) / 8)
290            yaml_total += sum(data.to_bytes(num_bytes, "little"))
291
292    tl_total = sum(tl.entries[0].data)
293
294    assert tl_total == yaml_total
295
296
297@pytest.mark.parametrize(
298    "entry,expected",
299    [
300        (
301            {
302                "tag_id": 0x102,
303                "ep_info": {
304                    "h": {
305                        "type": 0x01,
306                        "version": 0x02,
307                        "attr": 8,
308                    },
309                    "pc": 67239936,
310                    "spsr": 965,
311                    "args": [67112976, 67112960, 0, 0, 0, 0, 0, 0],
312                },
313            },
314            (
315                "0x00580201 0x00000008 0x04020000 0x00000000 "
316                "0x000003C5 0x00000000 0x04001010 0x00000000 "
317                "0x04001000 0x00000000 0x00000000 0x00000000 "
318                "0x00000000 0x00000000 0x00000000 0x00000000 "
319                "0x00000000 0x00000000 0x00000000 0x00000000 "
320                "0x00000000 0x00000000"
321            ),
322        ),
323        (
324            {
325                "tag_id": 0x102,
326                "ep_info": {
327                    "h": {
328                        "type": 0x01,
329                        "version": 0x02,
330                        "attr": "EP_NON_SECURE | EP_ST_ENABLE",
331                    },
332                    "pc": 67239936,
333                    "spsr": 965,
334                    "args": [67112976, 67112960, 0, 0, 0, 0, 0, 0],
335                },
336            },
337            (
338                "0x00580201 0x00000005 0x04020000 0x00000000 "
339                "0x000003C5 0x00000000 0x04001010 0x00000000 "
340                "0x04001000 0x00000000 0x00000000 0x00000000 "
341                "0x00000000 0x00000000 0x00000000 0x00000000 "
342                "0x00000000 0x00000000 0x00000000 0x00000000 "
343                "0x00000000 0x00000000"
344            ),
345        ),
346    ],
347)
348def test_create_from_yaml_check_exact_data(
349    tlcrunner, tmpyamlconfig, tmptlstr, entry, expected
350):
351    """Test creating a TL from a yaml file, checking the exact sequence of
352    bytes. This is useful for checking that the alignment is correct. You can
353    get the expected sequence of bytes by copying it from the ArmDS debugger.
354    """
355    # create yaml config file
356    config = {
357        "has_checksum": True,
358        "max_size": 0x1000,
359        "entries": [entry],
360    }
361    with open(tmpyamlconfig, "w") as f:
362        yaml.safe_dump(config, f)
363
364    # invoke TLC
365    tlcrunner.invoke(
366        cli,
367        [
368            "create",
369            "--from-yaml",
370            tmpyamlconfig,
371            tmptlstr,
372        ],
373    )
374
375    # open TL and check
376    tl = TransferList.fromfile(tmptlstr)
377    assert tl is not None
378    assert len(tl.entries) == 1
379
380    # check expected and actual data
381    actual = tl.entries[0].data
382    actual = bytes_to_hex(actual)
383
384    assert actual == expected
385
386
387@pytest.mark.parametrize("option", ["-O", "--output"])
388def test_gen_tl_header_with_output_name(tlcrunner, tmptlstr, option, filename="test.h"):
389    with tlcrunner.isolated_filesystem():
390        result = tlcrunner.invoke(
391            cli,
392            [
393                "gen-header",
394                option,
395                filename,
396                tmptlstr,
397            ],
398        )
399
400        assert result.exit_code == 0
401        assert Path(filename).exists()
402
403
404def test_gen_tl_with_fdt_header(tmptlstr, tmpfdt):
405    tlcrunner = CliRunner()
406
407    with tlcrunner.isolated_filesystem():
408        tlcrunner.invoke(cli, ["create", "--size", 1000, "--fdt", tmpfdt, tmptlstr])
409
410        result = tlcrunner.invoke(
411            cli,
412            [
413                "gen-header",
414                tmptlstr,
415            ],
416        )
417
418        assert result.exit_code == 0
419        assert Path("header.h").exists()
420
421        with open("header.h", "r") as f:
422            dtb_match = search(r"DTB_OFFSET\s+(\d+)", "".join(f.readlines()))
423            assert dtb_match and dtb_match[1].isnumeric()
424
425
426def test_gen_empty_tl_c_header(tlcrunner, tmptlstr):
427    with tlcrunner.isolated_filesystem():
428        result = tlcrunner.invoke(
429            cli,
430            [
431                "gen-header",
432                tmptlstr,
433            ],
434        )
435
436        assert result.exit_code == 0
437        assert Path("header.h").exists()
438
439        with open("header.h", "r") as f:
440            lines = "".join(f.readlines())
441
442            assert TransferList.hdr_size == int(
443                findall(r"SIZE\s+(0x[0-9a-fA-F]+|\d+)", lines)[0], 16
444            )
445            assert TransferList.version == int(
446                findall(r"VERSION.+(0x[0-9a-fA-F]+|\d+)", lines)[0]
447            )
448
449
450def bytes_to_hex(data: bytes) -> str:
451    """Convert bytes to a hex string in the same format as the debugger in
452    ArmDS
453
454    You can copy data from the debugger in Arm Development Studio and put it
455    into a unit test. You can then run this function on the output from tlc,
456    and compare it to the data you copied.
457
458    The format is groups of 4 bytes with 0x prefixes separated by spaces.
459    Little endian is used.
460    """
461    words_hex = []
462    for i in range(0, len(data), 4):
463        word = data[i : i + 4]
464        word_int = int.from_bytes(word, "little")
465        word_hex = "0x" + f"{word_int:0>8x}".upper()
466        words_hex.append(word_hex)
467
468    return " ".join(words_hex)
469
470
471def iter_nested_dict(dictionary: dict):
472    for key, value in dictionary.items():
473        if isinstance(value, dict):
474            yield from iter_nested_dict(value)
475        else:
476            yield key, value
477