1 // SPDX-License-Identifier: GPL-2.0-or-later
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <fcntl.h>
7 #include <limits.h>
8 #include <sched.h>
9 #include <stdlib.h>
10 #include <sys/mount.h>
11 #include <sys/stat.h>
12 #include <sys/wait.h>
13 #include <linux/nsfs.h>
14 #include <linux/stat.h>
15
16 #include "statmount.h"
17 #include "../utils.h"
18 #include "../../kselftest.h"
19
20 #define NSID_PASS 0
21 #define NSID_FAIL 1
22 #define NSID_SKIP 2
23 #define NSID_ERROR 3
24
handle_result(int ret,const char * testname)25 static void handle_result(int ret, const char *testname)
26 {
27 if (ret == NSID_PASS)
28 ksft_test_result_pass("%s\n", testname);
29 else if (ret == NSID_FAIL)
30 ksft_test_result_fail("%s\n", testname);
31 else if (ret == NSID_ERROR)
32 ksft_exit_fail_msg("%s\n", testname);
33 else
34 ksft_test_result_skip("%s\n", testname);
35 }
36
wait_for_pid(pid_t pid)37 static inline int wait_for_pid(pid_t pid)
38 {
39 int status, ret;
40
41 again:
42 ret = waitpid(pid, &status, 0);
43 if (ret == -1) {
44 if (errno == EINTR)
45 goto again;
46
47 ksft_print_msg("waitpid returned -1, errno=%d\n", errno);
48 return -1;
49 }
50
51 if (!WIFEXITED(status)) {
52 ksft_print_msg(
53 "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n",
54 WIFSIGNALED(status), WTERMSIG(status));
55 return -1;
56 }
57
58 ret = WEXITSTATUS(status);
59 return ret;
60 }
61
get_mnt_ns_id(const char * mnt_ns,uint64_t * mnt_ns_id)62 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id)
63 {
64 int fd = open(mnt_ns, O_RDONLY);
65
66 if (fd < 0) {
67 ksft_print_msg("failed to open for ns %s: %s\n",
68 mnt_ns, strerror(errno));
69 sleep(60);
70 return NSID_ERROR;
71 }
72
73 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) {
74 ksft_print_msg("failed to get the nsid for ns %s: %s\n",
75 mnt_ns, strerror(errno));
76 return NSID_ERROR;
77 }
78 close(fd);
79 return NSID_PASS;
80 }
81
setup_namespace(void)82 static int setup_namespace(void)
83 {
84 if (setup_userns() != 0)
85 return NSID_ERROR;
86
87 return NSID_PASS;
88 }
89
_test_statmount_mnt_ns_id(void)90 static int _test_statmount_mnt_ns_id(void)
91 {
92 struct statmount sm;
93 uint64_t mnt_ns_id;
94 uint64_t root_id;
95 int ret;
96
97 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id);
98 if (ret != NSID_PASS)
99 return ret;
100
101 root_id = get_unique_mnt_id("/");
102 if (!root_id)
103 return NSID_ERROR;
104
105 ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0);
106 if (ret == -1) {
107 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
108 return NSID_ERROR;
109 }
110
111 if (sm.size != sizeof(sm)) {
112 ksft_print_msg("unexpected size: %u != %u\n", sm.size,
113 (uint32_t)sizeof(sm));
114 return NSID_FAIL;
115 }
116 if (sm.mask != STATMOUNT_MNT_NS_ID) {
117 ksft_print_msg("statmount mnt ns id unavailable\n");
118 return NSID_SKIP;
119 }
120
121 if (sm.mnt_ns_id != mnt_ns_id) {
122 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n",
123 (unsigned long long)sm.mnt_ns_id,
124 (unsigned long long)mnt_ns_id);
125 return NSID_FAIL;
126 }
127
128 return NSID_PASS;
129 }
130
test_statmount_mnt_ns_id(void)131 static void test_statmount_mnt_ns_id(void)
132 {
133 pid_t pid;
134 int ret;
135
136 pid = fork();
137 if (pid < 0)
138 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
139
140 /* We're the original pid, wait for the result. */
141 if (pid != 0) {
142 ret = wait_for_pid(pid);
143 handle_result(ret, "test statmount ns id");
144 return;
145 }
146
147 ret = setup_namespace();
148 if (ret != NSID_PASS)
149 exit(ret);
150 ret = _test_statmount_mnt_ns_id();
151 exit(ret);
152 }
153
validate_external_listmount(pid_t pid,uint64_t child_nr_mounts)154 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts)
155 {
156 uint64_t list[256];
157 uint64_t mnt_ns_id;
158 uint64_t nr_mounts;
159 char buf[256];
160 int ret;
161
162 /* Get the mount ns id for our child. */
163 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid);
164 ret = get_mnt_ns_id(buf, &mnt_ns_id);
165
166 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0);
167 if (nr_mounts == (uint64_t)-1) {
168 ksft_print_msg("listmount: %s\n", strerror(errno));
169 return NSID_ERROR;
170 }
171
172 if (nr_mounts != child_nr_mounts) {
173 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts,
174 child_nr_mounts);
175 return NSID_FAIL;
176 }
177
178 /* Validate that all of our entries match our mnt_ns_id. */
179 for (int i = 0; i < nr_mounts; i++) {
180 struct statmount sm;
181
182 ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm,
183 sizeof(sm), 0);
184 if (ret < 0) {
185 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno));
186 return NSID_ERROR;
187 }
188
189 if (sm.mask != STATMOUNT_MNT_NS_ID) {
190 ksft_print_msg("statmount mnt ns id unavailable\n");
191 return NSID_SKIP;
192 }
193
194 if (sm.mnt_ns_id != mnt_ns_id) {
195 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n",
196 (unsigned long long)sm.mnt_ns_id,
197 (unsigned long long)mnt_ns_id);
198 return NSID_FAIL;
199 }
200 }
201
202 return NSID_PASS;
203 }
204
test_listmount_ns(void)205 static void test_listmount_ns(void)
206 {
207 uint64_t nr_mounts;
208 char pval;
209 int child_ready_pipe[2];
210 int parent_ready_pipe[2];
211 pid_t pid;
212 int ret, child_ret;
213
214 if (pipe(child_ready_pipe) < 0)
215 ksft_exit_fail_msg("failed to create the child pipe: %s\n",
216 strerror(errno));
217 if (pipe(parent_ready_pipe) < 0)
218 ksft_exit_fail_msg("failed to create the parent pipe: %s\n",
219 strerror(errno));
220
221 pid = fork();
222 if (pid < 0)
223 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno));
224
225 if (pid == 0) {
226 char cval;
227 uint64_t list[256];
228
229 close(child_ready_pipe[0]);
230 close(parent_ready_pipe[1]);
231
232 ret = setup_namespace();
233 if (ret != NSID_PASS)
234 exit(ret);
235
236 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0);
237 if (nr_mounts == (uint64_t)-1) {
238 ksft_print_msg("listmount: %s\n", strerror(errno));
239 exit(NSID_FAIL);
240 }
241
242 /*
243 * Tell our parent how many mounts we have, and then wait for it
244 * to tell us we're done.
245 */
246 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) !=
247 sizeof(nr_mounts))
248 ret = NSID_ERROR;
249 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval))
250 ret = NSID_ERROR;
251 exit(NSID_PASS);
252 }
253
254 close(child_ready_pipe[1]);
255 close(parent_ready_pipe[0]);
256
257 /* Wait until the child has created everything. */
258 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) !=
259 sizeof(nr_mounts))
260 ret = NSID_ERROR;
261
262 ret = validate_external_listmount(pid, nr_mounts);
263
264 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval))
265 ret = NSID_ERROR;
266
267 child_ret = wait_for_pid(pid);
268 if (child_ret != NSID_PASS)
269 ret = child_ret;
270 handle_result(ret, "test listmount ns id");
271 }
272
main(void)273 int main(void)
274 {
275 int ret;
276
277 ksft_print_header();
278 ret = statmount(0, 0, 0, NULL, 0, 0);
279 assert(ret == -1);
280 if (errno == ENOSYS)
281 ksft_exit_skip("statmount() syscall not supported\n");
282
283 ksft_set_plan(2);
284 test_statmount_mnt_ns_id();
285 test_listmount_ns();
286
287 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0)
288 ksft_exit_fail();
289 else
290 ksft_exit_pass();
291 }
292