1 // SPDX-License-Identifier: GPL-2.0
2
3 /*
4 * Important notes about in-place decompression
5 *
6 * At least on x86, the kernel is decompressed in place: the compressed data
7 * is placed to the end of the output buffer, and the decompressor overwrites
8 * most of the compressed data. There must be enough safety margin to
9 * guarantee that the write position is always behind the read position.
10 *
11 * The safety margin for ZSTD with a 128 KB block size is calculated below.
12 * Note that the margin with ZSTD is bigger than with GZIP or XZ!
13 *
14 * The worst case for in-place decompression is that the beginning of
15 * the file is compressed extremely well, and the rest of the file is
16 * uncompressible. Thus, we must look for worst-case expansion when the
17 * compressor is encoding uncompressible data.
18 *
19 * The structure of the .zst file in case of a compressed kernel is as follows.
20 * Maximum sizes (as bytes) of the fields are in parenthesis.
21 *
22 * Frame Header: (18)
23 * Blocks: (N)
24 * Checksum: (4)
25 *
26 * The frame header and checksum overhead is at most 22 bytes.
27 *
28 * ZSTD stores the data in blocks. Each block has a header whose size is
29 * a 3 bytes. After the block header, there is up to 128 KB of payload.
30 * The maximum uncompressed size of the payload is 128 KB. The minimum
31 * uncompressed size of the payload is never less than the payload size
32 * (excluding the block header).
33 *
34 * The assumption, that the uncompressed size of the payload is never
35 * smaller than the payload itself, is valid only when talking about
36 * the payload as a whole. It is possible that the payload has parts where
37 * the decompressor consumes more input than it produces output. Calculating
38 * the worst case for this would be tricky. Instead of trying to do that,
39 * let's simply make sure that the decompressor never overwrites any bytes
40 * of the payload which it is currently reading.
41 *
42 * Now we have enough information to calculate the safety margin. We need
43 * - 22 bytes for the .zst file format headers;
44 * - 3 bytes per every 128 KiB of uncompressed size (one block header per
45 * block); and
46 * - 128 KiB (biggest possible zstd block size) to make sure that the
47 * decompressor never overwrites anything from the block it is currently
48 * reading.
49 *
50 * We get the following formula:
51 *
52 * safety_margin = 22 + uncompressed_size * 3 / 131072 + 131072
53 * <= 22 + (uncompressed_size >> 15) + 131072
54 */
55
56 #include "decompress.h"
57
58 #include "zstd/entropy_common.c"
59 #include "zstd/fse_decompress.c"
60 #include "zstd/huf_decompress.c"
61 #include "zstd/zstd_common.c"
62 #include "zstd/decompress.c"
63
64 /* 128MB is the maximum window size supported by zstd. */
65 #define ZSTD_WINDOWSIZE_MAX (1 << ZSTD_WINDOWLOG_MAX)
66 /*
67 * Size of the input and output buffers in multi-call mode.
68 * Pick a larger size because it isn't used during kernel decompression,
69 * since that is single pass, and we have to allocate a large buffer for
70 * zstd's window anyway. The larger size speeds up initramfs decompression.
71 */
72 #define ZSTD_IOBUF_SIZE (1 << 17)
73
handle_zstd_error(size_t ret,void (* error)(const char * x))74 static int __init handle_zstd_error(size_t ret, void (*error)(const char *x))
75 {
76 const int err = ZSTD_getErrorCode(ret);
77
78 if (!ZSTD_isError(ret))
79 return 0;
80
81 switch (err) {
82 case ZSTD_error_memory_allocation:
83 error("ZSTD decompressor ran out of memory");
84 break;
85 case ZSTD_error_prefix_unknown:
86 error("Input is not in the ZSTD format (wrong magic bytes)");
87 break;
88 case ZSTD_error_dstSize_tooSmall:
89 case ZSTD_error_corruption_detected:
90 case ZSTD_error_checksum_wrong:
91 error("ZSTD-compressed data is corrupt");
92 break;
93 default:
94 error("ZSTD-compressed data is probably corrupt");
95 break;
96 }
97 return -1;
98 }
99
100 /*
101 * Handle the case where we have the entire input and output in one segment.
102 * We can allocate less memory (no circular buffer for the sliding window),
103 * and avoid some memcpy() calls.
104 */
decompress_single(const u8 * in_buf,long in_len,u8 * out_buf,long out_len,unsigned int * in_pos,void (* error)(const char * x))105 static int __init decompress_single(const u8 *in_buf, long in_len, u8 *out_buf,
106 long out_len, unsigned int *in_pos,
107 void (*error)(const char *x))
108 {
109 const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
110 void *wksp = large_malloc(wksp_size);
111 ZSTD_DCtx *dctx = ZSTD_initDCtx(wksp, wksp_size);
112 int err;
113 size_t ret;
114
115 if (dctx == NULL) {
116 error("Out of memory while allocating ZSTD_DCtx");
117 err = -1;
118 goto out;
119 }
120 /*
121 * Find out how large the frame actually is, there may be junk at
122 * the end of the frame that ZSTD_decompressDCtx() can't handle.
123 */
124 ret = ZSTD_findFrameCompressedSize(in_buf, in_len);
125 err = handle_zstd_error(ret, error);
126 if (err)
127 goto out;
128 in_len = (long)ret;
129
130 ret = ZSTD_decompressDCtx(dctx, out_buf, out_len, in_buf, in_len);
131 err = handle_zstd_error(ret, error);
132 if (err)
133 goto out;
134
135 if (in_pos != NULL)
136 *in_pos = in_len;
137
138 err = 0;
139 out:
140 if (wksp != NULL)
141 large_free(wksp);
142 return err;
143 }
144
unzstd(unsigned char * in_buf,unsigned int in_len,int (* fill)(void *,unsigned int),int (* flush)(void *,unsigned int),unsigned char * out_buf,unsigned int * in_pos,void (* error)(const char * x))145 int __init unzstd(unsigned char *in_buf, unsigned int in_len,
146 int (*fill)(void*, unsigned int),
147 int (*flush)(void*, unsigned int),
148 unsigned char *out_buf, unsigned int *in_pos,
149 void (*error)(const char *x))
150 {
151 ZSTD_inBuffer in;
152 ZSTD_outBuffer out;
153 ZSTD_frameParams params;
154 void *in_allocated = NULL;
155 void *out_allocated = NULL;
156 void *wksp = NULL;
157 size_t wksp_size;
158 ZSTD_DStream *dstream;
159 int err;
160 size_t ret;
161 /*
162 * ZSTD decompression code won't be happy if the buffer size is so big
163 * that its end address overflows. When the size is not provided, make
164 * it as big as possible without having the end address overflow.
165 */
166 unsigned long out_len = ULONG_MAX - (unsigned long)out_buf;
167
168 if (fill == NULL && flush == NULL)
169 /*
170 * We can decompress faster and with less memory when we have a
171 * single chunk.
172 */
173 return decompress_single(in_buf, in_len, out_buf, out_len,
174 in_pos, error);
175
176 /*
177 * If in_buf is not provided, we must be using fill(), so allocate
178 * a large enough buffer. If it is provided, it must be at least
179 * ZSTD_IOBUF_SIZE large.
180 */
181 if (in_buf == NULL) {
182 in_allocated = large_malloc(ZSTD_IOBUF_SIZE);
183 if (in_allocated == NULL) {
184 error("Out of memory while allocating input buffer");
185 err = -1;
186 goto out;
187 }
188 in_buf = in_allocated;
189 in_len = 0;
190 }
191 /* Read the first chunk, since we need to decode the frame header. */
192 if (fill != NULL)
193 in_len = fill(in_buf, ZSTD_IOBUF_SIZE);
194 if ((int)in_len < 0) {
195 error("ZSTD-compressed data is truncated");
196 err = -1;
197 goto out;
198 }
199 /* Set the first non-empty input buffer. */
200 in.src = in_buf;
201 in.pos = 0;
202 in.size = in_len;
203 /* Allocate the output buffer if we are using flush(). */
204 if (flush != NULL) {
205 out_allocated = large_malloc(ZSTD_IOBUF_SIZE);
206 if (out_allocated == NULL) {
207 error("Out of memory while allocating output buffer");
208 err = -1;
209 goto out;
210 }
211 out_buf = out_allocated;
212 out_len = ZSTD_IOBUF_SIZE;
213 }
214 /* Set the output buffer. */
215 out.dst = out_buf;
216 out.pos = 0;
217 out.size = out_len;
218
219 /*
220 * We need to know the window size to allocate the ZSTD_DStream.
221 * Since we are streaming, we need to allocate a buffer for the sliding
222 * window. The window size varies from 1 KB to ZSTD_WINDOWSIZE_MAX
223 * (8 MB), so it is important to use the actual value so as not to
224 * waste memory when it is smaller.
225 */
226 ret = ZSTD_getFrameParams(¶ms, in.src, in.size);
227 err = handle_zstd_error(ret, error);
228 if (err)
229 goto out;
230 if (ret != 0) {
231 error("ZSTD-compressed data has an incomplete frame header");
232 err = -1;
233 goto out;
234 }
235 if (params.windowSize > ZSTD_WINDOWSIZE_MAX) {
236 error("ZSTD-compressed data has too large a window size");
237 err = -1;
238 goto out;
239 }
240
241 /*
242 * Allocate the ZSTD_DStream now that we know how much memory is
243 * required.
244 */
245 wksp_size = ZSTD_DStreamWorkspaceBound(params.windowSize);
246 wksp = large_malloc(wksp_size);
247 dstream = ZSTD_initDStream(params.windowSize, wksp, wksp_size);
248 if (dstream == NULL) {
249 error("Out of memory while allocating ZSTD_DStream");
250 err = -1;
251 goto out;
252 }
253
254 /*
255 * Decompression loop:
256 * Read more data if necessary (error if no more data can be read).
257 * Call the decompression function, which returns 0 when finished.
258 * Flush any data produced if using flush().
259 */
260 if (in_pos != NULL)
261 *in_pos = 0;
262 do {
263 /*
264 * If we need to reload data, either we have fill() and can
265 * try to get more data, or we don't and the input is truncated.
266 */
267 if (in.pos == in.size) {
268 if (in_pos != NULL)
269 *in_pos += in.pos;
270 in_len = fill ? fill(in_buf, ZSTD_IOBUF_SIZE) : -1;
271 if ((int)in_len < 0) {
272 error("ZSTD-compressed data is truncated");
273 err = -1;
274 goto out;
275 }
276 in.pos = 0;
277 in.size = in_len;
278 }
279 /* Returns zero when the frame is complete. */
280 ret = ZSTD_decompressStream(dstream, &out, &in);
281 err = handle_zstd_error(ret, error);
282 if (err)
283 goto out;
284 /* Flush all of the data produced if using flush(). */
285 if (flush != NULL && out.pos > 0) {
286 if (out.pos != flush(out.dst, out.pos)) {
287 error("Failed to flush()");
288 err = -1;
289 goto out;
290 }
291 out.pos = 0;
292 }
293 } while (ret != 0);
294
295 if (in_pos != NULL)
296 *in_pos += in.pos;
297
298 err = 0;
299 out:
300 if (in_allocated != NULL)
301 large_free(in_allocated);
302 if (out_allocated != NULL)
303 large_free(out_allocated);
304 if (wksp != NULL)
305 large_free(wksp);
306 return err;
307 }
308