1#!/usr/bin/env python3
2#
3# Copyright (C), 2022 Intel Corporation.
4# Copyright (c), 2018-2021, SISSA (International School for Advanced Studies).
5#
6# SPDX-License-Identifier: BSD-3-Clause
7#
8
9import sys, os
10from decimal import Decimal
11from copy import copy
12import operator
13import elementpath
14
15# Allow this script to find the library module at misc/config_tools/library.
16#
17# TODO: Reshuffle the module structure of the configuration toolset for clearer imports.
18sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
19import library.rdt as rdt
20
21BaseParser = elementpath.XPath2Parser
22
23class CustomParser(BaseParser):
24    if hasattr(BaseParser, "SYMBOLS"):
25        SYMBOLS = BaseParser.SYMBOLS | {
26            # Bit-wise operations
27            'bitwise-and',
28
29            'bits-of',
30            'has',
31            'duplicate-values',
32
33            'number-of-clos-id-needed',
34            }
35
36method = CustomParser.method
37function = CustomParser.function
38
39###
40# Custom functions
41
42OPERATORS_MAP = {
43    'bitwise-and': operator.and_
44}
45
46def hex_to_int(value):
47    if hasattr(value, 'text'):
48        value = value.text
49    if isinstance(value, int):
50        return value
51    elif isinstance(value, (float, Decimal)):
52        return int(value)
53    elif isinstance(value, str) and value.startswith("0x"):
54        return int(value, base=16)
55    else:
56        raise TypeError('invalid type {!r} for integer'.format(type(value)))
57
58@method(function('bitwise-and', nargs=2))
59def evaluate(self, context=None):
60    def aux(op):
61        op1 = self.get_argument(context, 0)
62        op2 = self.get_argument(context, 1)
63
64        try:
65            return op(hex_to_int(op1), hex_to_int(op2))
66        except ValueError as err:
67            raise self.error('FORG0001', err) from None
68        except TypeError as err:
69            raise self.error('XPTY0004', err)
70
71    return aux(OPERATORS_MAP[self.symbol])
72
73@method(function('bits-of', nargs=1))
74def evaluate_bits_of(self, context=None):
75    op = self.get_argument(context)
76
77    try:
78        value = hex_to_int(op)
79        for idx, bit in enumerate(reversed(bin(value)[2:])):
80            if bit == '1':
81                yield idx
82    except TypeError as err:
83        raise self.error('XPTY0004', err)
84
85@method(function('has', nargs=2))
86def evaluate_has_function(self, context=None):
87    arg2 = self.get_argument(context, index=1, cls=str)
88    for item in self[0].select(context):
89        value = self.data_value(item)
90        if value == arg2:
91            return True
92    return False
93
94@method(function('duplicate-values', nargs=1))
95def select_duplicate_values_function(self, context=None):
96    def duplicate_values():
97        results = []
98        reported = []
99        for item in self[0].select(context):
100            value = self.data_value(item)
101            if context is not None:
102                context.item = value
103
104            if value in results:
105                if value not in reported:
106                    yield value
107                    reported.append(value)
108            else:
109                results.append(value)
110
111    yield from duplicate_values()
112
113@method(function('number-of-clos-id-needed', nargs=1))
114def evaluate_number_of_clos_id_needed(self, context=None):
115    op = self.get_argument(context, index=0)
116    if op is not None:
117        if isinstance(op, elementpath.TypedElement):
118            op = op.elem
119
120        # This function may be invoked when the xmlschema library parses the data check schemas, in which case `op` will
121        # be an object of class Xsd11Element. Only attempt to calculate the needed CLOS IDs when a real acrn-config node
122        # is given.
123        if hasattr(op, "xpath"):
124            return len(rdt.get_policy_list(op))
125
126    return 0
127
128###
129# Collection of counter examples
130
131class Hashable:
132    def __init__(self, obj):
133        self.obj = obj
134
135    def __hash__(self):
136        return id(self.obj)
137
138def copy_context(context):
139    ret = copy(context)
140    if hasattr(context, 'counter_example'):
141        ret.counter_example = dict()
142    return ret
143
144def add_counter_example(context, private_context, kvlist):
145    if hasattr(context, 'counter_example'):
146        context.counter_example.update(kvlist)
147        if private_context:
148            context.counter_example.update(private_context.counter_example)
149
150@method('every')
151@method('some')
152def evaluate(self, context=None):
153    if context is None:
154        raise self.missing_context()
155
156    some = self.symbol == 'some'
157    varrefs = [Hashable(self[k]) for k in range(0, len(self) - 1, 2)]
158    varnames = [self[k][0].value for k in range(0, len(self) - 1, 2)]
159    selectors = [self[k].select for k in range(1, len(self) - 1, 2)]
160
161    for results in copy(context).iter_product(selectors, varnames):
162        private_context = copy_context(context)
163        private_context.variables.update(x for x in zip(varnames, results))
164        if self.boolean_value([x for x in self[-1].select(private_context)]):
165            if some:
166                add_counter_example(context, private_context, zip(varrefs, results))
167                return True
168        elif not some:
169            add_counter_example(context, private_context, zip(varrefs, results))
170            return False
171
172    return not some
173
174elementpath.XPath2Parser = CustomParser
175