1# Copyright 2025 The Hafnium Authors.
2#
3# Use of this source code is governed by a BSD-style
4# license that can be found in the LICENSE file or at
5# https://opensource.org/licenses/BSD-3-Clause.
6
7import xml.etree.ElementTree as ET
8
9import click
10import collections
11import datetime
12import json
13import os
14import platform
15import re
16import time
17
18MACHINE = platform.machine()
19
20HFTEST_LOG_PREFIX = "[hftest] "
21HFTEST_LOG_FAILURE_PREFIX = "Failure:"
22HFTEST_LOG_FINISHED = "FINISHED"
23
24HFTEST_CTRL_JSON_START = "[hftest_ctrl:json_start]"
25HFTEST_CTRL_JSON_END = "[hftest_ctrl:json_end]"
26
27HFTEST_CTRL_GET_COMMAND_LINE = "[hftest_ctrl:get_command_line]"
28HFTEST_CTRL_FINISHED = "[hftest_ctrl:finished]"
29
30HFTEST_CTRL_JSON_REGEX = re.compile("^\\[[0-9a-fA-F]+ [0-9a-fA-F]+\\] ")
31
32HF_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
33    os.path.abspath(__file__)))))
34
35HF_PREBUILTS = os.path.join(HF_ROOT, "prebuilts")
36
37VM_NODE_REGEX = "vm[1-9]"
38
39def read_file(path):
40    with open(path, "r", encoding="utf-8", errors="backslashreplace") as f:
41        return f.read()
42
43def write_file(path, to_write, append=False):
44    with open(path, "a" if append else "w") as f:
45        f.write(to_write)
46
47def append_file(path, to_write):
48    write_file(path, to_write, append=True)
49
50def join_if_not_None(*args):
51    return " ".join(filter(lambda x: x, args))
52
53def get_vm_node_from_manifest(dts : str):
54    """ Get VM node string from Partition's extension to Partition Manager's
55    manifest."""
56    match = re.search(VM_NODE_REGEX, dts)
57    if not match:
58        raise Exception("Partition's node is not defined in its manifest.")
59    return match.group()
60
61def correct_vm_node(dts: str, node_index : int):
62    """ The vm node is being appended to the Partition Manager manifests.
63    Ideally, these files would be reused accross various test set-ups."""
64    return dts.replace(get_vm_node_from_manifest(dts), f"vm{node_index}")
65
66def shared_options(f):
67    f = click.option("--hypervisor")(f)
68    f = click.option("--log", required=True)(f)
69    f = click.option("--initrd")(f)
70    f = click.option("--out_initrd")(f)
71    f = click.option("--suite")(f)
72    f = click.option("--test")(f)
73    f = click.option("--vm_args")(f)
74    f = click.option("--skip-long-running-tests", is_flag=True)(f)
75    f = click.option("--force-long-running", is_flag=True)(f)
76    f = click.option("--debug", is_flag=True, help="Makes platforms stall waiting for debugger connection.")(f)
77    f = click.option("--show-output", is_flag=True)(f)
78    f = click.option("--disable_visualisation", is_flag=True)(f)
79    f = click.option("--log-level", help="Set the log level (DEBUG=10, INFO=20, WARNING=30, ERROR=40)")(f)
80    return f
81
82DT = collections.namedtuple("DT", ["dts", "dtb"])
83
84class ArtifactsManager:
85    """Class which manages folder with test artifacts."""
86
87    def __init__(self, log_dir):
88        self.created_files = []
89        self.log_dir = log_dir
90
91        # Create directory.
92        try:
93            os.makedirs(self.log_dir)
94        except OSError:
95            if not os.path.isdir(self.log_dir):
96                raise
97        print("Logs saved under", log_dir)
98
99        # Create files expected by the Sponge test result parser.
100        self.sponge_log_path = self.create_file("sponge_log", ".log")
101        self.sponge_xml_path = self.create_file("sponge_log", ".xml")
102
103    def gen_file_path(self, basename, extension):
104        """Generate path to a file in the log directory."""
105        return os.path.join(self.log_dir, basename + extension)
106
107    def create_file(self, basename, extension):
108        """Create and touch a new file in the log folder. Ensure that no other
109        file of the same name was created by this instance of ArtifactsManager.
110        """
111        # Determine the path of the file.
112        path = self.gen_file_path(basename, extension)
113
114        # Check that the path is unique.
115        assert(path not in self.created_files)
116        self.created_files += [ path ]
117
118        # Touch file.
119        with open(path, "w") as f:
120            pass
121
122        return path
123
124    def get_file(self, basename, extension):
125        """Return path to a file in the log folder. Assert that it was created
126        by this instance of ArtifactsManager."""
127        path = self.gen_file_path(basename, extension)
128        assert(path in self.created_files)
129        return path
130
131
132# Tuple used to return information about the results of running a set of tests.
133TestRunnerResult = collections.namedtuple("TestRunnerResult", [
134        "tests_run",
135        "tests_failed",
136        "tests_skipped",
137    ])
138
139class TestRunner:
140    """Class which communicates with a test platform to obtain a list of
141    available tests and driving their execution."""
142
143    def __init__(self, artifacts, driver, test_set_up, suite_regex, test_regex,
144            skip_long_running_tests, force_long_running, debug, show_output):
145        self.artifacts = artifacts
146        self.driver = driver
147        self.test_set_up = test_set_up
148        self.skip_long_running_tests = skip_long_running_tests
149        self.force_long_running = force_long_running
150        self.debug = debug
151        self.show_output = show_output
152
153        self.suite_re = re.compile(suite_regex or ".*")
154        self.test_re = re.compile(test_regex or ".*")
155
156    def extract_hftest_lines(self, raw):
157        """Extract hftest-specific lines from a raw output from an invocation
158        of the test platform."""
159        lines = []
160        lines_to_process = raw.splitlines()
161
162        try:
163            # If logs have logs of more than one VM, the loop below to extract
164            # lines won't work. Thus, extracting between starting and ending
165            # logs: HFTEST_CTRL_GET_COMMAND_LINE and HFTEST_CTRL_FINISHED.
166            hftest_start = lines_to_process.index(HFTEST_CTRL_GET_COMMAND_LINE) + 1
167            hftest_end = lines_to_process.index(HFTEST_CTRL_FINISHED)
168        except ValueError:
169            hftest_start = 0
170            hftest_end = len(lines_to_process)
171
172        lines_to_process = lines_to_process[hftest_start : hftest_end]
173
174        for line in lines_to_process:
175            match = HFTEST_CTRL_JSON_REGEX.search(line)
176            if match is not None:
177                line = line[match.end():]
178            if line.startswith(HFTEST_LOG_PREFIX):
179                lines.append(line[len(HFTEST_LOG_PREFIX):])
180        return lines
181
182    def get_test_json(self):
183        """Invoke the test platform and request a JSON of available test and
184        test suites."""
185        out = self.driver.run("json", "json", self.force_long_running)
186        hf_out = self.extract_hftest_lines(out)
187        try:
188            hf_out = hf_out[hf_out.index(HFTEST_CTRL_JSON_START) + 1
189                        :hf_out.index(HFTEST_CTRL_JSON_END)];
190        except ValueError as e:
191            print("Unable to find JSON control string:")
192            print(f"out={out}")
193            print(f"hf_out={hf_out}")
194            raise e
195
196        hf_out = "\n".join(hf_out)
197        try:
198            return json.loads(hf_out)
199        except ValueError as e:
200            print("Unable to parse JSON:")
201            print(f"out={out}")
202            print(f"hf_out={hf_out}")
203            print(out)
204            raise e
205
206    def collect_results(self, fn, it, xml_node):
207        """Run `fn` on every entry in `it` and collect their TestRunnerResults.
208        Insert "tests" and "failures" nodes to `xml_node`."""
209        tests_run = 0
210        tests_failed = 0
211        tests_skipped = 0
212        start_time = time.perf_counter()
213        for i in it:
214            sub_result = fn(i)
215            assert(sub_result.tests_run >= sub_result.tests_failed)
216            tests_run += sub_result.tests_run
217            tests_failed += sub_result.tests_failed
218            tests_skipped += sub_result.tests_skipped
219        elapsed_time = time.perf_counter() - start_time
220
221        xml_node.set("tests", str(tests_run + tests_skipped))
222        xml_node.set("failures", str(tests_failed))
223        xml_node.set("skipped", str(tests_skipped))
224        xml_node.set("time", str(elapsed_time))
225        return TestRunnerResult(tests_run, tests_failed, tests_skipped)
226
227    def is_passed_test(self, test_out):
228        """Parse the output of a test and return True if it passed."""
229        return \
230            len(test_out) > 0 and \
231            test_out[-1] == HFTEST_LOG_FINISHED and \
232            not any(l.startswith(HFTEST_LOG_FAILURE_PREFIX) for l in test_out)
233
234    def get_failure_message(self, test_out):
235        """Parse the output of a test and return the message of the first
236        assertion failure."""
237        for i, line in enumerate(test_out):
238            if line.startswith(HFTEST_LOG_FAILURE_PREFIX) and i + 1 < len(test_out):
239                # The assertion message is on the line after the 'Failure:'
240                return test_out[i + 1].strip()
241
242        return None
243
244    def get_log_name(self, suite, test):
245        """Returns a string with a generated log name for the test."""
246        log_name = ""
247
248        cpu = self.driver.args.cpu
249        if cpu:
250            log_name += cpu + "."
251
252        log_name += suite["name"] + "." + test["name"]
253
254        return log_name
255
256    def run_test(self, suite, test, suite_xml):
257        """Invoke the test platform and request to run a given `test` in given
258        `suite`. Create a new XML node with results under `suite_xml`.
259        Test only invoked if it matches the regex given to constructor."""
260        if not self.test_re.match(test["name"]):
261            return TestRunnerResult(tests_run=0, tests_failed=0, tests_skipped=0)
262
263        test_xml = ET.SubElement(suite_xml, "testcase")
264        test_xml.set("name", test["name"])
265        test_xml.set("classname", suite["name"])
266
267        if (self.skip_long_running_tests and test["is_long_running"]) or test["skip_test"]:
268            print("      SKIP", test["name"])
269            test_xml.set("status", "notrun")
270            skipped_xml = ET.SubElement(test_xml, "skipped")
271            skipped_xml.set("message", "Long running")
272            return TestRunnerResult(tests_run=0, tests_failed=0, tests_skipped=1)
273
274        action_log = "DEBUG" if self.debug else "RUN"
275        print(f"      {action_log}", test["name"])
276        log_name = self.get_log_name(suite, test)
277
278        test_xml.set("status", "run")
279
280        start_time = time.perf_counter()
281        out = self.driver.run(
282            log_name, "run {} {}".format(suite["name"], test["name"]),
283            test["is_long_running"] or self.force_long_running,
284            self.debug, self.show_output)
285
286        hftest_out = self.extract_hftest_lines(out)
287        elapsed_time = time.perf_counter() - start_time
288
289        test_xml.set("time", str(elapsed_time))
290
291        system_out_xml = ET.SubElement(test_xml, "system-out")
292        system_out_xml.text = out
293
294        if self.is_passed_test(hftest_out):
295            print("        PASS")
296            return TestRunnerResult(tests_run=1, tests_failed=0, tests_skipped=0)
297        else:
298            print("[x]     FAIL --", self.driver.get_run_log(log_name))
299            failure_xml = ET.SubElement(test_xml, "failure")
300            failure_message = self.get_failure_message(hftest_out) or "Test failed"
301            failure_xml.set("message", failure_message)
302            failure_xml.text = '\n'.join(hftest_out)
303            return TestRunnerResult(tests_run=1, tests_failed=1, tests_skipped=0)
304
305    def run_suite(self, suite, xml):
306        """Invoke the test platform and request to run all matching tests in
307        `suite`. Create new XML nodes with results under `xml`.
308        Suite skipped if it does not match the regex given to constructor."""
309        if not self.suite_re.match(suite["name"]):
310            return TestRunnerResult(tests_run=0, tests_failed=0, tests_skipped=0)
311
312        print("    SUITE", suite["name"])
313        suite_xml = ET.SubElement(xml, "testsuite")
314        suite_xml.set("name", suite["name"])
315        properties_xml = ET.SubElement(suite_xml, "properties")
316
317        property_xml = ET.SubElement(properties_xml, "property")
318        property_xml.set("name", "driver")
319        property_xml.set("value", type(self.driver).__name__)
320
321        if self.driver.args.cpu:
322            property_xml = ET.SubElement(properties_xml, "property")
323            property_xml.set("name", "cpu")
324            property_xml.set("value", self.driver.args.cpu)
325
326        return self.collect_results(
327            lambda test: self.run_test(suite, test, suite_xml),
328            suite["tests"],
329            suite_xml)
330
331    def run_tests(self):
332        """Run all suites and tests matching regexes given to constructor.
333        Write results to sponge log XML. Return the number of run and failed
334        tests."""
335
336        test_spec = self.get_test_json()
337        timestamp = datetime.datetime.now().replace(microsecond=0).isoformat()
338
339        xml = ET.Element("testsuites")
340        xml.set("name", self.test_set_up)
341        xml.set("timestamp", timestamp)
342
343        result = self.collect_results(
344            lambda suite: self.run_suite(suite, xml),
345            test_spec["suites"],
346            xml)
347
348        # Write XML to file.
349        ET.ElementTree(xml).write(self.artifacts.sponge_xml_path,
350            encoding='utf-8', xml_declaration=True)
351
352        if result.tests_failed > 0:
353            print("[x] FAIL:", result.tests_failed, "of", result.tests_run,
354                    "tests failed")
355        elif result.tests_run > 0:
356            print("    PASS: all", result.tests_run, "tests passed")
357
358        # Let the driver clean up.
359        self.driver.finish()
360
361        return result
362