common/rc: add _scratch_{u}mount_idmapped() helpers
[xfstests-dev.git] / src / idmapped-mounts / mount-idmapped.c
1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5
6 #include "../global.h"
7
8 #include <dirent.h>
9 #include <errno.h>
10 #include <fcntl.h>
11 #include <getopt.h>
12 #include <libgen.h>
13 #include <limits.h>
14 #include <linux/bpf.h>
15 #include <linux/sched.h>
16 #include <linux/seccomp.h>
17 #include <sched.h>
18 #include <signal.h>
19 #include <stdbool.h>
20 #include <stdint.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 #include <sys/mman.h>
25 #include <sys/stat.h>
26 #include <sys/syscall.h>
27 #include <sys/types.h>
28 #include <sys/wait.h>
29 #include <unistd.h>
30
31 #include "missing.h"
32 #include "utils.h"
33
34 /* A few helpful macros. */
35 #define STRLITERALLEN(x) (sizeof(""x"") - 1)
36
37 #define INTTYPE_TO_STRLEN(type)             \
38         (2 + (sizeof(type) <= 1             \
39                   ? 3                       \
40                   : sizeof(type) <= 2       \
41                         ? 5                 \
42                         : sizeof(type) <= 4 \
43                               ? 10          \
44                               : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
45
46 #define syserror(format, ...)                           \
47         ({                                              \
48                 fprintf(stderr, format, ##__VA_ARGS__); \
49                 (-errno);                               \
50         })
51
52 #define syserror_set(__ret__, format, ...)                    \
53         ({                                                    \
54                 typeof(__ret__) __internal_ret__ = (__ret__); \
55                 errno = labs(__ret__);                        \
56                 fprintf(stderr, format, ##__VA_ARGS__);       \
57                 __internal_ret__;                             \
58         })
59
60 struct list {
61         void *elem;
62         struct list *next;
63         struct list *prev;
64 };
65
66 #define list_for_each(__iterator, __list) \
67         for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
68
69 static inline void list_init(struct list *list)
70 {
71         list->elem = NULL;
72         list->next = list->prev = list;
73 }
74
75 static inline int list_empty(const struct list *list)
76 {
77         return list == list->next;
78 }
79
80 static inline void __list_add(struct list *new, struct list *prev, struct list *next)
81 {
82         next->prev = new;
83         new->next = next;
84         new->prev = prev;
85         prev->next = new;
86 }
87
88 static inline void list_add_tail(struct list *head, struct list *list)
89 {
90         __list_add(list, head->prev, head);
91 }
92
93 typedef enum idmap_type_t {
94         ID_TYPE_UID,
95         ID_TYPE_GID
96 } idmap_type_t;
97
98 struct id_map {
99         idmap_type_t map_type;
100         __u32 nsid;
101         __u32 hostid;
102         __u32 range;
103 };
104
105 static struct list active_map;
106
107 static int add_map_entry(__u32 id_host,
108                          __u32 id_ns,
109                          __u32 range,
110                          idmap_type_t map_type)
111 {
112         struct list *new_list = NULL;
113         struct id_map *newmap = NULL;
114
115         newmap = malloc(sizeof(*newmap));
116         if (!newmap)
117                 return -ENOMEM;
118
119         new_list = malloc(sizeof(struct list));
120         if (!new_list) {
121                 free(newmap);
122                 return -ENOMEM;
123         }
124
125         *newmap = (struct id_map){
126                 .hostid         = id_host,
127                 .nsid           = id_ns,
128                 .range          = range,
129                 .map_type       = map_type,
130         };
131
132         new_list->elem = newmap;
133         list_add_tail(&active_map, new_list);
134         return 0;
135 }
136
137 static int parse_map(char *map)
138 {
139         char types[2] = {'u', 'g'};
140         int ret;
141         __u32 id_host, id_ns, range;
142         char which;
143
144         if (!map)
145                 return -1;
146
147         ret = sscanf(map, "%c:%u:%u:%u", &which, &id_ns, &id_host, &range);
148         if (ret != 4)
149                 return -1;
150
151         if (which != 'b' && which != 'u' && which != 'g')
152                 return -1;
153
154         for (int i = 0; i < 2; i++) {
155                 idmap_type_t map_type;
156
157                 if (which != types[i] && which != 'b')
158                         continue;
159
160                 if (types[i] == 'u')
161                         map_type = ID_TYPE_UID;
162                 else
163                         map_type = ID_TYPE_GID;
164
165                 ret = add_map_entry(id_host, id_ns, range, map_type);
166                 if (ret < 0)
167                         return ret;
168         }
169
170         return 0;
171 }
172
173 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
174 {
175         int fd = -EBADF, setgroups_fd = -EBADF;
176         int fret = -1;
177         int ret;
178         char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
179                   STRLITERALLEN("/setgroups") + 1];
180
181         if (geteuid() != 0 && map_type == ID_TYPE_GID) {
182                 ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
183                 if (ret < 0 || ret >= sizeof(path))
184                         goto out;
185
186                 setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
187                 if (setgroups_fd < 0 && errno != ENOENT) {
188                         syserror("Failed to open \"%s\"", path);
189                         goto out;
190                 }
191
192                 if (setgroups_fd >= 0) {
193                         ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
194                         if (ret != STRLITERALLEN("deny\n")) {
195                                 syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
196                                 goto out;
197                         }
198                 }
199         }
200
201         ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
202         if (ret < 0 || ret >= sizeof(path))
203                 goto out;
204
205         fd = open(path, O_WRONLY | O_CLOEXEC);
206         if (fd < 0) {
207                 syserror("Failed to open \"%s\"", path);
208                 goto out;
209         }
210
211         ret = write_nointr(fd, buf, buf_size);
212         if (ret != buf_size) {
213                 syserror("Failed to write %cid mapping to \"%s\"",
214                          map_type == ID_TYPE_UID ? 'u' : 'g', path);
215                 goto out;
216         }
217
218         fret = 0;
219 out:
220         if (fd >= 0)
221                 close(fd);
222         if (setgroups_fd >= 0)
223                 close(setgroups_fd);
224
225         return fret;
226 }
227
228 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
229 {
230         int fill, left;
231         char mapbuf[4096] = {};
232         bool had_entry = false;
233
234         for (idmap_type_t map_type = ID_TYPE_UID, u_or_g = 'u';
235              map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
236                 char *pos = mapbuf;
237                 int ret;
238                 struct list *iterator;
239
240
241                 list_for_each(iterator, idmap) {
242                         struct id_map *map = iterator->elem;
243                         if (map->map_type != map_type)
244                                 continue;
245
246                         had_entry = true;
247
248                         left = 4096 - (pos - mapbuf);
249                         fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
250                         /*
251                          * The kernel only takes <= 4k for writes to
252                          * /proc/<pid>/{g,u}id_map
253                          */
254                         if (fill <= 0 || fill >= left)
255                                 return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
256
257                         pos += fill;
258                 }
259                 if (!had_entry)
260                         continue;
261
262                 ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
263                 if (ret < 0)
264                         return syserror("Failed to write mapping: %s", mapbuf);
265
266                 memset(mapbuf, 0, sizeof(mapbuf));
267         }
268
269         return 0;
270 }
271
272 static int get_userns_fd_from_idmap(struct list *idmap)
273 {
274         int ret;
275         pid_t pid;
276         char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
277                   STRLITERALLEN("/ns/user") + 1];
278
279         pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
280         if (pid < 0)
281                 return -errno;
282
283         ret = map_ids_from_idmap(idmap, pid);
284         if (ret < 0)
285                 return ret;
286
287         ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
288         if (ret < 0 || (size_t)ret >= sizeof(path_ns))
289                 ret = -EIO;
290         else
291                 ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
292
293         (void)kill(pid, SIGKILL);
294         (void)wait_for_pid(pid);
295         return ret;
296 }
297
298 static inline bool strnequal(const char *str, const char *eq, size_t len)
299 {
300         return strncmp(str, eq, len) == 0;
301 }
302
303 static void usage(void)
304 {
305         const char *text = "\
306 mount-idmapped --map-mount=<idmap> <source> <target>\n\
307 \n\
308 Create an idmapped mount of <source> at <target>\n\
309 Options:\n\
310   --map-mount=<idmap>\n\
311         Specify an idmap for the <target> mount in the format\n\
312         <idmap-type>:<id-from>:<id-to>:<id-range>\n\
313         The <idmap-type> can be:\n\
314         \"b\" or \"both\"       -> map both uids and gids\n\
315         \"u\" or \"uid\"        -> map uids\n\
316         \"g\" or \"gid\"        -> map gids\n\
317         For example, specifying:\n\
318         both:1000:1001:1        -> map uid and gid 1000 to uid and gid 1001 in <target> and no other ids\n\
319         uid:20000:100000:1000   -> map uid 20000 to uid 100000, uid 20001 to uid 100001 [...] in <target>\n\
320         Currently up to 340 separate idmappings may be specified.\n\n\
321   --map-mount=/proc/<pid>/ns/user\n\
322         Specify a path to a user namespace whose idmap is to be used.\n\n\
323   --recursive\n\
324         Copy the whole mount tree from <source> and apply the idmap to everyone at <target>.\n\n\
325 Examples:\n\
326   - Create an idmapped mount of /source on /target with both ('b') uids and gids mapped:\n\
327         mount-idmapped --map-mount b:0:10000:10000 /source /target\n\n\
328   - Create an idmapped mount of /source on /target with uids ('u') and gids ('g') mapped separately:\n\
329         mount-idmapped --map-mount u:0:10000:10000 g:0:20000:20000 /source /target\n\n\
330 ";
331         fprintf(stderr, "%s", text);
332         _exit(EXIT_SUCCESS);
333 }
334
335 #define exit_usage(format, ...)                         \
336         ({                                              \
337                 fprintf(stderr, format, ##__VA_ARGS__); \
338                 usage();                                \
339         })
340
341 #define exit_log(format, ...)                           \
342         ({                                              \
343                 fprintf(stderr, format, ##__VA_ARGS__); \
344                 exit(EXIT_FAILURE);                     \
345         })
346
347 static const struct option longopts[] = {
348         {"map-mount",   required_argument,      0,      'a'},
349         {"help",        no_argument,            0,      'c'},
350         {"recursive",   no_argument,            0,      'd'},
351         { NULL,         0,                      0,      0  },
352 };
353
354 int main(int argc, char *argv[])
355 {
356         int fd_userns = -EBADF;
357         int index = 0;
358         const char *source = NULL, *target = NULL;
359         bool recursive = false;
360         int fd_tree, new_argc, ret;
361         char *const *new_argv;
362
363         list_init(&active_map);
364         while ((ret = getopt_long_only(argc, argv, "", longopts, &index)) != -1) {
365                 switch (ret) {
366                 case 'a':
367                         if (strnequal(optarg, "/proc/", STRLITERALLEN("/proc/"))) {
368                                 fd_userns = open(optarg, O_RDONLY | O_CLOEXEC);
369                                 if (fd_userns < 0)
370                                         exit_log("%m - Failed top open user namespace path %s\n", optarg);
371                                 break;
372                         }
373
374                         ret = parse_map(optarg);
375                         if (ret < 0)
376                                 exit_log("Failed to parse idmaps for mount\n");
377                         break;
378                 case 'd':
379                         recursive = true;
380                         break;
381                 case 'c':
382                         /* fallthrough */
383                 default:
384                         usage();
385                 }
386         }
387
388         new_argv = &argv[optind];
389         new_argc = argc - optind;
390         if (new_argc < 2)
391                 exit_usage("Missing source or target mountpoint\n\n");
392         source = new_argv[0];
393         target = new_argv[1];
394
395         fd_tree = sys_open_tree(-EBADF, source,
396                                 OPEN_TREE_CLONE |
397                                 OPEN_TREE_CLOEXEC |
398                                 AT_EMPTY_PATH |
399                                 (recursive ? AT_RECURSIVE : 0));
400         if (fd_tree < 0) {
401                 exit_log("%m - Failed to open %s\n", source);
402                 exit(EXIT_FAILURE);
403         }
404
405         if (!list_empty(&active_map) || fd_userns >= 0) {
406                 struct mount_attr attr = {
407                         .attr_set = MOUNT_ATTR_IDMAP,
408                 };
409
410                 if (fd_userns >= 0)
411                         attr.userns_fd = fd_userns;
412                 else
413                         attr.userns_fd = get_userns_fd_from_idmap(&active_map);
414                 if (attr.userns_fd < 0)
415                         exit_log("%m - Failed to create user namespace\n");
416
417                 ret = sys_mount_setattr(fd_tree, "", AT_EMPTY_PATH | AT_RECURSIVE,
418                                         &attr, sizeof(attr));
419                 if (ret < 0)
420                         exit_log("%m - Failed to change mount attributes\n");
421                 close(attr.userns_fd);
422         }
423
424         ret = sys_move_mount(fd_tree, "", -EBADF, target,
425                              MOVE_MOUNT_F_EMPTY_PATH);
426         if (ret < 0)
427                 exit_log("%m - Failed to attach mount to %s\n", target);
428         close(fd_tree);
429
430         exit(EXIT_SUCCESS);
431 }