1 // SPDX-License-Identifier: GPL-2.0-or-later
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15 
16 #include "statmount.h"
17 #include "../utils.h"
18 #include "../../kselftest.h"
19 
20 #define NSID_PASS 0
21 #define NSID_FAIL 1
22 #define NSID_SKIP 2
23 #define NSID_ERROR 3
24 
handle_result(int ret,const char * testname)25 static void handle_result(int ret, const char *testname)
26 {
27 	if (ret == NSID_PASS)
28 		ksft_test_result_pass("%s\n", testname);
29 	else if (ret == NSID_FAIL)
30 		ksft_test_result_fail("%s\n", testname);
31 	else if (ret == NSID_ERROR)
32 		ksft_exit_fail_msg("%s\n", testname);
33 	else
34 		ksft_test_result_skip("%s\n", testname);
35 }
36 
wait_for_pid(pid_t pid)37 static inline int wait_for_pid(pid_t pid)
38 {
39 	int status, ret;
40 
41 again:
42 	ret = waitpid(pid, &status, 0);
43 	if (ret == -1) {
44 		if (errno == EINTR)
45 			goto again;
46 
47 		ksft_print_msg("waitpid returned -1, errno=%d\n", errno);
48 		return -1;
49 	}
50 
51 	if (!WIFEXITED(status)) {
52 		ksft_print_msg(
53 		       "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n",
54 		       WIFSIGNALED(status), WTERMSIG(status));
55 		return -1;
56 	}
57 
58 	ret = WEXITSTATUS(status);
59 	return ret;
60 }
61 
get_mnt_ns_id(const char * mnt_ns,uint64_t * mnt_ns_id)62 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
63 {
64 	int fd = open(mnt_ns, O_RDONLY);
65 
66 	if (fd < 0) {
67 		ksft_print_msg("failed to open for ns %s: %s\n",
68 			       mnt_ns, strerror(errno));
69 		sleep(60);
70 		return NSID_ERROR;
71 	}
72 
73 	if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
74 		ksft_print_msg("failed to get the nsid for ns %s: %s\n",
75 			       mnt_ns, strerror(errno));
76 		return NSID_ERROR;
77 	}
78 	close(fd);
79 	return NSID_PASS;
80 }
81 
setup_namespace(void)82 static int setup_namespace(void)
83 {
84 	if (setup_userns() != 0)
85 		return NSID_ERROR;
86 
87 	return NSID_PASS;
88 }
89 
_test_statmount_mnt_ns_id(void)90 static int _test_statmount_mnt_ns_id(void)
91 {
92 	struct statmount sm;
93 	uint64_t mnt_ns_id;
94 	uint64_t root_id;
95 	int ret;
96 
97 	ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
98 	if (ret != NSID_PASS)
99 		return ret;
100 
101 	root_id = get_unique_mnt_id("/");
102 	if (!root_id)
103 		return NSID_ERROR;
104 
105 	ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
106 	if (ret == -1) {
107 		ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
108 		return NSID_ERROR;
109 	}
110 
111 	if (sm.size != sizeof(sm)) {
112 		ksft_print_msg("unexpected size: %u != %u\n", sm.size,
113 			       (uint32_t)sizeof(sm));
114 		return NSID_FAIL;
115 	}
116 	if (sm.mask != STATMOUNT_MNT_NS_ID) {
117 		ksft_print_msg("statmount mnt ns id unavailable\n");
118 		return NSID_SKIP;
119 	}
120 
121 	if (sm.mnt_ns_id != mnt_ns_id) {
122 		ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
123 			       (unsigned long long)sm.mnt_ns_id,
124 			       (unsigned long long)mnt_ns_id);
125 		return NSID_FAIL;
126 	}
127 
128 	return NSID_PASS;
129 }
130 
test_statmount_mnt_ns_id(void)131 static void test_statmount_mnt_ns_id(void)
132 {
133 	pid_t pid;
134 	int ret;
135 
136 	pid = fork();
137 	if (pid < 0)
138 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
139 
140 	/* We're the original pid, wait for the result. */
141 	if (pid != 0) {
142 		ret = wait_for_pid(pid);
143 		handle_result(ret, "test statmount ns id");
144 		return;
145 	}
146 
147 	ret = setup_namespace();
148 	if (ret != NSID_PASS)
149 		exit(ret);
150 	ret = _test_statmount_mnt_ns_id();
151 	exit(ret);
152 }
153 
validate_external_listmount(pid_t pid,uint64_t child_nr_mounts)154 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
155 {
156 	uint64_t list[256];
157 	uint64_t mnt_ns_id;
158 	uint64_t nr_mounts;
159 	char buf[256];
160 	int ret;
161 
162 	/* Get the mount ns id for our child. */
163 	snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
164 	ret = get_mnt_ns_id(buf, &mnt_ns_id);
165 
166 	nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
167 	if (nr_mounts == (uint64_t)-1) {
168 		ksft_print_msg("listmount: %s\n", strerror(errno));
169 		return NSID_ERROR;
170 	}
171 
172 	if (nr_mounts != child_nr_mounts) {
173 		ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
174 			       child_nr_mounts);
175 		return NSID_FAIL;
176 	}
177 
178 	/* Validate that all of our entries match our mnt_ns_id. */
179 	for (int i = 0; i < nr_mounts; i++) {
180 		struct statmount sm;
181 
182 		ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm,
183 				sizeof(sm), 0);
184 		if (ret < 0) {
185 			ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
186 			return NSID_ERROR;
187 		}
188 
189 		if (sm.mask != STATMOUNT_MNT_NS_ID) {
190 			ksft_print_msg("statmount mnt ns id unavailable\n");
191 			return NSID_SKIP;
192 		}
193 
194 		if (sm.mnt_ns_id != mnt_ns_id) {
195 			ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
196 				       (unsigned long long)sm.mnt_ns_id,
197 				       (unsigned long long)mnt_ns_id);
198 			return NSID_FAIL;
199 		}
200 	}
201 
202 	return NSID_PASS;
203 }
204 
test_listmount_ns(void)205 static void test_listmount_ns(void)
206 {
207 	uint64_t nr_mounts;
208 	char pval;
209 	int child_ready_pipe[2];
210 	int parent_ready_pipe[2];
211 	pid_t pid;
212 	int ret, child_ret;
213 
214 	if (pipe(child_ready_pipe) < 0)
215 		ksft_exit_fail_msg("failed to create the child pipe: %s\n",
216 				   strerror(errno));
217 	if (pipe(parent_ready_pipe) < 0)
218 		ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
219 				   strerror(errno));
220 
221 	pid = fork();
222 	if (pid < 0)
223 		ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
224 
225 	if (pid == 0) {
226 		char cval;
227 		uint64_t list[256];
228 
229 		close(child_ready_pipe[0]);
230 		close(parent_ready_pipe[1]);
231 
232 		ret = setup_namespace();
233 		if (ret != NSID_PASS)
234 			exit(ret);
235 
236 		nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
237 		if (nr_mounts == (uint64_t)-1) {
238 			ksft_print_msg("listmount: %s\n", strerror(errno));
239 			exit(NSID_FAIL);
240 		}
241 
242 		/*
243 		 * Tell our parent how many mounts we have, and then wait for it
244 		 * to tell us we're done.
245 		 */
246 		if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) !=
247 					sizeof(nr_mounts))
248 			ret = NSID_ERROR;
249 		if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval))
250 			ret = NSID_ERROR;
251 		exit(NSID_PASS);
252 	}
253 
254 	close(child_ready_pipe[1]);
255 	close(parent_ready_pipe[0]);
256 
257 	/* Wait until the child has created everything. */
258 	if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
259 	    sizeof(nr_mounts))
260 		ret = NSID_ERROR;
261 
262 	ret = validate_external_listmount(pid, nr_mounts);
263 
264 	if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
265 		ret = NSID_ERROR;
266 
267 	child_ret = wait_for_pid(pid);
268 	if (child_ret != NSID_PASS)
269 		ret = child_ret;
270 	handle_result(ret, "test listmount ns id");
271 }
272 
main(void)273 int main(void)
274 {
275 	int ret;
276 
277 	ksft_print_header();
278 	ret = statmount(0, 0, 0, NULL, 0, 0);
279 	assert(ret == -1);
280 	if (errno == ENOSYS)
281 		ksft_exit_skip("statmount() syscall not supported\n");
282 
283 	ksft_set_plan(2);
284 	test_statmount_mnt_ns_id();
285 	test_listmount_ns();
286 
287 	if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
288 		ksft_exit_fail();
289 	else
290 		ksft_exit_pass();
291 }
292