1#!/usr/bin/env python3
2
3"""
4This is a middle-processor for MicroPython source files.  It takes the output
5of the C preprocessor, has the option to change it, then feeds this into the
6C compiler.
7
8It currently has the ability to reorder static hash tables so they are actually
9hashed, resulting in faster lookup times at runtime.
10
11To use, configure the Python variables below, and add the following line to the
12Makefile:
13
14CFLAGS += -no-integrated-cpp -B$(shell pwd)/../tools
15"""
16
17import sys
18import os
19import re
20
21################################################################################
22# these are the configuration variables
23# TODO somehow make them externally configurable
24
25# this is the path to the true C compiler
26cc1_path = '/usr/lib/gcc/x86_64-unknown-linux-gnu/5.3.0/cc1'
27#cc1_path = '/usr/lib/gcc/arm-none-eabi/5.3.0/cc1'
28
29# this must be the same as MICROPY_QSTR_BYTES_IN_HASH
30bytes_in_qstr_hash = 2
31
32# this must be 1 or more (can be a decimal)
33# larger uses more code size but yields faster lookups
34table_size_mult = 1
35
36# these control output during processing
37print_stats = True
38print_debug = False
39
40# end configuration variables
41################################################################################
42
43# precompile regexs
44re_preproc_line = re.compile(r'# [0-9]+ ')
45re_map_entry = re.compile(r'\{.+?\(MP_QSTR_([A-Za-z0-9_]+)\).+\},')
46re_mp_obj_dict_t = re.compile(r'(?P<head>(static )?const mp_obj_dict_t (?P<id>[a-z0-9_]+) = \{ \.base = \{&mp_type_dict\}, \.map = \{ \.all_keys_are_qstrs = 1, \.is_fixed = 1, \.is_ordered = )1(?P<tail>, \.used = .+ };)$')
47re_mp_map_t = re.compile(r'(?P<head>(static )?const mp_map_t (?P<id>[a-z0-9_]+) = \{ \.all_keys_are_qstrs = 1, \.is_fixed = 1, \.is_ordered = )1(?P<tail>, \.used = .+ };)$')
48re_mp_rom_map_elem_t = re.compile(r'static const mp_rom_map_elem_t [a-z_0-9]+\[\] = {$')
49
50# this must match the equivalent function in qstr.c
51def compute_hash(qstr):
52    hash = 5381
53    for char in qstr:
54        hash = (hash * 33) ^ ord(char)
55    # Make sure that valid hash is never zero, zero means "hash not computed"
56    return (hash & ((1 << (8 * bytes_in_qstr_hash)) - 1)) or 1
57
58# this algo must match the equivalent in map.c
59def hash_insert(map, key, value):
60    hash = compute_hash(key)
61    pos = hash % len(map)
62    start_pos = pos
63    if print_debug:
64        print('  insert %s: start at %u/%u -- ' % (key, pos, len(map)), end='')
65    while True:
66        if map[pos] is None:
67            # found empty slot, so key is not in table
68            if print_debug:
69                print('put at %u' % pos)
70            map[pos] = (key, value)
71            return
72        else:
73            # not yet found, keep searching
74            if map[pos][0] == key:
75                raise AssertionError("duplicate key '%s'" % (key,))
76            pos = (pos + 1) % len(map)
77            assert pos != start_pos
78
79def hash_find(map, key):
80    hash = compute_hash(key)
81    pos = hash % len(map)
82    start_pos = pos
83    attempts = 0
84    while True:
85        attempts += 1
86        if map[pos] is None:
87            return attempts, None
88        elif map[pos][0] == key:
89            return attempts, map[pos][1]
90        else:
91            pos = (pos + 1) % len(map)
92            if pos == start_pos:
93                return attempts, None
94
95def process_map_table(file, line, output):
96    output.append(line)
97
98    # consume all lines that are entries of the table and concat them
99    # (we do it this way because there can be multiple entries on one line)
100    table_contents = []
101    while True:
102        line = file.readline()
103        if len(line) == 0:
104            print('unexpected end of input')
105            sys.exit(1)
106        line = line.strip()
107        if len(line) == 0:
108            # empty line
109            continue
110        if re_preproc_line.match(line):
111            # preprocessor line number comment
112            continue
113        if line == '};':
114            # end of table (we assume it appears on a single line)
115            break
116        table_contents.append(line)
117
118    # make combined string of entries
119    entries_str = ''.join(table_contents)
120
121    # split into individual entries
122    entries = []
123    while entries_str:
124        # look for single entry, by matching nested braces
125        match = None
126        if entries_str[0] == '{':
127            nested_braces = 0
128            for i in range(len(entries_str)):
129                if entries_str[i] == '{':
130                    nested_braces += 1
131                elif entries_str[i] == '}':
132                    nested_braces -= 1
133                    if nested_braces == 0:
134                        match = re_map_entry.match(entries_str[:i + 2])
135                        break
136
137        if not match:
138            print('unknown line in table:', entries_str)
139            sys.exit(1)
140
141        # extract single entry
142        line = match.group(0)
143        qstr = match.group(1)
144        entries_str = entries_str[len(line):].lstrip()
145
146        # add the qstr and the whole line to list of all entries
147        entries.append((qstr, line))
148
149    # sort entries so hash table construction is deterministic
150    entries.sort()
151
152    # create hash table
153    map = [None] * int(len(entries) * table_size_mult)
154    for qstr, line in entries:
155        # We assume that qstr does not have any escape sequences in it.
156        # This is reasonably safe, since keys in a module or class dict
157        # should be standard identifiers.
158        # TODO verify this and raise an error if escape sequence found
159        hash_insert(map, qstr, line)
160
161    # compute statistics
162    total_attempts = 0
163    for qstr, _ in entries:
164        attempts, line = hash_find(map, qstr)
165        assert line is not None
166        if print_debug:
167            print('  %s lookup took %u attempts' % (qstr, attempts))
168        total_attempts += attempts
169    if len(entries):
170        stats = len(map), len(entries) / len(map), total_attempts / len(entries)
171    else:
172        stats = 0, 0, 0
173    if print_debug:
174        print('  table stats: size=%d, load=%.2f, avg_lookups=%.1f' % stats)
175
176    # output hash table
177    for row in map:
178        if row is None:
179            output.append('{ 0, 0 },\n')
180        else:
181            output.append(row[1] + '\n')
182    output.append('};\n')
183
184    # skip to next non-blank line
185    while True:
186        line = file.readline()
187        if len(line) == 0:
188            print('unexpected end of input')
189            sys.exit(1)
190        line = line.strip()
191        if len(line) == 0:
192            continue
193        break
194
195    # transform the is_ordered param from 1 to 0
196    match = re_mp_obj_dict_t.match(line)
197    if match is None:
198        match = re_mp_map_t.match(line)
199    if match is None:
200        print('expecting mp_obj_dict_t or mp_map_t definition')
201        print(output[0])
202        print(line)
203        sys.exit(1)
204    line = match.group('head') + '0' + match.group('tail') + '\n'
205    output.append(line)
206
207    return (match.group('id'),) + stats
208
209def process_file(filename):
210    output = []
211    file_changed = False
212    with open(filename, 'rt') as f:
213        while True:
214            line = f.readline()
215            if not line:
216                break
217            if re_mp_rom_map_elem_t.match(line):
218                file_changed = True
219                stats = process_map_table(f, line, output)
220                if print_stats:
221                    print('  [%s: size=%d, load=%.2f, avg_lookups=%.1f]' % stats)
222            else:
223                output.append(line)
224
225    if file_changed:
226        if print_debug:
227            print('  modifying static maps in', output[0].strip())
228        with open(filename, 'wt') as f:
229            for line in output:
230                f.write(line)
231
232def main():
233    # run actual C compiler
234    # need to quote args that have special characters in them
235    def quote(s):
236        if s.find('<') != -1 or s.find('>') != -1:
237            return "'" + s + "'"
238        else:
239            return s
240    ret = os.system(cc1_path + ' ' + ' '.join(quote(s) for s in sys.argv[1:]))
241    if ret != 0:
242        ret = (ret & 0x7f) or 127 # make it in range 0-127, but non-zero
243        sys.exit(ret)
244
245    if sys.argv[1] == '-E':
246        # CPP has been run, now do our processing stage
247        for i, arg in enumerate(sys.argv):
248            if arg == '-o':
249                return process_file(sys.argv[i + 1])
250
251        print('%s: could not find "-o" option' % (sys.argv[0],))
252        sys.exit(1)
253    elif sys.argv[1] == '-fpreprocessed':
254        # compiler has been run, nothing more to do
255        return
256    else:
257        # unknown processing stage
258        print('%s: unknown first option "%s"' % (sys.argv[0], sys.argv[1]))
259        sys.exit(1)
260
261if __name__ == '__main__':
262    main()
263