1 // SPDX-License-Identifier: GPL-2.0
2 
3 /*
4  * Important notes about in-place decompression
5  *
6  * At least on x86, the kernel is decompressed in place: the compressed data
7  * is placed to the end of the output buffer, and the decompressor overwrites
8  * most of the compressed data. There must be enough safety margin to
9  * guarantee that the write position is always behind the read position.
10  *
11  * The safety margin for ZSTD with a 128 KB block size is calculated below.
12  * Note that the margin with ZSTD is bigger than with GZIP or XZ!
13  *
14  * The worst case for in-place decompression is that the beginning of
15  * the file is compressed extremely well, and the rest of the file is
16  * uncompressible. Thus, we must look for worst-case expansion when the
17  * compressor is encoding uncompressible data.
18  *
19  * The structure of the .zst file in case of a compressed kernel is as follows.
20  * Maximum sizes (as bytes) of the fields are in parenthesis.
21  *
22  *    Frame Header: (18)
23  *    Blocks: (N)
24  *    Checksum: (4)
25  *
26  * The frame header and checksum overhead is at most 22 bytes.
27  *
28  * ZSTD stores the data in blocks. Each block has a header whose size is
29  * a 3 bytes. After the block header, there is up to 128 KB of payload.
30  * The maximum uncompressed size of the payload is 128 KB. The minimum
31  * uncompressed size of the payload is never less than the payload size
32  * (excluding the block header).
33  *
34  * The assumption, that the uncompressed size of the payload is never
35  * smaller than the payload itself, is valid only when talking about
36  * the payload as a whole. It is possible that the payload has parts where
37  * the decompressor consumes more input than it produces output. Calculating
38  * the worst case for this would be tricky. Instead of trying to do that,
39  * let's simply make sure that the decompressor never overwrites any bytes
40  * of the payload which it is currently reading.
41  *
42  * Now we have enough information to calculate the safety margin. We need
43  *   - 22 bytes for the .zst file format headers;
44  *   - 3 bytes per every 128 KiB of uncompressed size (one block header per
45  *     block); and
46  *   - 128 KiB (biggest possible zstd block size) to make sure that the
47  *     decompressor never overwrites anything from the block it is currently
48  *     reading.
49  *
50  * We get the following formula:
51  *
52  *    safety_margin = 22 + uncompressed_size * 3 / 131072 + 131072
53  *                 <= 22 + (uncompressed_size >> 15) + 131072
54  */
55 
56 #include "decompress.h"
57 
58 #include "zstd/entropy_common.c"
59 #include "zstd/fse_decompress.c"
60 #include "zstd/huf_decompress.c"
61 #include "zstd/zstd_common.c"
62 #include "zstd/decompress.c"
63 
64 /* 128MB is the maximum window size supported by zstd. */
65 #define ZSTD_WINDOWSIZE_MAX	(1 << ZSTD_WINDOWLOG_MAX)
66 /*
67  * Size of the input and output buffers in multi-call mode.
68  * Pick a larger size because it isn't used during kernel decompression,
69  * since that is single pass, and we have to allocate a large buffer for
70  * zstd's window anyway. The larger size speeds up initramfs decompression.
71  */
72 #define ZSTD_IOBUF_SIZE		(1 << 17)
73 
handle_zstd_error(size_t ret,void (* error)(const char * x))74 static int __init handle_zstd_error(size_t ret, void (*error)(const char *x))
75 {
76 	const int err = ZSTD_getErrorCode(ret);
77 
78 	if (!ZSTD_isError(ret))
79 		return 0;
80 
81 	switch (err) {
82 	case ZSTD_error_memory_allocation:
83 		error("ZSTD decompressor ran out of memory");
84 		break;
85 	case ZSTD_error_prefix_unknown:
86 		error("Input is not in the ZSTD format (wrong magic bytes)");
87 		break;
88 	case ZSTD_error_dstSize_tooSmall:
89 	case ZSTD_error_corruption_detected:
90 	case ZSTD_error_checksum_wrong:
91 		error("ZSTD-compressed data is corrupt");
92 		break;
93 	default:
94 		error("ZSTD-compressed data is probably corrupt");
95 		break;
96 	}
97 	return -1;
98 }
99 
100 /*
101  * Handle the case where we have the entire input and output in one segment.
102  * We can allocate less memory (no circular buffer for the sliding window),
103  * and avoid some memcpy() calls.
104  */
decompress_single(const u8 * in_buf,long in_len,u8 * out_buf,long out_len,unsigned int * in_pos,void (* error)(const char * x))105 static int __init decompress_single(const u8 *in_buf, long in_len, u8 *out_buf,
106 				    long out_len, unsigned int *in_pos,
107 				    void (*error)(const char *x))
108 {
109 	const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
110 	void *wksp = large_malloc(wksp_size);
111 	ZSTD_DCtx *dctx = ZSTD_initDCtx(wksp, wksp_size);
112 	int err;
113 	size_t ret;
114 
115 	if (dctx == NULL) {
116 		error("Out of memory while allocating ZSTD_DCtx");
117 		err = -1;
118 		goto out;
119 	}
120 	/*
121 	 * Find out how large the frame actually is, there may be junk at
122 	 * the end of the frame that ZSTD_decompressDCtx() can't handle.
123 	 */
124 	ret = ZSTD_findFrameCompressedSize(in_buf, in_len);
125 	err = handle_zstd_error(ret, error);
126 	if (err)
127 		goto out;
128 	in_len = (long)ret;
129 
130 	ret = ZSTD_decompressDCtx(dctx, out_buf, out_len, in_buf, in_len);
131 	err = handle_zstd_error(ret, error);
132 	if (err)
133 		goto out;
134 
135 	if (in_pos != NULL)
136 		*in_pos = in_len;
137 
138 	err = 0;
139 out:
140 	if (wksp != NULL)
141 		large_free(wksp);
142 	return err;
143 }
144 
unzstd(unsigned char * in_buf,unsigned int in_len,int (* fill)(void *,unsigned int),int (* flush)(void *,unsigned int),unsigned char * out_buf,unsigned int * in_pos,void (* error)(const char * x))145 int __init unzstd(unsigned char *in_buf, unsigned int in_len,
146 		  int (*fill)(void*, unsigned int),
147 		  int (*flush)(void*, unsigned int),
148 		  unsigned char *out_buf, unsigned int *in_pos,
149 		  void (*error)(const char *x))
150 {
151 	ZSTD_inBuffer in;
152 	ZSTD_outBuffer out;
153 	ZSTD_frameParams params;
154 	void *in_allocated = NULL;
155 	void *out_allocated = NULL;
156 	void *wksp = NULL;
157 	size_t wksp_size;
158 	ZSTD_DStream *dstream;
159 	int err;
160 	size_t ret;
161 	/*
162 	 * ZSTD decompression code won't be happy if the buffer size is so big
163 	 * that its end address overflows. When the size is not provided, make
164 	 * it as big as possible without having the end address overflow.
165 	 */
166 	unsigned long out_len = ULONG_MAX - (unsigned long)out_buf;
167 
168 	if (fill == NULL && flush == NULL)
169 		/*
170 		 * We can decompress faster and with less memory when we have a
171 		 * single chunk.
172 		 */
173 		return decompress_single(in_buf, in_len, out_buf, out_len,
174 					 in_pos, error);
175 
176 	/*
177 	 * If in_buf is not provided, we must be using fill(), so allocate
178 	 * a large enough buffer. If it is provided, it must be at least
179 	 * ZSTD_IOBUF_SIZE large.
180 	 */
181 	if (in_buf == NULL) {
182 		in_allocated = large_malloc(ZSTD_IOBUF_SIZE);
183 		if (in_allocated == NULL) {
184 			error("Out of memory while allocating input buffer");
185 			err = -1;
186 			goto out;
187 		}
188 		in_buf = in_allocated;
189 		in_len = 0;
190 	}
191 	/* Read the first chunk, since we need to decode the frame header. */
192 	if (fill != NULL)
193 		in_len = fill(in_buf, ZSTD_IOBUF_SIZE);
194 	if ((int)in_len < 0) {
195 		error("ZSTD-compressed data is truncated");
196 		err = -1;
197 		goto out;
198 	}
199 	/* Set the first non-empty input buffer. */
200 	in.src = in_buf;
201 	in.pos = 0;
202 	in.size = in_len;
203 	/* Allocate the output buffer if we are using flush(). */
204 	if (flush != NULL) {
205 		out_allocated = large_malloc(ZSTD_IOBUF_SIZE);
206 		if (out_allocated == NULL) {
207 			error("Out of memory while allocating output buffer");
208 			err = -1;
209 			goto out;
210 		}
211 		out_buf = out_allocated;
212 		out_len = ZSTD_IOBUF_SIZE;
213 	}
214 	/* Set the output buffer. */
215 	out.dst = out_buf;
216 	out.pos = 0;
217 	out.size = out_len;
218 
219 	/*
220 	 * We need to know the window size to allocate the ZSTD_DStream.
221 	 * Since we are streaming, we need to allocate a buffer for the sliding
222 	 * window. The window size varies from 1 KB to ZSTD_WINDOWSIZE_MAX
223 	 * (8 MB), so it is important to use the actual value so as not to
224 	 * waste memory when it is smaller.
225 	 */
226 	ret = ZSTD_getFrameParams(&params, in.src, in.size);
227 	err = handle_zstd_error(ret, error);
228 	if (err)
229 		goto out;
230 	if (ret != 0) {
231 		error("ZSTD-compressed data has an incomplete frame header");
232 		err = -1;
233 		goto out;
234 	}
235 	if (params.windowSize > ZSTD_WINDOWSIZE_MAX) {
236 		error("ZSTD-compressed data has too large a window size");
237 		err = -1;
238 		goto out;
239 	}
240 
241 	/*
242 	 * Allocate the ZSTD_DStream now that we know how much memory is
243 	 * required.
244 	 */
245 	wksp_size = ZSTD_DStreamWorkspaceBound(params.windowSize);
246 	wksp = large_malloc(wksp_size);
247 	dstream = ZSTD_initDStream(params.windowSize, wksp, wksp_size);
248 	if (dstream == NULL) {
249 		error("Out of memory while allocating ZSTD_DStream");
250 		err = -1;
251 		goto out;
252 	}
253 
254 	/*
255 	 * Decompression loop:
256 	 * Read more data if necessary (error if no more data can be read).
257 	 * Call the decompression function, which returns 0 when finished.
258 	 * Flush any data produced if using flush().
259 	 */
260 	if (in_pos != NULL)
261 		*in_pos = 0;
262 	do {
263 		/*
264 		 * If we need to reload data, either we have fill() and can
265 		 * try to get more data, or we don't and the input is truncated.
266 		 */
267 		if (in.pos == in.size) {
268 			if (in_pos != NULL)
269 				*in_pos += in.pos;
270 			in_len = fill ? fill(in_buf, ZSTD_IOBUF_SIZE) : -1;
271 			if ((int)in_len < 0) {
272 				error("ZSTD-compressed data is truncated");
273 				err = -1;
274 				goto out;
275 			}
276 			in.pos = 0;
277 			in.size = in_len;
278 		}
279 		/* Returns zero when the frame is complete. */
280 		ret = ZSTD_decompressStream(dstream, &out, &in);
281 		err = handle_zstd_error(ret, error);
282 		if (err)
283 			goto out;
284 		/* Flush all of the data produced if using flush(). */
285 		if (flush != NULL && out.pos > 0) {
286 			if (out.pos != flush(out.dst, out.pos)) {
287 				error("Failed to flush()");
288 				err = -1;
289 				goto out;
290 			}
291 			out.pos = 0;
292 		}
293 	} while (ret != 0);
294 
295 	if (in_pos != NULL)
296 		*in_pos += in.pos;
297 
298 	err = 0;
299 out:
300 	if (in_allocated != NULL)
301 		large_free(in_allocated);
302 	if (out_allocated != NULL)
303 		large_free(out_allocated);
304 	if (wksp != NULL)
305 		large_free(wksp);
306 	return err;
307 }
308