1# Copyright (C) 2021-2022 Intel Corporation.
2#
3# SPDX-License-Identifier: BSD-3-Clause
4#
5
6from . import grammar
7from . import datatypes
8from .tree import Tree
9
10### Basic data types
11
12def __build_value(label, value):
13    tree = Tree(label, [value])
14    tree.register_structure(("value",))
15    tree.complete_parsing()
16    return tree
17
18def __build_string(label, s):
19    assert isinstance(s, str)
20    return __build_value(label, s)
21
22def __build_const_data(label, data):
23    assert isinstance(data, int)
24    return __build_value(label, data)
25
26NameSeg = lambda x: __build_string("NameSeg", x)
27NameString = lambda x: __build_string("NameString", x)
28String = lambda x: __build_string("String", x)
29
30def ByteList(data):
31    assert isinstance(data, (bytes, bytearray))
32    return __build_value("ByteList", data)
33
34ByteData = lambda x: __build_const_data("ByteData", x)
35WordData = lambda x: __build_const_data("WordData", x)
36DWordData = lambda x: __build_const_data("DWordData", x)
37TWordData = lambda x: __build_const_data("TWordData", x)
38QWordData = lambda x: __build_const_data("QWordData", x)
39
40def PkgLength(length=0):
41    return __build_const_data("PkgLength", length)
42
43FieldLength = lambda x: __build_const_data("FieldLength", x)
44
45### Sequences
46
47def MethodInvocation(name, *args):
48    if isinstance(name, str):
49        name_tree = NameString(name)
50    else:
51        name_tree = name
52    tree = Tree("MethodInvocation", [name])
53    for arg in args:
54        assert isinstance(arg, Tree)
55        tree.append_child(arg)
56    tree.register_structure(("NameString", "TermArg*"))
57    tree.complete_parsing()
58    return tree
59
60def __create_sequence_builder(label):
61    def aux(tree, elem, arg):
62        if isinstance(arg, Tree):
63            # TODO: validate the given arg
64            tree.append_child(arg)
65        else:
66            tree.append_child(globals()[elem](arg))
67
68    seq = grammar.get_definition(label)
69    structure = grammar.get_names(label)
70
71    def fn(*args):
72        tree = Tree(label)
73        it = iter(args)
74        for elem in seq:
75            if isinstance(elem, int):    # The leading opcode
76                continue
77            elif elem.endswith("*"):
78                for arg in it:
79                    aux(tree, elem, arg)
80            elif elem.endswith("?"):
81                try:
82                    aux(tree, elem, next(it))
83                except StopIteration:
84                    pass
85            else:
86                aux(tree, elem, next(it))
87        tree.register_structure(structure)
88        tree.complete_parsing()
89        return tree
90    return fn
91
92def build_value(value):
93    if isinstance(value, (int, datatypes.Integer)):
94        if isinstance(value, int):
95            value = datatypes.Integer(value)
96        v = value.get()
97        return \
98            ZeroOp() if v == 0 else \
99            OneOp() if v == 1 else \
100            ByteConst(v) if v <= 0xff else \
101            WordConst(v) if v <= 0xffff else \
102            DWordConst(v) if v <= 0xffffffff else \
103            QWordConst(v)
104    elif isinstance(value, datatypes.Buffer):
105        buffer_size = len(value.get())
106        builder = ByteConst if buffer_size <= 0xff else \
107                  WordConst if buffer_size <= 0xffff else \
108                  DWordConst if buffer_size <= 0xffffffff else \
109                  QWordConst
110        return DefBuffer(
111            PkgLength(),
112            builder(buffer_size),
113            ByteList(value.get()))
114    elif isinstance(value, datatypes.Package):
115        elements = list(map(build_value, value.elements))
116        return DefPackage(
117            PkgLength(),
118            len(value.elements),
119            PackageElementList(*elements))
120    elif isinstance(value, (str, datatypes.String)):
121        if isinstance(value, str):
122            return String(value)
123        else:
124            return String(value.get())
125    elif isinstance(value, datatypes.BufferField):
126        return build_value(value.to_integer())
127    else:
128        return None
129
130for sym in dir(grammar):
131    # Ignore builtin members and opcode constants
132    if sym.startswith("__") or (sym.upper() == sym):
133        continue
134
135    definition = getattr(grammar, sym)
136    if isinstance(definition, tuple):
137        globals()[sym] = __create_sequence_builder(sym)
138    elif isinstance(definition, list) and len(definition) == 1:
139        if definition[0] in globals().keys():
140            globals()[sym] = globals()[definition[0]]
141