1from __future__ import print_function
2
3import collections
4import re
5import sys
6
7import gzip
8import zlib
9
10
11_COMPRESSED_MARKER = 0xFF
12
13
14def check_non_ascii(msg):
15    for c in msg:
16        if ord(c) >= 0x80:
17            print(
18                'Unable to generate compressed data: message "{}" contains a non-ascii character "{}".'.format(
19                    msg, c
20                ),
21                file=sys.stderr,
22            )
23            sys.exit(1)
24
25
26# Replace <char><space> with <char | 0x80>.
27# Trival scheme to demo/test.
28def space_compression(error_strings):
29    for line in error_strings:
30        check_non_ascii(line)
31        result = ""
32        for i in range(len(line)):
33            if i > 0 and line[i] == " ":
34                result = result[:-1]
35                result += "\\{:03o}".format(ord(line[i - 1]))
36            else:
37                result += line[i]
38        error_strings[line] = result
39    return None
40
41
42# Replace common words with <0x80 | index>.
43# Index is into a table of words stored as aaaaa<0x80|a>bbb<0x80|b>...
44# Replaced words are assumed to have spaces either side to avoid having to store the spaces in the compressed strings.
45def word_compression(error_strings):
46    topn = collections.Counter()
47
48    for line in error_strings.keys():
49        check_non_ascii(line)
50        for word in line.split(" "):
51            topn[word] += 1
52
53    # Order not just by frequency, but by expected saving. i.e. prefer a longer string that is used less frequently.
54    # Use the word itself for ties so that compression is deterministic.
55    def bytes_saved(item):
56        w, n = item
57        return -((len(w) + 1) * (n - 1)), w
58
59    top128 = sorted(topn.items(), key=bytes_saved)[:128]
60
61    index = [w for w, _ in top128]
62    index_lookup = {w: i for i, w in enumerate(index)}
63
64    for line in error_strings.keys():
65        result = ""
66        need_space = False
67        for word in line.split(" "):
68            if word in index_lookup:
69                result += "\\{:03o}".format(0b10000000 | index_lookup[word])
70                need_space = False
71            else:
72                if need_space:
73                    result += " "
74                need_space = True
75                result += word
76        error_strings[line] = result.strip()
77
78    return "".join(w[:-1] + "\\{:03o}".format(0b10000000 | ord(w[-1])) for w in index)
79
80
81# Replace chars in text with variable length bit sequence.
82# For comparison only (the table is not emitted).
83def huffman_compression(error_strings):
84    # https://github.com/tannewt/huffman
85    import huffman
86
87    all_strings = "".join(error_strings)
88    cb = huffman.codebook(collections.Counter(all_strings).items())
89
90    for line in error_strings:
91        b = "1"
92        for c in line:
93            b += cb[c]
94        n = len(b)
95        if n % 8 != 0:
96            n += 8 - (n % 8)
97        result = ""
98        for i in range(0, n, 8):
99            result += "\\{:03o}".format(int(b[i : i + 8], 2))
100        if len(result) > len(line) * 4:
101            result = line
102        error_strings[line] = result
103
104    # TODO: This would be the prefix lengths and the table ordering.
105    return "_" * (10 + len(cb))
106
107
108# Replace common N-letter sequences with <0x80 | index>, where
109# the common sequences are stored in a separate table.
110# This isn't very useful, need a smarter way to find top-ngrams.
111def ngram_compression(error_strings):
112    topn = collections.Counter()
113    N = 2
114
115    for line in error_strings.keys():
116        check_non_ascii(line)
117        if len(line) < N:
118            continue
119        for i in range(0, len(line) - N, N):
120            topn[line[i : i + N]] += 1
121
122    def bytes_saved(item):
123        w, n = item
124        return -(len(w) * (n - 1))
125
126    top128 = sorted(topn.items(), key=bytes_saved)[:128]
127
128    index = [w for w, _ in top128]
129    index_lookup = {w: i for i, w in enumerate(index)}
130
131    for line in error_strings.keys():
132        result = ""
133        for i in range(0, len(line) - N + 1, N):
134            word = line[i : i + N]
135            if word in index_lookup:
136                result += "\\{:03o}".format(0b10000000 | index_lookup[word])
137            else:
138                result += word
139        if len(line) % N != 0:
140            result += line[len(line) - len(line) % N :]
141        error_strings[line] = result.strip()
142
143    return "".join(index)
144
145
146def main(collected_path, fn):
147    error_strings = collections.OrderedDict()
148    max_uncompressed_len = 0
149    num_uses = 0
150
151    # Read in all MP_ERROR_TEXT strings.
152    with open(collected_path, "r") as f:
153        for line in f:
154            line = line.strip()
155            if not line:
156                continue
157            num_uses += 1
158            error_strings[line] = None
159            max_uncompressed_len = max(max_uncompressed_len, len(line))
160
161    # So that objexcept.c can figure out how big the buffer needs to be.
162    print("#define MP_MAX_UNCOMPRESSED_TEXT_LEN ({})".format(max_uncompressed_len))
163
164    # Run the compression.
165    compressed_data = fn(error_strings)
166
167    # Print the data table.
168    print('MP_COMPRESSED_DATA("{}")'.format(compressed_data))
169
170    # Print the replacements.
171    for uncomp, comp in error_strings.items():
172        if uncomp == comp:
173            prefix = ""
174        else:
175            prefix = "\\{:03o}".format(_COMPRESSED_MARKER)
176        print('MP_MATCH_COMPRESSED("{}", "{}{}")'.format(uncomp, prefix, comp))
177
178    # Used to calculate the "true" length of the (escaped) compressed strings.
179    def unescape(s):
180        return re.sub(r"\\\d\d\d", "!", s)
181
182    # Stats. Note this doesn't include the cost of the decompressor code.
183    uncomp_len = sum(len(s) + 1 for s in error_strings.keys())
184    comp_len = sum(1 + len(unescape(s)) + 1 for s in error_strings.values())
185    data_len = len(compressed_data) + 1 if compressed_data else 0
186    print("// Total input length:      {}".format(uncomp_len))
187    print("// Total compressed length: {}".format(comp_len))
188    print("// Total data length:       {}".format(data_len))
189    print("// Predicted saving:        {}".format(uncomp_len - comp_len - data_len))
190
191    # Somewhat meaningless comparison to zlib/gzip.
192    all_input_bytes = "\\0".join(error_strings.keys()).encode()
193    print()
194    if hasattr(gzip, "compress"):
195        gzip_len = len(gzip.compress(all_input_bytes)) + num_uses * 4
196        print("// gzip length:             {}".format(gzip_len))
197        print("// Percentage of gzip:      {:.1f}%".format(100 * (comp_len + data_len) / gzip_len))
198    if hasattr(zlib, "compress"):
199        zlib_len = len(zlib.compress(all_input_bytes)) + num_uses * 4
200        print("// zlib length:             {}".format(zlib_len))
201        print("// Percentage of zlib:      {:.1f}%".format(100 * (comp_len + data_len) / zlib_len))
202
203
204if __name__ == "__main__":
205    main(sys.argv[1], word_compression)
206