src: Fix an error for the loop initialization declaration
[xfstests-dev.git] / src / idmapped-mounts / mount-idmapped.c
1 // SPDX-License-Identifier: GPL-2.0
2 #ifndef _GNU_SOURCE
3 #define _GNU_SOURCE
4 #endif
5
6 #include "../global.h"
7
8 #include <dirent.h>
9 #include <errno.h>
10 #include <fcntl.h>
11 #include <getopt.h>
12 #include <libgen.h>
13 #include <limits.h>
14 #include <linux/bpf.h>
15 #include <linux/sched.h>
16 #include <linux/seccomp.h>
17 #include <sched.h>
18 #include <signal.h>
19 #include <stdbool.h>
20 #include <stdint.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 #include <sys/mman.h>
25 #include <sys/stat.h>
26 #include <sys/syscall.h>
27 #include <sys/types.h>
28 #include <sys/wait.h>
29 #include <unistd.h>
30
31 #include "missing.h"
32 #include "utils.h"
33
34 /* A few helpful macros. */
35 #define STRLITERALLEN(x) (sizeof(""x"") - 1)
36
37 #define INTTYPE_TO_STRLEN(type)             \
38         (2 + (sizeof(type) <= 1             \
39                   ? 3                       \
40                   : sizeof(type) <= 2       \
41                         ? 5                 \
42                         : sizeof(type) <= 4 \
43                               ? 10          \
44                               : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
45
46 #define syserror(format, ...)                           \
47         ({                                              \
48                 fprintf(stderr, format, ##__VA_ARGS__); \
49                 (-errno);                               \
50         })
51
52 #define syserror_set(__ret__, format, ...)                    \
53         ({                                                    \
54                 typeof(__ret__) __internal_ret__ = (__ret__); \
55                 errno = labs(__ret__);                        \
56                 fprintf(stderr, format, ##__VA_ARGS__);       \
57                 __internal_ret__;                             \
58         })
59
60 struct list {
61         void *elem;
62         struct list *next;
63         struct list *prev;
64 };
65
66 #define list_for_each(__iterator, __list) \
67         for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
68
69 static inline void list_init(struct list *list)
70 {
71         list->elem = NULL;
72         list->next = list->prev = list;
73 }
74
75 static inline int list_empty(const struct list *list)
76 {
77         return list == list->next;
78 }
79
80 static inline void __list_add(struct list *new, struct list *prev, struct list *next)
81 {
82         next->prev = new;
83         new->next = next;
84         new->prev = prev;
85         prev->next = new;
86 }
87
88 static inline void list_add_tail(struct list *head, struct list *list)
89 {
90         __list_add(list, head->prev, head);
91 }
92
93 typedef enum idmap_type_t {
94         ID_TYPE_UID,
95         ID_TYPE_GID
96 } idmap_type_t;
97
98 struct id_map {
99         idmap_type_t map_type;
100         __u32 nsid;
101         __u32 hostid;
102         __u32 range;
103 };
104
105 static struct list active_map;
106
107 static int add_map_entry(__u32 id_host,
108                          __u32 id_ns,
109                          __u32 range,
110                          idmap_type_t map_type)
111 {
112         struct list *new_list = NULL;
113         struct id_map *newmap = NULL;
114
115         newmap = malloc(sizeof(*newmap));
116         if (!newmap)
117                 return -ENOMEM;
118
119         new_list = malloc(sizeof(struct list));
120         if (!new_list) {
121                 free(newmap);
122                 return -ENOMEM;
123         }
124
125         *newmap = (struct id_map){
126                 .hostid         = id_host,
127                 .nsid           = id_ns,
128                 .range          = range,
129                 .map_type       = map_type,
130         };
131
132         new_list->elem = newmap;
133         list_add_tail(&active_map, new_list);
134         return 0;
135 }
136
137 static int parse_map(char *map)
138 {
139         char types[2] = {'u', 'g'};
140         int ret, i;
141         __u32 id_host, id_ns, range;
142         char which;
143
144         if (!map)
145                 return -1;
146
147         ret = sscanf(map, "%c:%u:%u:%u", &which, &id_ns, &id_host, &range);
148         if (ret != 4)
149                 return -1;
150
151         if (which != 'b' && which != 'u' && which != 'g')
152                 return -1;
153
154         for (i = 0; i < 2; i++) {
155                 idmap_type_t map_type;
156
157                 if (which != types[i] && which != 'b')
158                         continue;
159
160                 if (types[i] == 'u')
161                         map_type = ID_TYPE_UID;
162                 else
163                         map_type = ID_TYPE_GID;
164
165                 ret = add_map_entry(id_host, id_ns, range, map_type);
166                 if (ret < 0)
167                         return ret;
168         }
169
170         return 0;
171 }
172
173 static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
174 {
175         int fd = -EBADF, setgroups_fd = -EBADF;
176         int fret = -1;
177         int ret;
178         char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
179                   STRLITERALLEN("/setgroups") + 1];
180
181         if (geteuid() != 0 && map_type == ID_TYPE_GID) {
182                 ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
183                 if (ret < 0 || ret >= sizeof(path))
184                         goto out;
185
186                 setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
187                 if (setgroups_fd < 0 && errno != ENOENT) {
188                         syserror("Failed to open \"%s\"", path);
189                         goto out;
190                 }
191
192                 if (setgroups_fd >= 0) {
193                         ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
194                         if (ret != STRLITERALLEN("deny\n")) {
195                                 syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
196                                 goto out;
197                         }
198                 }
199         }
200
201         ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
202         if (ret < 0 || ret >= sizeof(path))
203                 goto out;
204
205         fd = open(path, O_WRONLY | O_CLOEXEC);
206         if (fd < 0) {
207                 syserror("Failed to open \"%s\"", path);
208                 goto out;
209         }
210
211         ret = write_nointr(fd, buf, buf_size);
212         if (ret != buf_size) {
213                 syserror("Failed to write %cid mapping to \"%s\"",
214                          map_type == ID_TYPE_UID ? 'u' : 'g', path);
215                 goto out;
216         }
217
218         fret = 0;
219 out:
220         if (fd >= 0)
221                 close(fd);
222         if (setgroups_fd >= 0)
223                 close(setgroups_fd);
224
225         return fret;
226 }
227
228 static int map_ids_from_idmap(struct list *idmap, pid_t pid)
229 {
230         int fill, left;
231         char mapbuf[4096] = {};
232         bool had_entry = false;
233         idmap_type_t map_type, u_or_g;
234
235         for (map_type = ID_TYPE_UID, u_or_g = 'u';
236              map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
237                 char *pos = mapbuf;
238                 int ret;
239                 struct list *iterator;
240
241
242                 list_for_each(iterator, idmap) {
243                         struct id_map *map = iterator->elem;
244                         if (map->map_type != map_type)
245                                 continue;
246
247                         had_entry = true;
248
249                         left = 4096 - (pos - mapbuf);
250                         fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
251                         /*
252                          * The kernel only takes <= 4k for writes to
253                          * /proc/<pid>/{g,u}id_map
254                          */
255                         if (fill <= 0 || fill >= left)
256                                 return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
257
258                         pos += fill;
259                 }
260                 if (!had_entry)
261                         continue;
262
263                 ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
264                 if (ret < 0)
265                         return syserror("Failed to write mapping: %s", mapbuf);
266
267                 memset(mapbuf, 0, sizeof(mapbuf));
268         }
269
270         return 0;
271 }
272
273 static int get_userns_fd_from_idmap(struct list *idmap)
274 {
275         int ret;
276         pid_t pid;
277         char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
278                   STRLITERALLEN("/ns/user") + 1];
279
280         pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
281         if (pid < 0)
282                 return -errno;
283
284         ret = map_ids_from_idmap(idmap, pid);
285         if (ret < 0)
286                 return ret;
287
288         ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
289         if (ret < 0 || (size_t)ret >= sizeof(path_ns))
290                 ret = -EIO;
291         else
292                 ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
293
294         (void)kill(pid, SIGKILL);
295         (void)wait_for_pid(pid);
296         return ret;
297 }
298
299 static inline bool strnequal(const char *str, const char *eq, size_t len)
300 {
301         return strncmp(str, eq, len) == 0;
302 }
303
304 static void usage(void)
305 {
306         const char *text = "\
307 mount-idmapped --map-mount=<idmap> <source> <target>\n\
308 \n\
309 Create an idmapped mount of <source> at <target>\n\
310 Options:\n\
311   --map-mount=<idmap>\n\
312         Specify an idmap for the <target> mount in the format\n\
313         <idmap-type>:<id-from>:<id-to>:<id-range>\n\
314         The <idmap-type> can be:\n\
315         \"b\" or \"both\"       -> map both uids and gids\n\
316         \"u\" or \"uid\"        -> map uids\n\
317         \"g\" or \"gid\"        -> map gids\n\
318         For example, specifying:\n\
319         both:1000:1001:1        -> map uid and gid 1000 to uid and gid 1001 in <target> and no other ids\n\
320         uid:20000:100000:1000   -> map uid 20000 to uid 100000, uid 20001 to uid 100001 [...] in <target>\n\
321         Currently up to 340 separate idmappings may be specified.\n\n\
322   --map-mount=/proc/<pid>/ns/user\n\
323         Specify a path to a user namespace whose idmap is to be used.\n\n\
324   --recursive\n\
325         Copy the whole mount tree from <source> and apply the idmap to everyone at <target>.\n\n\
326 Examples:\n\
327   - Create an idmapped mount of /source on /target with both ('b') uids and gids mapped:\n\
328         mount-idmapped --map-mount b:0:10000:10000 /source /target\n\n\
329   - Create an idmapped mount of /source on /target with uids ('u') and gids ('g') mapped separately:\n\
330         mount-idmapped --map-mount u:0:10000:10000 g:0:20000:20000 /source /target\n\n\
331 ";
332         fprintf(stderr, "%s", text);
333         _exit(EXIT_SUCCESS);
334 }
335
336 #define exit_usage(format, ...)                         \
337         ({                                              \
338                 fprintf(stderr, format, ##__VA_ARGS__); \
339                 usage();                                \
340         })
341
342 #define exit_log(format, ...)                           \
343         ({                                              \
344                 fprintf(stderr, format, ##__VA_ARGS__); \
345                 exit(EXIT_FAILURE);                     \
346         })
347
348 static const struct option longopts[] = {
349         {"map-mount",   required_argument,      0,      'a'},
350         {"help",        no_argument,            0,      'c'},
351         {"recursive",   no_argument,            0,      'd'},
352         { NULL,         0,                      0,      0  },
353 };
354
355 int main(int argc, char *argv[])
356 {
357         int fd_userns = -EBADF;
358         int index = 0;
359         const char *source = NULL, *target = NULL;
360         bool recursive = false;
361         int fd_tree, new_argc, ret;
362         char *const *new_argv;
363
364         list_init(&active_map);
365         while ((ret = getopt_long_only(argc, argv, "", longopts, &index)) != -1) {
366                 switch (ret) {
367                 case 'a':
368                         if (strnequal(optarg, "/proc/", STRLITERALLEN("/proc/"))) {
369                                 fd_userns = open(optarg, O_RDONLY | O_CLOEXEC);
370                                 if (fd_userns < 0)
371                                         exit_log("%m - Failed top open user namespace path %s\n", optarg);
372                                 break;
373                         }
374
375                         ret = parse_map(optarg);
376                         if (ret < 0)
377                                 exit_log("Failed to parse idmaps for mount\n");
378                         break;
379                 case 'd':
380                         recursive = true;
381                         break;
382                 case 'c':
383                         /* fallthrough */
384                 default:
385                         usage();
386                 }
387         }
388
389         new_argv = &argv[optind];
390         new_argc = argc - optind;
391         if (new_argc < 2)
392                 exit_usage("Missing source or target mountpoint\n\n");
393         source = new_argv[0];
394         target = new_argv[1];
395
396         fd_tree = sys_open_tree(-EBADF, source,
397                                 OPEN_TREE_CLONE |
398                                 OPEN_TREE_CLOEXEC |
399                                 AT_EMPTY_PATH |
400                                 (recursive ? AT_RECURSIVE : 0));
401         if (fd_tree < 0) {
402                 exit_log("%m - Failed to open %s\n", source);
403                 exit(EXIT_FAILURE);
404         }
405
406         if (!list_empty(&active_map) || fd_userns >= 0) {
407                 struct mount_attr attr = {
408                         .attr_set = MOUNT_ATTR_IDMAP,
409                 };
410
411                 if (fd_userns >= 0)
412                         attr.userns_fd = fd_userns;
413                 else
414                         attr.userns_fd = get_userns_fd_from_idmap(&active_map);
415                 if (attr.userns_fd < 0)
416                         exit_log("%m - Failed to create user namespace\n");
417
418                 ret = sys_mount_setattr(fd_tree, "", AT_EMPTY_PATH | AT_RECURSIVE,
419                                         &attr, sizeof(attr));
420                 if (ret < 0)
421                         exit_log("%m - Failed to change mount attributes\n");
422                 close(attr.userns_fd);
423         }
424
425         ret = sys_move_mount(fd_tree, "", -EBADF, target,
426                              MOVE_MOUNT_F_EMPTY_PATH);
427         if (ret < 0)
428                 exit_log("%m - Failed to attach mount to %s\n", target);
429         close(fd_tree);
430
431         exit(EXIT_SUCCESS);
432 }