1# Copyright (C) 2021-2022 Intel Corporation.
2#
3# SPDX-License-Identifier: BSD-3-Clause
4#
5
6import logging
7
8from . import grammar
9from .context import *
10from .exception import *
11from .tree import Tree, Transformer, Direction
12
13class Factory:
14    @staticmethod
15    def hook_pre(context, tree):
16        pass
17
18    @staticmethod
19    def hook_named(context, tree, name):
20        pass
21
22    @staticmethod
23    def hook_post(context, tree):
24        pass
25
26    def __init__(self):
27        self.level = 0
28        self.label = "unknown"
29
30    def mark_begin(self):
31        self.level += 1
32
33    def mark_end(self):
34        self.level -= 1
35
36    def match(self, context, stream, tree):
37        raise NotImplementedError
38
39    def parse(self, context, tree):
40        self.mark_begin()
41        tree.label = self.label
42        tree.scope = context.get_scope()
43        self.hook_pre(context, tree)
44        try:
45            self.match(context, context.current_stream, tree)
46        except Exception as e:
47            self.hook_post(context, tree)
48            self.mark_end()
49            raise
50
51        self.hook_post(context, tree)
52        self.mark_end()
53        return tree
54
55    @property
56    def decoder(self):
57        raise NotImplementedError
58
59################################################################################
60# 20.2.2 Name Objects Encoding
61################################################################################
62
63class NameSegFactory(Factory):
64    def __init__(self):
65        super().__init__()
66        self.__decoder = {}
67        for i in range(ord('A'), ord('Z') + 1):
68            self.__decoder[i] = self
69        self.__decoder[ord('_')] = self
70        self.label = "NameSeg"
71
72    def match(self, context, stream, tree):
73        tree.register_structure(("value",))
74        tree.append_child(stream.get_fixed_length_string(4))
75        tree.complete_parsing()
76
77    @property
78    def decoder(self):
79        return self.__decoder
80
81NameSeg = NameSegFactory()
82
83class NameStringFactory(Factory):
84    def __init__(self):
85        super().__init__()
86        self.label = "NameString"
87        self.__decoder = {}
88        for i in range(ord('A'), ord('Z') + 1):
89            self.__decoder[i] = self
90        for i in [ord('_'), ord('\\'), ord('^'), grammar.AML_DUAL_NAME_PREFIX, grammar.AML_MULTI_NAME_PREFIX]:
91            self.__decoder[i] = self
92
93    def match(self, context, stream, tree):
94        tree.register_structure(("value",))
95        acc = ""
96
97        # Namespace prefixes
98        char = stream.get_char()
99        while char in ["\\", "^"]:
100            acc += char
101            char = stream.get_char()
102
103        # Object name
104        if ord(char) == grammar.AML_DUAL_NAME_PREFIX:
105            acc += stream.get_fixed_length_string(4)
106            acc += "."
107            acc += stream.get_fixed_length_string(4)
108        elif ord(char) == grammar.AML_MULTI_NAME_PREFIX:
109            seg_count = stream.get_integer(1)
110            for i in range(0, seg_count):
111                if acc and acc[-1] not in ["\\", "^"]:
112                    acc += "."
113                acc += stream.get_fixed_length_string(4)
114        elif char == "\x00":    # NullName
115            pass
116        else:                   # NameSeg
117            stream.seek(-1)
118            acc += stream.get_fixed_length_string(4)
119
120        tree.append_child(acc)
121        tree.complete_parsing()
122
123    @property
124    def decoder(self):
125        return self.__decoder
126
127NameString = NameStringFactory()
128
129################################################################################
130# 20.2.3 Data Objects Encoding
131################################################################################
132
133class ConstDataFactory(Factory):
134    def __init__(self, label, width):
135        super().__init__()
136        self.label = label
137        self.width = width
138
139    def match(self, context, stream, tree):
140        tree.register_structure(("value",))
141        tree.append_child(stream.get_integer(self.width))
142        tree.complete_parsing()
143        return tree
144
145ByteData = ConstDataFactory("ByteData", 1)
146WordData = ConstDataFactory("WordData", 2)
147DWordData = ConstDataFactory("DWordData", 4)
148TWordData = ConstDataFactory("TWordData", 6)
149QWordData = ConstDataFactory("QWordData", 8)
150
151class StringFactory(Factory):
152    def __init__(self):
153        super().__init__()
154        self.label = "String"
155
156    def match(self, context, stream, tree):
157        assert stream.get_opcode()[0] == grammar.AML_STRING_PREFIX
158
159        tree.register_structure(("value",))
160        tree.append_child(stream.get_string())
161        tree.complete_parsing()
162        return tree
163
164    @property
165    def decoder(self):
166        return {grammar.AML_STRING_PREFIX: self}
167
168String = StringFactory()
169
170class ByteListFactory(Factory):
171    def __init__(self):
172        super().__init__()
173        self.label = "ByteList"
174
175    def match(self, context, stream, tree):
176        tree.register_structure(("value",))
177        tree.append_child(stream.get_buffer())
178        tree.complete_parsing()
179        stream.pop_scope()
180
181ByteList = ByteListFactory()
182
183################################################################################
184# 20.2.4 Package Length Encoding
185################################################################################
186
187class PkgLengthFactory(Factory):
188    @staticmethod
189    def get_package_length(byte_count, value):
190        if byte_count == 0:
191            total_size = (value & 0x3F)
192        else:
193            total_size = value & 0x0F
194            for i in range(1, byte_count + 1):
195                byte = (value & (0xFF << (i * 8))) >> (i * 8)
196                total_size |= (byte << (i * 8 - 4))
197        return total_size
198
199    def __init__(self, label, create_new_scope):
200        super().__init__()
201        self.label = label
202        self.create_new_scope = create_new_scope
203
204    def match(self, context, stream, tree):
205        pkg_lead_byte = stream.peek_integer(1)
206        byte_count = pkg_lead_byte >> 6
207        assert byte_count <= 3
208
209        tree.register_structure(("value",))
210        tree.append_child(self.get_package_length(byte_count, stream.get_integer(byte_count + 1)))
211        tree.complete_parsing()
212
213        if self.create_new_scope:
214            remaining = tree.value - byte_count - 1
215            stream.push_scope(remaining)
216            tree.package_range = (stream.current, remaining)
217        return tree
218
219PkgLength = PkgLengthFactory("PkgLength", True)
220FieldLength = PkgLengthFactory("FieldLength", False)
221
222################################################################################
223# 20.2.5 Term Objects Encoding
224################################################################################
225
226class MethodInvocationFactory(Factory):
227    def __init__(self):
228        super().__init__()
229        self.__decoder = None
230        self.label = "MethodInvocation"
231
232    def match(self, context, stream, tree):
233        tree.register_structure(("NameString", "TermArg*"))
234
235        child_namestring = Tree()
236        globals()["NameString"].parse(context, child_namestring)
237        tree.append_child(child_namestring)
238
239        sym = context.lookup_symbol(child_namestring.value)
240        if isinstance(sym, (MethodDecl, PredefinedMethodDecl)):
241            for i in range(0, sym.nargs):
242                child_arg = Tree()
243                globals()["TermArg"].parse(context, child_arg)
244                tree.append_child(child_arg)
245
246        tree.complete_parsing()
247        return tree
248
249    @property
250    def decoder(self):
251        if not self.__decoder:
252            self.__decoder = {}
253            for k in globals()["NameString"].decoder.keys():
254                self.__decoder[k] = self
255        return self.__decoder
256
257MethodInvocation = MethodInvocationFactory()
258
259################################################################################
260# Infrastructure Factories
261################################################################################
262
263class SequenceFactory(Factory):
264    def __init__(self, label, seq):
265        super().__init__()
266        self.label = label
267        # Some objects in ACPI AML have multiple occurrences of the same type of object in the grammar. In order to
268        # refer to these different occurrences, the grammar module uses the following notation to give names to each of
269        # them:
270        #
271        #     "<object type>:<alias name>"
272        #
273        # The grammar module provides the get_definition() and get_names() methods to get the specification solely in
274        # object types or alias names, respectively. For objects without aliases, the type is reused as the name.
275        try:
276            self.seq = grammar.get_definition(label)
277            self.structure = grammar.get_names(label)
278        except KeyError:
279            self.seq = seq
280            self.structure = seq
281        self.__decoder = None
282
283    def match(self, context, stream, tree):
284        tree.register_structure(self.structure)
285
286        # When a TermList is empty, the stream has already come to the end of the current scope here. Do not attempt to
287        # peek the next opcode in such cases.
288        if stream.at_end() and \
289           (self.seq[0][-1] in ["*", "?"]):
290            stream.pop_scope()
291            tree.complete_parsing()
292            return tree
293
294        package_end = 0
295
296        # Under any case this function shall maintain the balance of stream scopes. The following flags indicate the
297        # cleanup actions upon exceptions.
298        to_recover_from_deferred_mode = False
299        to_pop_stream_scope = False
300        completed = True
301
302        for i,elem in enumerate(self.seq):
303            pos = stream.current
304            try:
305                if isinstance(elem, int):  # The leading opcode
306                    opcode, _ = stream.get_opcode()
307                    assert elem == opcode
308                elif elem.endswith("*"):
309                    elem = elem[:-1]
310                    factory = globals()[elem]
311                    while not stream.at_end():
312                        child = Tree()
313                        factory.parse(context, child)
314                        tree.append_child(child)
315                    stream.pop_scope()
316                elif elem.endswith("?"):
317                    elem = elem[:-1]
318                    factory = globals()[elem]
319                    if not stream.at_end():
320                        sub_opcode, _ = stream.peek_opcode()
321                        if sub_opcode in factory.decoder.keys():
322                            child = Tree()
323                            factory.parse(context, child)
324                            tree.append_child(child)
325                else:
326                    # It is likely that a method body has forward definitions, while typically it does not define
327                    # symbols that are referred later. Thus always defer the parsing of method bodies to the second
328                    # phase.
329                    #
330                    # In second phase the labels of sequence factories always have the ".deferred" suffix. Thus it is
331                    # safe to check self.label against "DefMethod" here.
332                    if elem == "TermList" and self.label == "DefMethod":
333                        raise DeferLater(self.label, [elem])
334                    factory = globals()[elem]
335                    child = Tree()
336                    factory.parse(context, child)
337                    tree.append_child(child)
338                    if child.label == "PkgLength":
339                        to_pop_stream_scope = True
340                        if child.package_range:
341                            package_end = child.package_range[0] + child.package_range[1]
342                            context.enter_deferred_mode()
343                            to_recover_from_deferred_mode = True
344                    elif child.label == "NameString":
345                        self.hook_named(context, tree, child.value)
346            except (DecodeError, DeferLater, ScopeMismatch, UndefinedSymbol) as e:
347                if to_pop_stream_scope:
348                    stream.pop_scope(force=True)
349                    if to_recover_from_deferred_mode:
350                        tree.deferred_range = (pos, package_end - pos)
351                        tree.context_scope = context.get_scope()
352                        tree.factory = SequenceFactory(f"{self.label}.deferred", self.seq[i:])
353                        stream.seek(package_end, absolute=True)
354                        completed = False
355                        break
356                else:
357                    raise e
358
359        if completed:
360            tree.complete_parsing()
361
362        if to_recover_from_deferred_mode:
363            context.exit_deferred_mode()
364        return tree
365
366    @property
367    def decoder(self):
368        if not self.__decoder:
369            if isinstance(self.seq[0], int):
370                self.__decoder = {self.seq[0]: self}
371            else:
372                self.__decoder = {}
373                for k in globals()[self.seq[0]].decoder.keys():
374                    self.__decoder[k] = self
375        return self.__decoder
376
377class OptionFactory(Factory):
378    def __init__(self, label, opts):
379        super().__init__()
380        self.label = label
381        self.opts = opts
382        self.__decoder = None
383
384    def match(self, context, stream, tree):
385        opcode, _ = stream.peek_opcode()
386        try:
387            if len(self.opts) == 1:
388                globals()[self.opts[0]].parse(context, tree)
389            else:
390                self.decoder[opcode].parse(context, tree)
391            return tree
392        except KeyError:
393            raise DecodeError(opcode, self.label)
394
395    @property
396    def decoder(self):
397        if not self.__decoder:
398            self.__decoder = {}
399            for opt in self.opts:
400                self.__decoder.update(globals()[opt].decoder)
401        return self.__decoder
402
403class DeferredExpansion(Transformer):
404    def __init__(self, context):
405        super().__init__(Direction.TOPDOWN)
406        self.context = context
407
408        nodes = ["DefScope", "DefDevice", "DefMethod", "DefPowerRes", "DefProcessor", "DefThermalZone",
409                 "DefIfElse", "DefElse", "DefWhile"]
410
411        for i in nodes:
412            setattr(self, i, self.__expand_deferred_range)
413
414    def __expand_deferred_range(self, tree):
415        if tree.deferred_range:
416            start, size = tree.deferred_range
417            self.context.current_stream.reset()
418            self.context.current_stream.seek(start, absolute=True)
419            self.context.current_stream.push_scope(size)
420
421            aux_tree = Tree()
422            self.context.change_scope(tree.context_scope)
423            try:
424                tree.factory.parse(self.context, aux_tree)
425                tree.children.extend(aux_tree.children)
426                tree.deferred_range = None
427                tree.factory = None
428                tree.complete_parsing()
429            except (DecodeError, DeferLater, ScopeMismatch, UndefinedSymbol) as e:
430                logging.debug(f"expansion of {tree.label} at {hex(tree.deferred_range[0])} failed due to: " + str(e))
431
432            self.context.pop_scope()
433
434        return tree
435
436################################################################################
437# Hook functions
438################################################################################
439
440def DefAlias_hook_post(context, tree):
441    source = tree.SourceObject.value
442    alias = tree.AliasObject.value
443    sym = AliasDecl(alias, source, tree)
444    context.register_symbol(sym)
445
446def DefName_hook_named(context, tree, name):
447    sym = NamedDecl(name, tree)
448    context.register_symbol(sym)
449
450def DefScope_hook_named(context, tree, name):
451    context.change_scope(name)
452
453def DefScope_hook_post(context, tree):
454    context.pop_scope()
455
456
457def DefCreateBitField_hook_named(context, tree, name):
458    name = tree.children[2].value
459    sym = FieldDecl(name, 1, tree)
460    context.register_symbol(sym)
461
462def DefCreateByteField_hook_named(context, tree, name):
463    name = tree.children[2].value
464    sym = FieldDecl(name, 8, tree)
465    context.register_symbol(sym)
466
467def DefCreateDWordField_hook_named(context, tree, name):
468    name = tree.children[2].value
469    sym = FieldDecl(name, 32, tree)
470    context.register_symbol(sym)
471
472def DefCreateField_hook_named(context, tree, name):
473    name = tree.children[3].value
474    sym = FieldDecl(name, 0, tree)
475    context.register_symbol(sym)
476
477def DefCreateQWordField_hook_named(context, tree, name):
478    name = tree.children[2].value
479    sym = FieldDecl(name, 64, tree)
480    context.register_symbol(sym)
481
482def DefCreateWordField_hook_named(context, tree, name):
483    name = tree.children[2].value
484    sym = FieldDecl(name, 16, tree)
485    context.register_symbol(sym)
486
487def DefDevice_hook_named(context, tree, name):
488    sym = DeviceDecl(name, tree)
489    context.register_symbol(sym)
490    context.change_scope(name)
491
492def DefDevice_hook_post(context, tree):
493    context.pop_scope()
494
495def DefExternal_hook_post(context, tree):
496    name = tree.NameString.value
497    ty = tree.ObjectType.value
498    nargs = tree.ArgumentCount.value
499
500    if ty == MethodDecl.object_type():
501        sym = MethodDecl(name, nargs, tree)
502    else:
503        sym = NamedDecl(name, tree)
504    context.register_symbol(sym)
505
506access_width_map = {
507    0: 8,    # AnyAcc
508    1: 8,    # ByteAcc
509    2: 16,   # WordAcc
510    3: 32,   # DWordAcc
511    4: 64,   # QWordAcc
512    5: 8,    # BufferAcc
513    # The other values are reserved
514}
515
516def DefField_hook_post(context, tree):
517    # Update the fields with region & offset info
518    region_name = context.lookup_symbol(tree.NameString.value).name
519    flags = tree.FieldFlags.value
520    access_width = access_width_map[flags & 0xF]
521    fields = tree.FieldList.FieldElements
522    bit_offset = 0
523    for field in fields:
524        if field.label == "NamedField":
525            name = field.NameSeg.value
526            length = field.FieldLength.value
527            sym = context.lookup_symbol(name)
528            assert isinstance(sym, OperationFieldDecl)
529            sym.set_location(region_name, bit_offset, access_width)
530            sym.parent_tree = tree
531            bit_offset += length
532        elif field.label == "ReservedField":
533            length = field.FieldLength.value
534            bit_offset += length
535        else:
536            break
537
538def DefIndexField_hook_post(context, tree):
539    # Update the fields with region & offset info
540    index_register = context.lookup_symbol(tree.IndexName.value)
541    data_register = context.lookup_symbol(tree.DataName.value)
542    flags = tree.FieldFlags.value
543    access_width = access_width_map[flags & 0xF]
544    fields = tree.FieldList.FieldElements
545    bit_offset = 0
546    for field in fields:
547        if field.label == "NamedField":
548            name = field.NameSeg.value
549            length = field.FieldLength.value
550            sym = context.lookup_symbol(name)
551            assert isinstance(sym, OperationFieldDecl)
552            sym.set_indexed_location(index_register, data_register, bit_offset, access_width)
553            bit_offset += length
554        elif field.label == "ReservedField":
555            length = field.FieldLength.value
556            bit_offset += length
557        else:
558            break
559
560def NamedField_hook_post(context, tree):
561    name = tree.NameSeg.value
562    length = tree.FieldLength.value
563    sym = OperationFieldDecl(name, length, tree)
564    context.register_symbol(sym)
565
566def DefMethod_hook_named(context, tree, name):
567    context.change_scope(name)
568
569def DefMethod_hook_post(context, tree):
570    context.pop_scope()
571    if len(tree.children) >= 3:
572        # Parsing of the method may be deferred. Do not use named fields to access its children.
573        name = tree.children[1].value
574        flags = tree.children[2].value
575        nargs = flags & 0x7
576        sym = MethodDecl(name, nargs, tree)
577        context.register_symbol(sym)
578
579def DefOpRegion_hook_named(context, tree, name):
580    sym = OperationRegionDecl(name, tree)
581    context.register_symbol(sym)
582
583def DefPowerRes_hook_named(context, tree, name):
584    sym = NamedDecl(name, tree)
585    context.register_symbol(sym)
586    context.change_scope(name)
587
588def DefPowerRes_hook_post(context, tree):
589    context.pop_scope()
590
591def DefThermalZone_hook_named(context, tree, name):
592    sym = NamedDecl(name, tree)
593    context.register_symbol(sym)
594    context.change_scope(name)
595
596def DefThermalZone_hook_post(context, tree):
597    context.pop_scope()
598
599################################################################################
600# Instantiate parsers
601################################################################################
602
603def register_hooks(factory, label):
604    if f"{sym}_hook_pre" in globals().keys():
605        factory.hook_pre = globals()[f"{sym}_hook_pre"]
606    if f"{sym}_hook_named" in globals().keys():
607        factory.hook_named = globals()[f"{sym}_hook_named"]
608    if f"{sym}_hook_post" in globals().keys():
609        factory.hook_post = globals()[f"{sym}_hook_post"]
610
611for sym in dir(grammar):
612    # Ignore builtin members and opcode constants
613    if sym.startswith("__") or (sym.upper() == sym):
614        continue
615
616    definition = getattr(grammar, sym)
617    if isinstance(definition, tuple):
618        factory = SequenceFactory(sym, definition)
619        register_hooks(factory, sym)
620        globals()[sym] = factory
621    elif isinstance(definition, list):
622        factory = OptionFactory(sym, definition)
623        register_hooks(factory, sym)
624        globals()[sym] = factory
625