c2afa8dc8fa04ac98bdb19ff0f7e7503e329ba12
[xfstests-dev.git] / src / idmapped-mounts / utils.c
1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5 #include <fcntl.h>
6 #include <grp.h>
7 #include <linux/limits.h>
8 #include <sched.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <sys/eventfd.h>
12 #include <sys/mount.h>
13 #include <sys/prctl.h>
14 #include <sys/socket.h>
15 #include <sys/stat.h>
16 #include <sys/types.h>
17 #include <sys/wait.h>
18
19 #include "utils.h"
20
21 ssize_t read_nointr(int fd, void *buf, size_t count)
22 {
23         ssize_t ret;
24
25         do {
26                 ret = read(fd, buf, count);
27         } while (ret < 0 && errno == EINTR);
28
29         return ret;
30 }
31
32 ssize_t write_nointr(int fd, const void *buf, size_t count)
33 {
34         ssize_t ret;
35
36         do {
37                 ret = write(fd, buf, count);
38         } while (ret < 0 && errno == EINTR);
39
40         return ret;
41 }
42
43 #define __STACK_SIZE (8 * 1024 * 1024)
44 pid_t do_clone(int (*fn)(void *), void *arg, int flags)
45 {
46         void *stack;
47
48         stack = malloc(__STACK_SIZE);
49         if (!stack)
50                 return -ENOMEM;
51
52 #ifdef __ia64__
53         return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
54 #else
55         return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
56 #endif
57 }
58
59 static int get_userns_fd_cb(void *data)
60 {
61         return 0;
62 }
63
64 int wait_for_pid(pid_t pid)
65 {
66         int status, ret;
67
68 again:
69         ret = waitpid(pid, &status, 0);
70         if (ret == -1) {
71                 if (errno == EINTR)
72                         goto again;
73
74                 return -1;
75         }
76
77         if (!WIFEXITED(status))
78                 return -1;
79
80         return WEXITSTATUS(status);
81 }
82
83 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
84 {
85         int fd = -EBADF, setgroups_fd = -EBADF;
86         int fret = -1;
87         int ret;
88         char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
89                   STRLITERALLEN("/setgroups") + 1];
90
91         if (geteuid() != 0 && map_type == ID_TYPE_GID) {
92                 ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
93                 if (ret < 0 || ret >= sizeof(path))
94                         goto out;
95
96                 setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
97                 if (setgroups_fd < 0 && errno != ENOENT) {
98                         syserror("Failed to open \"%s\"", path);
99                         goto out;
100                 }
101
102                 if (setgroups_fd >= 0) {
103                         ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
104                         if (ret != STRLITERALLEN("deny\n")) {
105                                 syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
106                                 goto out;
107                         }
108                 }
109         }
110
111         ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
112         if (ret < 0 || ret >= sizeof(path))
113                 goto out;
114
115         fd = open(path, O_WRONLY | O_CLOEXEC);
116         if (fd < 0) {
117                 syserror("Failed to open \"%s\"", path);
118                 goto out;
119         }
120
121         ret = write_nointr(fd, buf, buf_size);
122         if (ret != buf_size) {
123                 syserror("Failed to write %cid mapping to \"%s\"",
124                          map_type == ID_TYPE_UID ? 'u' : 'g', path);
125                 goto out;
126         }
127
128         fret = 0;
129 out:
130         if (fd >= 0)
131                 close(fd);
132         if (setgroups_fd >= 0)
133                 close(setgroups_fd);
134
135         return fret;
136 }
137
138 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
139 {
140         int fill, left;
141         char mapbuf[4096] = {};
142         bool had_entry = false;
143         idmap_type_t map_type, u_or_g;
144
145         if (list_empty(idmap))
146                 return 0;
147
148         for (map_type = ID_TYPE_UID, u_or_g = 'u';
149              map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
150                 char *pos = mapbuf;
151                 int ret;
152                 struct list *iterator;
153
154
155                 list_for_each(iterator, idmap) {
156                         struct id_map *map = iterator->elem;
157                         if (map->map_type != map_type)
158                                 continue;
159
160                         had_entry = true;
161
162                         left = 4096 - (pos - mapbuf);
163                         fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
164                         /*
165                          * The kernel only takes <= 4k for writes to
166                          * /proc/<pid>/{g,u}id_map
167                          */
168                         if (fill <= 0 || fill >= left)
169                                 return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
170
171                         pos += fill;
172                 }
173                 if (!had_entry)
174                         continue;
175
176                 ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
177                 if (ret < 0)
178                         return syserror("Failed to write mapping: %s", mapbuf);
179
180                 memset(mapbuf, 0, sizeof(mapbuf));
181         }
182
183         return 0;
184 }
185
186 int get_userns_fd_from_idmap(struct list *idmap)
187 {
188         int ret;
189         pid_t pid;
190         char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
191                      STRLITERALLEN("/ns/user") + 1];
192
193         pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
194         if (pid < 0)
195                 return -errno;
196
197         ret = map_ids_from_idmap(idmap, pid);
198         if (ret < 0)
199                 return ret;
200
201         ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
202         if (ret < 0 || (size_t)ret >= sizeof(path_ns))
203                 ret = -EIO;
204         else
205                 ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
206
207         (void)kill(pid, SIGKILL);
208         (void)wait_for_pid(pid);
209         return ret;
210 }
211
212 int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
213 {
214         struct list head, uid_mapl, gid_mapl;
215         struct id_map uid_map = {
216                 .map_type       = ID_TYPE_UID,
217                 .nsid           = nsid,
218                 .hostid         = hostid,
219                 .range          = range,
220         };
221         struct id_map gid_map = {
222                 .map_type       = ID_TYPE_GID,
223                 .nsid           = nsid,
224                 .hostid         = hostid,
225                 .range          = range,
226         };
227
228         list_init(&head);
229         uid_mapl.elem = &uid_map;
230         gid_mapl.elem = &gid_map;
231         list_add_tail(&head, &uid_mapl);
232         list_add_tail(&head, &gid_mapl);
233
234         return get_userns_fd_from_idmap(&head);
235 }
236
237 bool switch_ids(uid_t uid, gid_t gid)
238 {
239         if (setgroups(0, NULL))
240                 return syserror("failure: setgroups");
241
242         if (setresgid(gid, gid, gid))
243                 return syserror("failure: setresgid");
244
245         if (setresuid(uid, uid, uid))
246                 return syserror("failure: setresuid");
247
248         return true;
249 }
250
251 static int userns_fd_cb(void *data)
252 {
253         struct userns_hierarchy *h = data;
254         char c;
255         int ret;
256
257         ret = read_nointr(h->fd_event, &c, 1);
258         if (ret < 0)
259                 return syserror("failure: read from socketpair");
260
261         /* Only switch ids if someone actually wrote a mapping for us. */
262         if (c == '1') {
263                 if (!switch_ids(0, 0))
264                         return syserror("failure: switch ids to 0");
265
266                 /* Ensure we can access proc files from processes we can ptrace. */
267                 ret = prctl(PR_SET_DUMPABLE, 1, 0, 0, 0);
268                 if (ret < 0)
269                         return syserror("failure: make dumpable");
270         }
271
272         ret = write_nointr(h->fd_event, "1", 1);
273         if (ret < 0)
274                 return syserror("failure: write to socketpair");
275
276         ret = create_userns_hierarchy(++h);
277         if (ret < 0)
278                 return syserror("failure: userns level %d", h->level);
279
280         return 0;
281 }
282
283 int create_userns_hierarchy(struct userns_hierarchy *h)
284 {
285         int fret = -1;
286         char c;
287         int fd_socket[2];
288         int fd_userns = -EBADF, ret = -1;
289         ssize_t bytes;
290         pid_t pid;
291         char path[256];
292
293         if (h->level == MAX_USERNS_LEVEL)
294                 return 0;
295
296         ret = socketpair(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0, fd_socket);
297         if (ret < 0)
298                 return syserror("failure: create socketpair");
299
300         /* Note the CLONE_FILES | CLONE_VM when mucking with fds and memory. */
301         h->fd_event = fd_socket[1];
302         pid = do_clone(userns_fd_cb, h, CLONE_NEWUSER | CLONE_FILES | CLONE_VM);
303         if (pid < 0) {
304                 syserror("failure: userns level %d", h->level);
305                 goto out_close;
306         }
307
308         ret = map_ids_from_idmap(&h->id_map, pid);
309         if (ret < 0) {
310                 kill(pid, SIGKILL);
311                 syserror("failure: writing id mapping for userns level %d for %d", h->level, pid);
312                 goto out_wait;
313         }
314
315         if (!list_empty(&h->id_map))
316                 bytes = write_nointr(fd_socket[0], "1", 1); /* Inform the child we wrote a mapping. */
317         else
318                 bytes = write_nointr(fd_socket[0], "0", 1); /* Inform the child we didn't write a mapping. */
319         if (bytes < 0) {
320                 kill(pid, SIGKILL);
321                 syserror("failure: write to socketpair");
322                 goto out_wait;
323         }
324
325         /* Wait for child to set*id() and become dumpable. */
326         bytes = read_nointr(fd_socket[0], &c, 1);
327         if (bytes < 0) {
328                 kill(pid, SIGKILL);
329                 syserror("failure: read from socketpair");
330                 goto out_wait;
331         }
332
333         snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
334         fd_userns = open(path, O_RDONLY | O_CLOEXEC);
335         if (fd_userns < 0) {
336                 kill(pid, SIGKILL);
337                 syserror("failure: open userns level %d for %d", h->level, pid);
338                 goto out_wait;
339         }
340
341         fret = 0;
342
343 out_wait:
344         if (!wait_for_pid(pid) && !fret) {
345                 h->fd_userns = fd_userns;
346                 fd_userns = -EBADF;
347         }
348
349 out_close:
350         if (fd_userns >= 0)
351                 close(fd_userns);
352         close(fd_socket[0]);
353         close(fd_socket[1]);
354         return fret;
355 }
356
357 int add_map_entry(struct list *head,
358                   __u32 id_host,
359                   __u32 id_ns,
360                   __u32 range,
361                   idmap_type_t map_type)
362 {
363         struct list *new_list = NULL;
364         struct id_map *newmap = NULL;
365
366         newmap = malloc(sizeof(*newmap));
367         if (!newmap)
368                 return -ENOMEM;
369
370         new_list = malloc(sizeof(struct list));
371         if (!new_list) {
372                 free(newmap);
373                 return -ENOMEM;
374         }
375
376         *newmap = (struct id_map){
377                 .hostid         = id_host,
378                 .nsid           = id_ns,
379                 .range          = range,
380                 .map_type       = map_type,
381         };
382
383         new_list->elem = newmap;
384         list_add_tail(head, new_list);
385         return 0;
386 }