1#!/usr/bin/env python3
2#
3# Copyright (C) 2022 Intel Corporation.
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7
8import sys, os
9import argparse
10import logging
11from copy import copy
12from collections import namedtuple
13import re
14
15try:
16    import elementpath
17    import elementpath_overlay
18    from elementpath.xpath_context import XPathContext
19    import xmlschema
20except ImportError:
21    logging.error("Python package `xmlschema` is not installed.\n" +
22                  "The scenario XML file will NOT be validated against the schema, which may cause build-time or runtime errors.\n" +
23                  "To enable the validation, install the python package by executing: pip3 install xmlschema.")
24    sys.exit(0)
25
26from pipeline import PipelineObject, PipelineStage, PipelineEngine
27from schema_slicer import SlicingSchemaByVMTypeStage
28from default_populator import DefaultValuePopulatingStage
29
30def existing_file_type(parser):
31    def aux(arg):
32        if not os.path.exists(arg):
33            parser.error(f"can't open {arg}: No such file or directory")
34        elif not os.path.isfile(arg):
35            parser.error(f"can't open {arg}: Is not a file")
36        else:
37            return arg
38    return aux
39
40def log_level_type(parser):
41    def aux(arg):
42        arg = arg.lower()
43        if arg in ["critical", "error", "warning", "info", "debug"]:
44            return arg
45        else:
46            parser.error(f"{arg} is not a valid log level")
47    return aux
48
49class ValidationError(dict):
50    logging_fns = {
51        "critical": logging.critical,
52        "error": logging.error,
53        "warning": logging.warning,
54        "info": logging.info,
55        "debug": logging.debug,
56    }
57
58    def __init__(self, paths, message, severity):
59        super().__init__(paths = paths, message = message, severity = severity)
60
61    def __str__(self):
62        return f"{', '.join(self['paths'])}: {self['message']}"
63
64    def log(self):
65        try:
66            self.logging_fns[self['severity']](self)
67        except KeyError:
68            logging.debug(self)
69
70class ScenarioValidator:
71    def __init__(self, schema_etree, datachecks_etree):
72        """Initialize the validator with preprocessed schemas in ElementTree."""
73        self.schema = xmlschema.XMLSchema11(schema_etree)
74        self.datachecks = xmlschema.XMLSchema11(datachecks_etree) if datachecks_etree else None
75
76    def check_syntax(self, scenario_etree):
77        errors = []
78
79        it = self.schema.iter_errors(scenario_etree)
80        for error in it:
81            # Syntactic errors are always critical.
82            e = ValidationError([error.path], error.reason, "critical")
83            e.log()
84            errors.append(e)
85
86        return errors
87
88    def check_semantics(self, board_etree, scenario_etree):
89        errors = []
90
91        if self.datachecks:
92            unified_node = copy(scenario_etree.getroot())
93            parent_map = {c : p for p in unified_node.iter() for c in p}
94            unified_node.extend(board_etree.getroot())
95            it = self.datachecks.iter_errors(unified_node)
96            for error in it:
97                e = self.format_error(unified_node, parent_map, error)
98                e.log()
99                errors.append(e)
100
101        return errors
102
103    @staticmethod
104    def format_paths(unified_node, parent_map, report_on, variables):
105        elems = elementpath.select(unified_node, report_on, variables = variables, parser = elementpath.XPath2Parser)
106        paths = []
107        for elem in elems:
108            path = []
109            while elem is not None:
110                path_segment = elem.tag
111                parent = parent_map.get(elem, None)
112                if parent is not None:
113                    children = parent.findall(elem.tag)
114                    if len(children) > 1:
115                        path_segment += f"[{children.index(elem) + 1}]"
116                path.insert(0, path_segment)
117                elem = parent
118            paths.append(f"/{'/'.join(path)}")
119        return paths
120
121    @staticmethod
122    def get_counter_example(error):
123        assertion = error.validator
124        if not isinstance(assertion, xmlschema.validators.assertions.XsdAssert):
125            return {}
126
127        elem = error.obj
128        context = XPathContext(elem, variables={'value': None})
129        context.counter_example = {}
130        result = assertion.token.evaluate(context)
131
132        if result == False:
133            return context.counter_example
134        else:
135            return {}
136
137    @staticmethod
138    def format_error(unified_node, parent_map, error):
139        def format_node(n):
140            if isinstance(n, str):
141                return n
142            elif isinstance(n, (int, float)):
143                return str(n)
144            elif isinstance(n, object) and n.__class__.__name__.endswith("Element"):
145                return n.text
146            else:
147                return str(n)
148
149        anno = error.validator.annotation
150        counter_example = ScenarioValidator.get_counter_example(error)
151        variables = {k.obj.source.strip("$"): v for k,v in counter_example.items()}
152
153        paths = ScenarioValidator.format_paths(unified_node, parent_map, anno.elem.get("{https://projectacrn.org}report-on"), variables)
154        description = anno.elem.find("{http://www.w3.org/2001/XMLSchema}documentation").text
155        severity = anno.elem.get("{https://projectacrn.org}severity")
156
157        expr_regex = re.compile("{[^{}]*}")
158        exprs = set(expr_regex.findall(description))
159        for expr in exprs:
160            result = elementpath.select(unified_node, expr.strip("{}"), variables = variables, parser = elementpath.XPath2Parser)
161            if isinstance(result, list):
162                if len(result) == 1:
163                    value = format_node(result[0])
164                elif len(result) > 1:
165                    s = ', '.join(map(format_node, result))
166                    value = f"[{s}]"
167                else:
168                    value = "{unknown}"
169            else:
170                value = str(result)
171            description = description.replace(expr, value)
172
173        return ValidationError(paths, description, severity)
174
175class ValidatorConstructionStage(PipelineStage):
176    # The schema etree may still useful for schema-based transformation. Do not consume it.
177    uses = {"schema_etree"}
178    consumes = {"datachecks_etree"}
179    provides = {"validator"}
180
181    def run(self, obj):
182        validator = ScenarioValidator(obj.get("schema_etree"), obj.get("datachecks_etree"))
183        obj.set("validator", validator)
184
185class ValidatorConstructionByFileStage(PipelineStage):
186    uses = {"schema_path", "datachecks_path"}
187    provides = {"validator"}
188
189    def run(self, obj):
190        validator = ScenarioValidator(obj.get("schema_path"), obj.get("datachecks_path"))
191        obj.set("validator", validator)
192
193class SyntacticValidationStage(PipelineStage):
194    provides = {"syntactic_errors"}
195
196    def __init__(self, etree_tag = "scenario"):
197        self.etree_tag = f"{etree_tag}_etree"
198        self.uses = {"validator", self.etree_tag}
199
200    def run(self, obj):
201        errors = obj.get("validator").check_syntax(obj.get(self.etree_tag))
202        obj.set("syntactic_errors", errors)
203
204class SemanticValidationStage(PipelineStage):
205    uses = {"validator", "board_etree", "scenario_etree"}
206    provides = {"semantic_errors"}
207
208    def run(self, obj):
209        errors = obj.get("validator").check_semantics(obj.get("board_etree"), obj.get("scenario_etree"))
210        obj.set("semantic_errors", errors)
211
212class ReportValidationResultStage(PipelineStage):
213    consumes = {"board_etree", "scenario_etree", "syntactic_errors", "semantic_errors"}
214    provides = {"nr_all_errors"}
215
216    def run(self, obj):
217        board_name = obj.get("board_etree").getroot().get("board")
218        scenario_name = obj.get("scenario_etree").getroot().get("scenario")
219
220        nr_critical = len(obj.get("syntactic_errors"))
221        nr_error = len(list(filter(lambda e: e["severity"] == "error", obj.get("semantic_errors"))))
222        nr_warning = len(list(filter(lambda e: e["severity"] == "warning", obj.get("semantic_errors"))))
223
224        if nr_critical > 0 or nr_error > 0:
225            logging.error(f"Board {board_name} and scenario {scenario_name} are inconsistent: {nr_critical} syntax errors, {nr_error} data errors, {nr_warning} warnings.")
226        elif nr_warning > 0:
227            logging.warning(f"Board {board_name} and scenario {scenario_name} are potentially inconsistent: {nr_warning} warnings.")
228        else:
229            logging.info(f"Board {board_name} and scenario {scenario_name} are valid and consistent.")
230
231        obj.set("nr_all_errors", nr_critical + nr_error + nr_warning)
232
233def validate_one(validation_pipeline, pipeline_obj, board_xml, scenario_xml):
234    pipeline_obj.set("board_path", board_xml)
235    pipeline_obj.set("scenario_path", scenario_xml)
236    validation_pipeline.run(pipeline_obj)
237    return pipeline_obj.consume("nr_all_errors")
238
239def validate_board(validation_pipeline, pipeline_obj, board_xml):
240    board_dir = os.path.dirname(board_xml)
241    nr_all_errors = 0
242
243    for f in os.listdir(board_dir):
244        if not f.endswith(".xml"):
245            continue
246        if f == os.path.basename(board_xml) or "launch" in f:
247            continue
248        nr_all_errors += validate_one(validation_pipeline, pipeline_obj, board_xml, os.path.join(board_dir, f))
249
250    return nr_all_errors
251
252def validate_all(validation_pipeline, pipeline_obj, data_dir):
253    nr_all_errors = 0
254
255    for f in os.listdir(data_dir):
256        board_xml = os.path.join(data_dir, f, f"{f}.xml")
257        if os.path.isfile(board_xml):
258            nr_all_errors += validate_board(validation_pipeline, pipeline_obj, board_xml)
259        else:
260            logging.warning(f"Cannot find a board XML under {os.path.join(data_dir, f)}")
261
262    return nr_all_errors
263
264def main(args):
265    from lxml_loader import LXMLLoadStage
266
267    validator_construction_pipeline = PipelineEngine(["schema_path", "datachecks_path"])
268    validator_construction_pipeline.add_stages([
269        LXMLLoadStage("schema"),
270        LXMLLoadStage("datachecks"),
271        SlicingSchemaByVMTypeStage(),
272        ValidatorConstructionStage(),
273    ])
274
275    validation_pipeline = PipelineEngine(["board_path", "scenario_path", "schema_etree", "validator"])
276    validation_pipeline.add_stages([
277        LXMLLoadStage("board"),
278        LXMLLoadStage("scenario"),
279        DefaultValuePopulatingStage(),
280        SyntacticValidationStage(),
281        SemanticValidationStage(),
282        ReportValidationResultStage(),
283    ])
284
285    obj = PipelineObject(schema_path = args.schema, datachecks_path = args.datachecks)
286    validator_construction_pipeline.run(obj)
287    if args.board and args.scenario:
288        nr_all_errors = validate_one(validation_pipeline, obj, args.board, args.scenario)
289    elif args.board:
290        nr_all_errors = validate_board(validation_pipeline, obj, args.board)
291    else:
292        nr_all_errors = validate_all(validation_pipeline, obj, os.path.join(config_tools_dir, "data"))
293
294    sys.exit(1 if nr_all_errors > 0 else 0)
295
296if __name__ == "__main__":
297    config_tools_dir = os.path.join(os.path.dirname(__file__), "..")
298    schema_dir = os.path.join(config_tools_dir, "schema")
299
300    parser = argparse.ArgumentParser()
301    parser.add_argument("board", nargs="?", type=existing_file_type(parser), help="the board XML file to be validated")
302    parser.add_argument("scenario", nargs="?", type=existing_file_type(parser), help="the scenario XML file to be validated")
303    parser.add_argument("--loglevel", default="warning", type=log_level_type(parser), help="choose log level, e.g. debug, info, warning or error")
304    parser.add_argument("--schema", default=os.path.join(schema_dir, "config.xsd"), help="the XML schema that defines the syntax of scenario XMLs")
305    parser.add_argument("--datachecks", default=os.path.join(schema_dir, "datachecks.xsd"), help="the XML schema that defines the semantic rules against board and scenario data")
306    args = parser.parse_args()
307
308    logging.basicConfig(level=args.loglevel.upper())
309    main(args)
310