1 // SPDX-License-Identifier: GPL-2.0+
2 #include "decompress.h"
3 
4 #if IS_ENABLED(CONFIG_ZLIB)
5 #include <u-boot/zlib.h>
6 
7 /* report a zlib or i/o error */
zerr(int ret)8 static int zerr(int ret)
9 {
10 	switch (ret) {
11 	case Z_STREAM_ERROR:
12 		return -EINVAL;
13 	case Z_DATA_ERROR:
14 		return -EIO;
15 	case Z_MEM_ERROR:
16 		return -ENOMEM;
17 	case Z_ERRNO:
18 	default:
19 		return -EFAULT;
20 	}
21 }
22 
z_erofs_decompress_deflate(struct z_erofs_decompress_req * rq)23 static int z_erofs_decompress_deflate(struct z_erofs_decompress_req *rq)
24 {
25 	u8 *dest = (u8 *)rq->out;
26 	u8 *src = (u8 *)rq->in;
27 	u8 *buff = NULL;
28 	unsigned int inputmargin = 0;
29 	z_stream strm;
30 	int ret;
31 
32 	while (!src[inputmargin & (erofs_blksiz() - 1)])
33 		if (!(++inputmargin & (erofs_blksiz() - 1)))
34 			break;
35 
36 	if (inputmargin >= rq->inputsize)
37 		return -EFSCORRUPTED;
38 
39 	if (rq->decodedskip) {
40 		buff = malloc(rq->decodedlength);
41 		if (!buff)
42 			return -ENOMEM;
43 		dest = buff;
44 	}
45 
46 	/* allocate inflate state */
47 	strm.zalloc = Z_NULL;
48 	strm.zfree = Z_NULL;
49 	strm.opaque = Z_NULL;
50 	strm.avail_in = 0;
51 	strm.next_in = Z_NULL;
52 	ret = inflateInit2(&strm, -15);
53 	if (ret != Z_OK) {
54 		free(buff);
55 		return zerr(ret);
56 	}
57 
58 	strm.next_in = src + inputmargin;
59 	strm.avail_in = rq->inputsize - inputmargin;
60 	strm.next_out = dest;
61 	strm.avail_out = rq->decodedlength;
62 
63 	ret = inflate(&strm, rq->partial_decoding ? Z_SYNC_FLUSH : Z_FINISH);
64 	if (ret != Z_STREAM_END || strm.total_out != rq->decodedlength) {
65 		if (ret != Z_OK || !rq->partial_decoding) {
66 			ret = zerr(ret);
67 			goto out_inflate_end;
68 		}
69 	}
70 
71 	if (rq->decodedskip)
72 		memcpy(rq->out, dest + rq->decodedskip,
73 		       rq->decodedlength - rq->decodedskip);
74 
75 out_inflate_end:
76 	inflateEnd(&strm);
77 	if (buff)
78 		free(buff);
79 	return ret;
80 }
81 #endif
82 
83 #if IS_ENABLED(CONFIG_LZ4)
84 #include <u-boot/lz4.h>
z_erofs_decompress_lz4(struct z_erofs_decompress_req * rq)85 static int z_erofs_decompress_lz4(struct z_erofs_decompress_req *rq)
86 {
87 	int ret = 0;
88 	char *dest = rq->out;
89 	char *src = rq->in;
90 	char *buff = NULL;
91 	bool support_0padding = false;
92 	unsigned int inputmargin = 0;
93 
94 	if (erofs_sb_has_lz4_0padding()) {
95 		support_0padding = true;
96 
97 		while (!src[inputmargin & (erofs_blksiz() - 1)])
98 			if (!(++inputmargin & (erofs_blksiz() - 1)))
99 				break;
100 
101 		if (inputmargin >= rq->inputsize)
102 			return -EIO;
103 	}
104 
105 	if (rq->decodedskip) {
106 		buff = malloc(rq->decodedlength);
107 		if (!buff)
108 			return -ENOMEM;
109 		dest = buff;
110 	}
111 
112 	if (rq->partial_decoding || !support_0padding)
113 		ret = LZ4_decompress_safe_partial(src + inputmargin, dest,
114 						  rq->inputsize - inputmargin,
115 						  rq->decodedlength, rq->decodedlength);
116 	else
117 		ret = LZ4_decompress_safe(src + inputmargin, dest,
118 					  rq->inputsize - inputmargin,
119 					  rq->decodedlength);
120 
121 	if (ret != (int)rq->decodedlength) {
122 		erofs_err("failed to %s decompress %d in[%u, %u] out[%u]",
123 			  rq->partial_decoding ? "partial" : "full",
124 			  ret, rq->inputsize, inputmargin, rq->decodedlength);
125 		ret = -EIO;
126 		goto out;
127 	}
128 
129 	if (rq->decodedskip)
130 		memcpy(rq->out, dest + rq->decodedskip,
131 		       rq->decodedlength - rq->decodedskip);
132 
133 out:
134 	if (buff)
135 		free(buff);
136 
137 	return ret;
138 }
139 #endif
140 
z_erofs_decompress(struct z_erofs_decompress_req * rq)141 int z_erofs_decompress(struct z_erofs_decompress_req *rq)
142 {
143 	if (rq->alg == Z_EROFS_COMPRESSION_INTERLACED) {
144 		unsigned int count, rightpart, skip;
145 
146 		/* XXX: should support inputsize >= erofs_blksiz() later */
147 		if (rq->inputsize > erofs_blksiz())
148 			return -EFSCORRUPTED;
149 
150 		if (rq->decodedlength > erofs_blksiz())
151 			return -EFSCORRUPTED;
152 
153 		if (rq->decodedlength < rq->decodedskip)
154 			return -EFSCORRUPTED;
155 
156 		count = rq->decodedlength - rq->decodedskip;
157 		skip = erofs_blkoff(rq->interlaced_offset + rq->decodedskip);
158 		rightpart = min(erofs_blksiz() - skip, count);
159 		memcpy(rq->out, rq->in + skip, rightpart);
160 		memcpy(rq->out + rightpart, rq->in, count - rightpart);
161 		return 0;
162 	} else if (rq->alg == Z_EROFS_COMPRESSION_SHIFTED) {
163 		if (rq->decodedlength > rq->inputsize)
164 			return -EFSCORRUPTED;
165 
166 		DBG_BUGON(rq->decodedlength < rq->decodedskip);
167 		memcpy(rq->out, rq->in + rq->decodedskip,
168 		       rq->decodedlength - rq->decodedskip);
169 		return 0;
170 	}
171 
172 #if IS_ENABLED(CONFIG_LZ4)
173 	if (rq->alg == Z_EROFS_COMPRESSION_LZ4)
174 		return z_erofs_decompress_lz4(rq);
175 #endif
176 #if IS_ENABLED(CONFIG_ZLIB)
177 	if (rq->alg == Z_EROFS_COMPRESSION_DEFLATE)
178 		return z_erofs_decompress_deflate(rq);
179 #endif
180 	return -EOPNOTSUPP;
181 }
182