1#!/usr/bin/env python3
2#
3# Copyright (C) 2022 Intel Corporation.
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7
8import os
9import argparse
10from copy import deepcopy
11
12from pipeline import PipelineObject, PipelineStage, PipelineEngine
13
14class SchemaTypeSlicer:
15    xpath_ns = {
16        "xs": "http://www.w3.org/2001/XMLSchema",
17        "acrn": "https://projectacrn.org",
18    }
19
20    @classmethod
21    def get_node(cls, element, xpath):
22        return element.find(xpath, namespaces=cls.xpath_ns)
23
24    @classmethod
25    def get_nodes(cls, element, xpath):
26        return element.findall(xpath, namespaces=cls.xpath_ns)
27
28    def __init__(self, etree):
29        self.etree = etree
30
31    def get_type_definition(self, type_name):
32        type_node = self.get_node(self.etree, f"//xs:complexType[@name='{type_name}']")
33        if type_node is None:
34            type_node = self.get_node(self.etree, f"//xs:simpleType[@name='{type_name}']")
35        return type_node
36
37    def slice_element_list(self, element_list_node, new_nodes):
38        sliced = False
39
40        for element_node in self.get_nodes(element_list_node, "xs:element"):
41            if not self.is_element_needed(element_node):
42                element_list_node.remove(element_node)
43                sliced = True
44                continue
45
46            # For embedded complex type definition, also slice in place. If the sliced type contains no sub-element,
47            # remove the element itself, too.
48            element_type_node = self.get_node(element_node, "xs:complexType")
49            if element_type_node is not None:
50                new_sub_nodes = self.slice(element_type_node, in_place=True)
51                if len(self.get_nodes(element_type_node, ".//xs:element")) > 0:
52                    new_nodes.extend(new_sub_nodes)
53                else:
54                    element_list_node.remove(element_node)
55                continue
56
57            # For external type definition, create a copy to slice. If the sliced type contains no sub-element, remove
58            # the element itself.
59            element_type_name = element_node.get("type")
60            if element_type_name:
61                element_type_node = self.get_type_definition(element_type_name)
62                if element_type_node is not None:
63                    sliced_type_name = self.get_name_of_slice(element_type_name)
64
65                    # If a sliced type already exists, do not duplicate the effort
66                    type_node = self.get_type_definition(sliced_type_name)
67                    if type_node is not None:
68                        element_node.set("type", sliced_type_name)
69                        sliced = True
70                    else:
71                        new_sub_nodes = self.slice(element_type_node)
72                        if len(new_sub_nodes) == 0:
73                            continue
74                        elif new_sub_nodes[-1].tag.endswith("simpleType") or len(self.get_nodes(new_sub_nodes[-1], ".//xs:element")) > 0:
75                            new_nodes.extend(new_sub_nodes)
76                            element_node.set("type", sliced_type_name)
77                            sliced = True
78                        else:
79                            element_list_node.remove(element_node)
80
81        return sliced
82
83    def slice_restriction(self, restriction_node):
84        sliced = False
85
86        for restriction in self.get_nodes(restriction_node, "xs:enumeration"):
87            if not self.is_element_needed(restriction):
88                restriction_node.remove(restriction)
89                sliced = True
90
91        return sliced
92
93    def slice(self, type_node, in_place=False, force_copy=False):
94        new_nodes = []
95        sliced = False
96
97        if in_place:
98            new_type_node = type_node
99        else:
100            new_type_node = deepcopy(type_node)
101            type_name = type_node.get("name")
102            if type_name != None:
103                sliced_type_name = self.get_name_of_slice(type_name)
104                new_type_node.set("name", sliced_type_name)
105
106        element_list_node = self.get_node(new_type_node, "xs:all")
107        if element_list_node is not None:
108            sliced = self.slice_element_list(element_list_node, new_nodes)
109
110        restriction_node = self.get_node(new_type_node, "xs:restriction")
111        if restriction_node is not None:
112            sliced = self.slice_restriction(restriction_node)
113
114        if not in_place and (sliced or force_copy):
115            new_nodes.append(new_type_node)
116        return new_nodes
117
118    def is_element_needed(self, element_node):
119        return True
120
121    def get_name_of_slice(self, name):
122        return f"Sliced{name}"
123
124class SlicingSchemaByVMTypeStage(PipelineStage):
125    uses = {"schema_etree"}
126    provides = {"schema_etree"}
127
128    class VMTypeSlicer(SchemaTypeSlicer):
129        def is_element_needed(self, element_node):
130            annot_node = self.get_node(element_node, "xs:annotation")
131            if annot_node is None:
132                return True
133            applicable_vms = annot_node.get("{https://projectacrn.org}applicable-vms")
134            return applicable_vms is None or applicable_vms.find(self.vm_type_indicator) >= 0
135
136        def get_name_of_slice(self, name):
137            return f"{self.type_prefix}{name}"
138
139    class PreLaunchedTypeSlicer(VMTypeSlicer):
140        vm_type_indicator = "pre-launched"
141        type_prefix = "PreLaunched"
142
143    class ServiceVMTypeSlicer(VMTypeSlicer):
144        vm_type_indicator = "service-vm"
145        type_prefix = "Service"
146
147    class PostLaunchedTypeSlicer(VMTypeSlicer):
148        vm_type_indicator = "post-launched"
149        type_prefix = "PostLaunched"
150
151    def run(self, obj):
152        schema_etree = obj.get("schema_etree")
153
154        vm_type_name = "VMConfigType"
155        vm_type_node = SchemaTypeSlicer.get_node(schema_etree, f"//xs:complexType[@name='{vm_type_name}']")
156        slicers = [
157            self.PreLaunchedTypeSlicer(schema_etree),
158            self.ServiceVMTypeSlicer(schema_etree),
159            self.PostLaunchedTypeSlicer(schema_etree)
160        ]
161
162        for slicer in slicers:
163            new_nodes = slicer.slice(vm_type_node, force_copy=True)
164            for n in new_nodes:
165                schema_etree.getroot().append(n)
166
167        for node in SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType[@name='ACRNConfigType']//xs:element[@name='vm']//xs:alternative"):
168            test = node.get("test")
169            if test.find("PRE_LAUNCHED_VM") >= 0:
170                node.set("type", slicers[0].get_name_of_slice(vm_type_name))
171            elif test.find("SERVICE_VM") >= 0:
172                node.set("type", slicers[1].get_name_of_slice(vm_type_name))
173            elif test.find("POST_LAUNCHED_VM") >= 0:
174                node.set("type", slicers[2].get_name_of_slice(vm_type_name))
175
176        obj.set("schema_etree", schema_etree)
177
178class SlicingSchemaByViewStage(PipelineStage):
179    uses = {"schema_etree"}
180    provides = {"schema_etree"}
181
182    class ViewSlicer(SchemaTypeSlicer):
183        def is_element_needed(self, element_node):
184            annot_node = self.get_node(element_node, "xs:annotation")
185            if annot_node is None:
186                return True
187            views = annot_node.get("{https://projectacrn.org}views")
188            return views is None or views.find(self.view_indicator) >= 0
189
190        def get_name_of_slice(self, name):
191            if name.find("ConfigType") >= 0:
192                return name.replace("ConfigType", f"{self.type_prefix}ConfigType")
193            else:
194                return f"{self.type_prefix}{name}"
195
196    class BasicViewSlicer(ViewSlicer):
197        view_indicator = "basic"
198        type_prefix = "Basic"
199
200    class AdvancedViewSlicer(ViewSlicer):
201        view_indicator = "advanced"
202        type_prefix = "Advanced"
203
204    def run(self, obj):
205        schema_etree = obj.get("schema_etree")
206
207        type_nodes = list(filter(lambda x: x.get("name") and x.get("name").endswith("VMConfigType"), SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType")))
208        type_nodes.append(SchemaTypeSlicer.get_node(schema_etree, "//xs:complexType[@name = 'HVConfigType']"))
209
210        slicers = [
211            self.BasicViewSlicer(schema_etree),
212            self.AdvancedViewSlicer(schema_etree),
213        ]
214
215        for slicer in slicers:
216            for type_node in type_nodes:
217                new_nodes = slicer.slice(type_node, force_copy=True)
218                for n in new_nodes:
219                    schema_etree.getroot().append(n)
220
221        obj.set("schema_etree", schema_etree)
222
223def main(args):
224    from lxml_loader import LXMLLoadStage
225
226    pipeline = PipelineEngine(["schema_path"])
227    pipeline.add_stages([
228        LXMLLoadStage("schema"),
229        SlicingSchemaByVMTypeStage(),
230        SlicingSchemaByViewStage(),
231    ])
232
233    obj = PipelineObject(schema_path = args.schema)
234    pipeline.run(obj)
235    obj.get("schema_etree").write(args.out)
236
237    print(f"Sliced schema written to {args.out}")
238
239
240if __name__ == "__main__":
241    # abs __file__ path to ignore `__file__ == 'schema_slicer.py'` issue
242    config_tools_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
243    schema_dir = os.path.join(config_tools_dir, "schema")
244
245    parser = argparse.ArgumentParser(description="Slice a given scenario schema by VM types and views")
246    parser.add_argument("out", nargs="?", default=os.path.join(schema_dir, "sliced.xsd"), help="Path where the output is placed")
247    parser.add_argument("--schema", default=os.path.join(schema_dir, "config.xsd"), help="the XML schema that defines the syntax of scenario XMLs")
248    args = parser.parse_args()
249
250    main(args)
251