]> git.apps.os.sepia.ceph.com Git - xfstests-dev.git/commitdiff
idmapped-mounts: refactor helpers
authorChristian Brauner <christian.brauner@ubuntu.com>
Sat, 14 Aug 2021 10:48:02 +0000 (12:48 +0200)
committerEryu Guan <guaneryu@gmail.com>
Sun, 22 Aug 2021 11:51:37 +0000 (19:51 +0800)
Make all userns creation helpers share a commond codebase and move a
bunch of code into utils.{c,h}. This simplifies a bunch of things and
makes it easier to create nested user namespaces in follow up patches.

Cc: fstests@vger.kernel.org
Reviewed-by: Christoph Hellwig <hch@lst.de>
Signed-off-by: Christian Brauner <christian.brauner@ubuntu.com>
Signed-off-by: Eryu Guan <guaneryu@gmail.com>
src/idmapped-mounts/mount-idmapped.c
src/idmapped-mounts/utils.c
src/idmapped-mounts/utils.h

index 219104e763690aada2c55b62d8d4e95c67f801de..b1209057e893a4f5e8c4148a3260bf25f9523a11 100644 (file)
 #include "missing.h"
 #include "utils.h"
 
-/* A few helpful macros. */
-#define STRLITERALLEN(x) (sizeof(""x"") - 1)
-
-#define INTTYPE_TO_STRLEN(type)             \
-       (2 + (sizeof(type) <= 1             \
-                 ? 3                       \
-                 : sizeof(type) <= 2       \
-                       ? 5                 \
-                       : sizeof(type) <= 4 \
-                             ? 10          \
-                             : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
-
-#define syserror(format, ...)                           \
-       ({                                              \
-               fprintf(stderr, format, ##__VA_ARGS__); \
-               (-errno);                               \
-       })
-
-#define syserror_set(__ret__, format, ...)                    \
-       ({                                                    \
-               typeof(__ret__) __internal_ret__ = (__ret__); \
-               errno = labs(__ret__);                        \
-               fprintf(stderr, format, ##__VA_ARGS__);       \
-               __internal_ret__;                             \
-       })
-
-struct list {
-       void *elem;
-       struct list *next;
-       struct list *prev;
-};
-
-#define list_for_each(__iterator, __list) \
-       for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
-
-static inline void list_init(struct list *list)
-{
-       list->elem = NULL;
-       list->next = list->prev = list;
-}
-
-static inline int list_empty(const struct list *list)
-{
-       return list == list->next;
-}
-
-static inline void __list_add(struct list *new, struct list *prev, struct list *next)
-{
-       next->prev = new;
-       new->next = next;
-       new->prev = prev;
-       prev->next = new;
-}
-
-static inline void list_add_tail(struct list *head, struct list *list)
-{
-       __list_add(list, head->prev, head);
-}
-
-typedef enum idmap_type_t {
-       ID_TYPE_UID,
-       ID_TYPE_GID
-} idmap_type_t;
-
-struct id_map {
-       idmap_type_t map_type;
-       __u32 nsid;
-       __u32 hostid;
-       __u32 range;
-};
-
 static struct list active_map;
 
 static int add_map_entry(__u32 id_host,
@@ -166,132 +95,6 @@ static int parse_map(char *map)
        return 0;
 }
 
-static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
-{
-       int fd = -EBADF, setgroups_fd = -EBADF;
-       int fret = -1;
-       int ret;
-       char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
-                 STRLITERALLEN("/setgroups") + 1];
-
-       if (geteuid() != 0 && map_type == ID_TYPE_GID) {
-               ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
-               if (ret < 0 || ret >= sizeof(path))
-                       goto out;
-
-               setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
-               if (setgroups_fd < 0 && errno != ENOENT) {
-                       syserror("Failed to open \"%s\"", path);
-                       goto out;
-               }
-
-               if (setgroups_fd >= 0) {
-                       ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
-                       if (ret != STRLITERALLEN("deny\n")) {
-                               syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
-                               goto out;
-                       }
-               }
-       }
-
-       ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
-       if (ret < 0 || ret >= sizeof(path))
-               goto out;
-
-       fd = open(path, O_WRONLY | O_CLOEXEC);
-       if (fd < 0) {
-               syserror("Failed to open \"%s\"", path);
-               goto out;
-       }
-
-       ret = write_nointr(fd, buf, buf_size);
-       if (ret != buf_size) {
-               syserror("Failed to write %cid mapping to \"%s\"",
-                        map_type == ID_TYPE_UID ? 'u' : 'g', path);
-               goto out;
-       }
-
-       fret = 0;
-out:
-       if (fd >= 0)
-               close(fd);
-       if (setgroups_fd >= 0)
-               close(setgroups_fd);
-
-       return fret;
-}
-
-static int map_ids_from_idmap(struct list *idmap, pid_t pid)
-{
-       int fill, left;
-       char mapbuf[4096] = {};
-       bool had_entry = false;
-       idmap_type_t map_type, u_or_g;
-
-       for (map_type = ID_TYPE_UID, u_or_g = 'u';
-            map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
-               char *pos = mapbuf;
-               int ret;
-               struct list *iterator;
-
-
-               list_for_each(iterator, idmap) {
-                       struct id_map *map = iterator->elem;
-                       if (map->map_type != map_type)
-                               continue;
-
-                       had_entry = true;
-
-                       left = 4096 - (pos - mapbuf);
-                       fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
-                       /*
-                        * The kernel only takes <= 4k for writes to
-                        * /proc/<pid>/{g,u}id_map
-                        */
-                       if (fill <= 0 || fill >= left)
-                               return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
-
-                       pos += fill;
-               }
-               if (!had_entry)
-                       continue;
-
-               ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
-               if (ret < 0)
-                       return syserror("Failed to write mapping: %s", mapbuf);
-
-               memset(mapbuf, 0, sizeof(mapbuf));
-       }
-
-       return 0;
-}
-
-static int get_userns_fd_from_idmap(struct list *idmap)
-{
-       int ret;
-       pid_t pid;
-       char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
-                 STRLITERALLEN("/ns/user") + 1];
-
-       pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
-       if (pid < 0)
-               return -errno;
-
-       ret = map_ids_from_idmap(idmap, pid);
-       if (ret < 0)
-               return ret;
-
-       ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
-       if (ret < 0 || (size_t)ret >= sizeof(path_ns))
-               ret = -EIO;
-       else
-               ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
-
-       (void)kill(pid, SIGKILL);
-       (void)wait_for_pid(pid);
-       return ret;
-}
-
 static inline bool strnequal(const char *str, const char *eq, size_t len)
 {
        return strncmp(str, eq, len) == 0;
index 977443f1e75f7729aa1c93bcc173997f77cdc3b8..e54f481d24e9cd796a4838b44ce676edfb9d645e 100644 (file)
@@ -36,99 +36,192 @@ ssize_t write_nointr(int fd, const void *buf, size_t count)
        return ret;
 }
 
-static int write_file(const char *path, const void *buf, size_t count)
+#define __STACK_SIZE (8 * 1024 * 1024)
+pid_t do_clone(int (*fn)(void *), void *arg, int flags)
 {
-       int fd;
-       ssize_t ret;
+       void *stack;
 
-       fd = open(path, O_WRONLY | O_CLOEXEC | O_NOCTTY | O_NOFOLLOW);
-       if (fd < 0)
-               return -1;
+       stack = malloc(__STACK_SIZE);
+       if (!stack)
+               return -ENOMEM;
 
-       ret = write_nointr(fd, buf, count);
-       close(fd);
-       if (ret < 0 || (size_t)ret != count)
-               return -1;
+#ifdef __ia64__
+       return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
+#else
+       return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
+#endif
+}
 
+static int get_userns_fd_cb(void *data)
+{
        return 0;
 }
 
-static int map_ids(pid_t pid, unsigned long nsid, unsigned long hostid,
-                  unsigned long range)
+int wait_for_pid(pid_t pid)
 {
-       char map[100], procfile[256];
+       int status, ret;
 
-       snprintf(procfile, sizeof(procfile), "/proc/%d/uid_map", pid);
-       snprintf(map, sizeof(map), "%lu %lu %lu", nsid, hostid, range);
-       if (write_file(procfile, map, strlen(map)))
-               return -1;
+again:
+       ret = waitpid(pid, &status, 0);
+       if (ret == -1) {
+               if (errno == EINTR)
+                       goto again;
 
+               return -1;
+       }
 
-       snprintf(procfile, sizeof(procfile), "/proc/%d/gid_map", pid);
-       snprintf(map, sizeof(map), "%lu %lu %lu", nsid, hostid, range);
-       if (write_file(procfile, map, strlen(map)))
+       if (!WIFEXITED(status))
                return -1;
 
-       return 0;
+       return WEXITSTATUS(status);
 }
 
-#define __STACK_SIZE (8 * 1024 * 1024)
-pid_t do_clone(int (*fn)(void *), void *arg, int flags)
+static int write_id_mapping(idmap_type_t map_type, pid_t pid, const char *buf, size_t buf_size)
 {
-       void *stack;
+       int fd = -EBADF, setgroups_fd = -EBADF;
+       int fret = -1;
+       int ret;
+       char path[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
+                 STRLITERALLEN("/setgroups") + 1];
+
+       if (geteuid() != 0 && map_type == ID_TYPE_GID) {
+               ret = snprintf(path, sizeof(path), "/proc/%d/setgroups", pid);
+               if (ret < 0 || ret >= sizeof(path))
+                       goto out;
+
+               setgroups_fd = open(path, O_WRONLY | O_CLOEXEC);
+               if (setgroups_fd < 0 && errno != ENOENT) {
+                       syserror("Failed to open \"%s\"", path);
+                       goto out;
+               }
+
+               if (setgroups_fd >= 0) {
+                       ret = write_nointr(setgroups_fd, "deny\n", STRLITERALLEN("deny\n"));
+                       if (ret != STRLITERALLEN("deny\n")) {
+                               syserror("Failed to write \"deny\" to \"/proc/%d/setgroups\"", pid);
+                               goto out;
+                       }
+               }
+       }
 
-       stack = malloc(__STACK_SIZE);
-       if (!stack)
-               return -ENOMEM;
+       ret = snprintf(path, sizeof(path), "/proc/%d/%cid_map", pid, map_type == ID_TYPE_UID ? 'u' : 'g');
+       if (ret < 0 || ret >= sizeof(path))
+               goto out;
 
-#ifdef __ia64__
-       return __clone2(fn, stack, __STACK_SIZE, flags | SIGCHLD, arg, NULL);
-#else
-       return clone(fn, stack + __STACK_SIZE, flags | SIGCHLD, arg, NULL);
-#endif
+       fd = open(path, O_WRONLY | O_CLOEXEC);
+       if (fd < 0) {
+               syserror("Failed to open \"%s\"", path);
+               goto out;
+       }
+
+       ret = write_nointr(fd, buf, buf_size);
+       if (ret != buf_size) {
+               syserror("Failed to write %cid mapping to \"%s\"",
+                        map_type == ID_TYPE_UID ? 'u' : 'g', path);
+               goto out;
+       }
+
+       fret = 0;
+out:
+       if (fd >= 0)
+               close(fd);
+       if (setgroups_fd >= 0)
+               close(setgroups_fd);
+
+       return fret;
 }
 
-int get_userns_fd_cb(void *data)
+static int map_ids_from_idmap(struct list *idmap, pid_t pid)
 {
-       return kill(getpid(), SIGSTOP);
+       int fill, left;
+       char mapbuf[4096] = {};
+       bool had_entry = false;
+
+       for (idmap_type_t map_type = ID_TYPE_UID, u_or_g = 'u';
+            map_type <= ID_TYPE_GID; map_type++, u_or_g = 'g') {
+               char *pos = mapbuf;
+               int ret;
+               struct list *iterator;
+
+
+               list_for_each(iterator, idmap) {
+                       struct id_map *map = iterator->elem;
+                       if (map->map_type != map_type)
+                               continue;
+
+                       had_entry = true;
+
+                       left = 4096 - (pos - mapbuf);
+                       fill = snprintf(pos, left, "%u %u %u\n", map->nsid, map->hostid, map->range);
+                       /*
+                        * The kernel only takes <= 4k for writes to
+                        * /proc/<pid>/{g,u}id_map
+                        */
+                       if (fill <= 0 || fill >= left)
+                               return syserror_set(-E2BIG, "Too many %cid mappings defined", u_or_g);
+
+                       pos += fill;
+               }
+               if (!had_entry)
+                       continue;
+
+               ret = write_id_mapping(map_type, pid, mapbuf, pos - mapbuf);
+               if (ret < 0)
+                       return syserror("Failed to write mapping: %s", mapbuf);
+
+               memset(mapbuf, 0, sizeof(mapbuf));
+       }
+
+       return 0;
 }
 
-int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
+int get_userns_fd_from_idmap(struct list *idmap)
 {
        int ret;
        pid_t pid;
-       char path[256];
+       char path_ns[STRLITERALLEN("/proc/") + INTTYPE_TO_STRLEN(pid_t) +
+                    STRLITERALLEN("/ns/user") + 1];
 
-       pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER);
+       pid = do_clone(get_userns_fd_cb, NULL, CLONE_NEWUSER | CLONE_NEWNS);
        if (pid < 0)
                return -errno;
 
-       ret = map_ids(pid, nsid, hostid, range);
+       ret = map_ids_from_idmap(idmap, pid);
        if (ret < 0)
                return ret;
 
-       snprintf(path, sizeof(path), "/proc/%d/ns/user", pid);
-       ret = open(path, O_RDONLY | O_CLOEXEC);
-       kill(pid, SIGKILL);
-       wait_for_pid(pid);
+       ret = snprintf(path_ns, sizeof(path_ns), "/proc/%d/ns/user", pid);
+       if (ret < 0 || (size_t)ret >= sizeof(path_ns))
+               ret = -EIO;
+       else
+               ret = open(path_ns, O_RDONLY | O_CLOEXEC | O_NOCTTY);
+
+       (void)kill(pid, SIGKILL);
+       (void)wait_for_pid(pid);
        return ret;
 }
 
-int wait_for_pid(pid_t pid)
+int get_userns_fd(unsigned long nsid, unsigned long hostid, unsigned long range)
 {
-       int status, ret;
-
-again:
-       ret = waitpid(pid, &status, 0);
-       if (ret == -1) {
-               if (errno == EINTR)
-                       goto again;
-
-               return -1;
-       }
-
-       if (!WIFEXITED(status))
-               return -1;
-
-       return WEXITSTATUS(status);
+       struct list head, uid_mapl, gid_mapl;
+       struct id_map uid_map = {
+               .map_type       = ID_TYPE_UID,
+               .nsid           = nsid,
+               .hostid         = hostid,
+               .range          = range,
+       };
+       struct id_map gid_map = {
+               .map_type       = ID_TYPE_GID,
+               .nsid           = nsid,
+               .hostid         = hostid,
+               .range          = range,
+       };
+
+       list_init(&head);
+       uid_mapl.elem = &uid_map;
+       gid_mapl.elem = &gid_map;
+       list_add_tail(&head, &uid_mapl);
+       list_add_tail(&head, &gid_mapl);
+
+       return get_userns_fd_from_idmap(&head);
 }
index efbf3bc36b81c243814d269345a85950824c9c95..4f976f9fbfe7cff111774d93347bf31e04c46b36 100644 (file)
 
 #include "missing.h"
 
+/* A few helpful macros. */
+#define STRLITERALLEN(x) (sizeof(""x"") - 1)
+
+#define INTTYPE_TO_STRLEN(type)             \
+       (2 + (sizeof(type) <= 1             \
+                 ? 3                       \
+                 : sizeof(type) <= 2       \
+                       ? 5                 \
+                       : sizeof(type) <= 4 \
+                             ? 10          \
+                             : sizeof(type) <= 8 ? 20 : sizeof(int[-2 * (sizeof(type) > 8)])))
+
+#define syserror(format, ...)                           \
+       ({                                              \
+               fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__); \
+               (-errno);                               \
+       })
+
+#define syserror_set(__ret__, format, ...)                    \
+       ({                                                    \
+               typeof(__ret__) __internal_ret__ = (__ret__); \
+               errno = labs(__ret__);                        \
+               fprintf(stderr, "%m - " format "\n", ##__VA_ARGS__);       \
+               __internal_ret__;                             \
+       })
+
+typedef enum idmap_type_t {
+       ID_TYPE_UID,
+       ID_TYPE_GID
+} idmap_type_t;
+
+struct id_map {
+       idmap_type_t map_type;
+       __u32 nsid;
+       __u32 hostid;
+       __u32 range;
+};
+
+struct list {
+       void *elem;
+       struct list *next;
+       struct list *prev;
+};
+
+#define list_for_each(__iterator, __list) \
+       for (__iterator = (__list)->next; __iterator != __list; __iterator = __iterator->next)
+
+static inline void list_init(struct list *list)
+{
+       list->elem = NULL;
+       list->next = list->prev = list;
+}
+
+static inline int list_empty(const struct list *list)
+{
+       return list == list->next;
+}
+
+static inline void __list_add(struct list *new, struct list *prev, struct list *next)
+{
+       next->prev = new;
+       new->next = next;
+       new->prev = prev;
+       prev->next = new;
+}
+
+static inline void list_add_tail(struct list *head, struct list *list)
+{
+       __list_add(list, head->prev, head);
+}
+
 extern pid_t do_clone(int (*fn)(void *), void *arg, int flags);
-extern int get_userns_fd_cb(void *data);
 extern int get_userns_fd(unsigned long nsid, unsigned long hostid,
                         unsigned long range);
+extern int get_userns_fd_from_idmap(struct list *idmap);
 extern ssize_t read_nointr(int fd, void *buf, size_t count);
 extern int wait_for_pid(pid_t pid);
 extern ssize_t write_nointr(int fd, const void *buf, size_t count);