1#!/usr/bin/env python3 2# 3# Copyright (C) 2022 Intel Corporation. 4# 5# SPDX-License-Identifier: BSD-3-Clause 6# 7 8import os 9import argparse 10from copy import deepcopy 11 12from pipeline import PipelineObject, PipelineStage, PipelineEngine 13 14class SchemaTypeSlicer: 15 xpath_ns = { 16 "xs": "http://www.w3.org/2001/XMLSchema", 17 "acrn": "https://projectacrn.org", 18 } 19 20 @classmethod 21 def get_node(cls, element, xpath): 22 return element.find(xpath, namespaces=cls.xpath_ns) 23 24 @classmethod 25 def get_nodes(cls, element, xpath): 26 return element.findall(xpath, namespaces=cls.xpath_ns) 27 28 def __init__(self, etree): 29 self.etree = etree 30 31 def get_type_definition(self, type_name): 32 type_node = self.get_node(self.etree, f"//xs:complexType[@name='{type_name}']") 33 if type_node is None: 34 type_node = self.get_node(self.etree, f"//xs:simpleType[@name='{type_name}']") 35 return type_node 36 37 def slice_element_list(self, element_list_node, new_nodes): 38 sliced = False 39 40 for element_node in self.get_nodes(element_list_node, "xs:element"): 41 if not self.is_element_needed(element_node): 42 element_list_node.remove(element_node) 43 sliced = True 44 continue 45 46 # For embedded complex type definition, also slice in place. If the sliced type contains no sub-element, 47 # remove the element itself, too. 48 element_type_node = self.get_node(element_node, "xs:complexType") 49 if element_type_node is not None: 50 new_sub_nodes = self.slice(element_type_node, in_place=True) 51 if len(self.get_nodes(element_type_node, ".//xs:element")) > 0: 52 new_nodes.extend(new_sub_nodes) 53 else: 54 element_list_node.remove(element_node) 55 continue 56 57 # For external type definition, create a copy to slice. If the sliced type contains no sub-element, remove 58 # the element itself. 59 element_type_name = element_node.get("type") 60 if element_type_name: 61 element_type_node = self.get_type_definition(element_type_name) 62 if element_type_node is not None: 63 sliced_type_name = self.get_name_of_slice(element_type_name) 64 65 # If a sliced type already exists, do not duplicate the effort 66 type_node = self.get_type_definition(sliced_type_name) 67 if type_node is not None: 68 element_node.set("type", sliced_type_name) 69 sliced = True 70 else: 71 new_sub_nodes = self.slice(element_type_node) 72 if len(new_sub_nodes) == 0: 73 continue 74 elif new_sub_nodes[-1].tag.endswith("simpleType") or len(self.get_nodes(new_sub_nodes[-1], ".//xs:element")) > 0: 75 new_nodes.extend(new_sub_nodes) 76 element_node.set("type", sliced_type_name) 77 sliced = True 78 else: 79 element_list_node.remove(element_node) 80 81 return sliced 82 83 def slice_restriction(self, restriction_node): 84 sliced = False 85 86 for restriction in self.get_nodes(restriction_node, "xs:enumeration"): 87 if not self.is_element_needed(restriction): 88 restriction_node.remove(restriction) 89 sliced = True 90 91 return sliced 92 93 def slice(self, type_node, in_place=False, force_copy=False): 94 new_nodes = [] 95 sliced = False 96 97 if in_place: 98 new_type_node = type_node 99 else: 100 new_type_node = deepcopy(type_node) 101 type_name = type_node.get("name") 102 if type_name != None: 103 sliced_type_name = self.get_name_of_slice(type_name) 104 new_type_node.set("name", sliced_type_name) 105 106 element_list_node = self.get_node(new_type_node, "xs:all") 107 if element_list_node is not None: 108 sliced = self.slice_element_list(element_list_node, new_nodes) 109 110 restriction_node = self.get_node(new_type_node, "xs:restriction") 111 if restriction_node is not None: 112 sliced = self.slice_restriction(restriction_node) 113 114 if not in_place and (sliced or force_copy): 115 new_nodes.append(new_type_node) 116 return new_nodes 117 118 def is_element_needed(self, element_node): 119 return True 120 121 def get_name_of_slice(self, name): 122 return f"Sliced{name}" 123 124class SlicingSchemaByVMTypeStage(PipelineStage): 125 uses = {"schema_etree"} 126 provides = {"schema_etree"} 127 128 class VMTypeSlicer(SchemaTypeSlicer): 129 def is_element_needed(self, element_node): 130 annot_node = self.get_node(element_node, "xs:annotation") 131 if annot_node is None: 132 return True 133 applicable_vms = annot_node.get("{https://projectacrn.org}applicable-vms") 134 return applicable_vms is None or applicable_vms.find(self.vm_type_indicator) >= 0 135 136 def get_name_of_slice(self, name): 137 return f"{self.type_prefix}{name}" 138 139 class PreLaunchedTypeSlicer(VMTypeSlicer): 140 vm_type_indicator = "pre-launched" 141 type_prefix = "PreLaunched" 142 143 class ServiceVMTypeSlicer(VMTypeSlicer): 144 vm_type_indicator = "service-vm" 145 type_prefix = "Service" 146 147 class PostLaunchedTypeSlicer(VMTypeSlicer): 148 vm_type_indicator = "post-launched" 149 type_prefix = "PostLaunched" 150 151 def run(self, obj): 152 schema_etree = obj.get("schema_etree") 153 154 vm_type_name = "VMConfigType" 155 vm_type_node = SchemaTypeSlicer.get_node(schema_etree, f"//xs:complexType[@name='{vm_type_name}']") 156 slicers = [ 157 self.PreLaunchedTypeSlicer(schema_etree), 158 self.ServiceVMTypeSlicer(schema_etree), 159 self.PostLaunchedTypeSlicer(schema_etree) 160 ] 161 162 for slicer in slicers: 163 new_nodes = slicer.slice(vm_type_node, force_copy=True) 164 for n in new_nodes: 165 schema_etree.getroot().append(n) 166 167 for node in SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType[@name='ACRNConfigType']//xs:element[@name='vm']//xs:alternative"): 168 test = node.get("test") 169 if test.find("PRE_LAUNCHED_VM") >= 0: 170 node.set("type", slicers[0].get_name_of_slice(vm_type_name)) 171 elif test.find("SERVICE_VM") >= 0: 172 node.set("type", slicers[1].get_name_of_slice(vm_type_name)) 173 elif test.find("POST_LAUNCHED_VM") >= 0: 174 node.set("type", slicers[2].get_name_of_slice(vm_type_name)) 175 176 obj.set("schema_etree", schema_etree) 177 178class SlicingSchemaByViewStage(PipelineStage): 179 uses = {"schema_etree"} 180 provides = {"schema_etree"} 181 182 class ViewSlicer(SchemaTypeSlicer): 183 def is_element_needed(self, element_node): 184 annot_node = self.get_node(element_node, "xs:annotation") 185 if annot_node is None: 186 return True 187 views = annot_node.get("{https://projectacrn.org}views") 188 return views is None or views.find(self.view_indicator) >= 0 189 190 def get_name_of_slice(self, name): 191 if name.find("ConfigType") >= 0: 192 return name.replace("ConfigType", f"{self.type_prefix}ConfigType") 193 else: 194 return f"{self.type_prefix}{name}" 195 196 class BasicViewSlicer(ViewSlicer): 197 view_indicator = "basic" 198 type_prefix = "Basic" 199 200 class AdvancedViewSlicer(ViewSlicer): 201 view_indicator = "advanced" 202 type_prefix = "Advanced" 203 204 def run(self, obj): 205 schema_etree = obj.get("schema_etree") 206 207 type_nodes = list(filter(lambda x: x.get("name") and x.get("name").endswith("VMConfigType"), SchemaTypeSlicer.get_nodes(schema_etree, "//xs:complexType"))) 208 type_nodes.append(SchemaTypeSlicer.get_node(schema_etree, "//xs:complexType[@name = 'HVConfigType']")) 209 210 slicers = [ 211 self.BasicViewSlicer(schema_etree), 212 self.AdvancedViewSlicer(schema_etree), 213 ] 214 215 for slicer in slicers: 216 for type_node in type_nodes: 217 new_nodes = slicer.slice(type_node, force_copy=True) 218 for n in new_nodes: 219 schema_etree.getroot().append(n) 220 221 obj.set("schema_etree", schema_etree) 222 223def main(args): 224 from lxml_loader import LXMLLoadStage 225 226 pipeline = PipelineEngine(["schema_path"]) 227 pipeline.add_stages([ 228 LXMLLoadStage("schema"), 229 SlicingSchemaByVMTypeStage(), 230 SlicingSchemaByViewStage(), 231 ]) 232 233 obj = PipelineObject(schema_path = args.schema) 234 pipeline.run(obj) 235 obj.get("schema_etree").write(args.out) 236 237 print(f"Sliced schema written to {args.out}") 238 239 240if __name__ == "__main__": 241 # abs __file__ path to ignore `__file__ == 'schema_slicer.py'` issue 242 config_tools_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 243 schema_dir = os.path.join(config_tools_dir, "schema") 244 245 parser = argparse.ArgumentParser(description="Slice a given scenario schema by VM types and views") 246 parser.add_argument("out", nargs="?", default=os.path.join(schema_dir, "sliced.xsd"), help="Path where the output is placed") 247 parser.add_argument("--schema", default=os.path.join(schema_dir, "config.xsd"), help="the XML schema that defines the syntax of scenario XMLs") 248 args = parser.parse_args() 249 250 main(args) 251