A destructor kfunc can be defined as void func(type *), where type may
be void or any other pointer type as per convenience.
In this patch, we ensure that the type is sane and capture the function
pointer into off_desc of ptr_off_tab for the specific pointer offset,
with the invariant that the dtor pointer is always set when 'kptr_ref'
tag is applied to the pointer's pointee type, which is indicated by the
flag BPF_MAP_VALUE_OFF_F_REF.
Note that only BTF IDs whose destructor kfunc is registered, thus become
the allowed BTF IDs for embedding as referenced kptr. Hence it serves
the purpose of finding dtor kfunc BTF ID, as well acting as a check
against the whitelist of allowed BTF IDs for this purpose.
Finally, wire up the actual freeing of the referenced pointer if any at
all available offsets, so that no references are leaked after the BPF
map goes away and the BPF program previously moved the ownership a
referenced pointer into it.
The behavior is similar to BPF timers, where bpf_map_{update,delete}_elem
will free any existing referenced kptr. The same case is with LRU map's
bpf_lru_push_free/htab_lru_push_free functions, which are extended to
reset unreferenced and free referenced kptr.
Note that unlike BPF timers, kptr is not reset or freed when map uref
drops to zero.
Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20220424214901.2743946-8-memxor@gmail.com
 #include <linux/slab.h>
 #include <linux/percpu-refcount.h>
 #include <linux/bpfptr.h>
+#include <linux/btf.h>
 
 struct bpf_verifier_env;
 struct bpf_verifier_log;
        enum bpf_kptr_type type;
        struct {
                struct btf *btf;
+               struct module *module;
+               btf_dtor_kfunc_t dtor;
                u32 btf_id;
        } kptr;
 };
 void bpf_map_free_kptr_off_tab(struct bpf_map *map);
 struct bpf_map_value_off *bpf_map_copy_kptr_off_tab(const struct bpf_map *map);
 bool bpf_map_equal_kptr_off_tab(const struct bpf_map *map_a, const struct bpf_map *map_b);
+void bpf_map_free_kptrs(struct bpf_map *map, void *map_value);
 
 struct bpf_map *bpf_map_get(u32 ufd);
 struct bpf_map *bpf_map_get_with_uref(u32 ufd);
 
        u32 kfunc_btf_id;
 };
 
+typedef void (*btf_dtor_kfunc_t)(void *);
+
 extern const struct file_operations btf_fops;
 
 void btf_get(struct btf *btf);
 
        return 0;
 }
 
-static void check_and_free_timer_in_array(struct bpf_array *arr, void *val)
+static void check_and_free_fields(struct bpf_array *arr, void *val)
 {
-       if (unlikely(map_value_has_timer(&arr->map)))
+       if (map_value_has_timer(&arr->map))
                bpf_timer_cancel_and_free(val + arr->map.timer_off);
+       if (map_value_has_kptrs(&arr->map))
+               bpf_map_free_kptrs(&arr->map, val);
 }
 
 /* Called from syscall or from eBPF program */
                        copy_map_value_locked(map, val, value, false);
                else
                        copy_map_value(map, val, value);
-               check_and_free_timer_in_array(array, val);
+               check_and_free_fields(array, val);
        }
        return 0;
 }
        struct bpf_array *array = container_of(map, struct bpf_array, map);
        int i;
 
-       if (likely(!map_value_has_timer(map)))
+       /* We don't reset or free kptr on uref dropping to zero. */
+       if (!map_value_has_timer(map))
                return;
 
        for (i = 0; i < array->map.max_entries; i++)
 static void array_map_free(struct bpf_map *map)
 {
        struct bpf_array *array = container_of(map, struct bpf_array, map);
+       int i;
+
+       if (map_value_has_kptrs(map)) {
+               for (i = 0; i < array->map.max_entries; i++)
+                       bpf_map_free_kptrs(map, array->value + array->elem_size * i);
+               bpf_map_free_kptr_off_tab(map);
+       }
 
        if (array->map.map_type == BPF_MAP_TYPE_PERCPU_ARRAY)
                bpf_array_free_percpu(array);
 
        struct btf_field_info info_arr[BPF_MAP_VALUE_OFF_MAX];
        struct bpf_map_value_off *tab;
        struct btf *kernel_btf = NULL;
+       struct module *mod = NULL;
        int ret, i, nr_off;
 
        ret = btf_find_field(btf, t, BTF_FIELD_KPTR, info_arr, ARRAY_SIZE(info_arr));
                        goto end;
                }
 
+               /* Find and stash the function pointer for the destruction function that
+                * needs to be eventually invoked from the map free path.
+                */
+               if (info_arr[i].type == BPF_KPTR_REF) {
+                       const struct btf_type *dtor_func;
+                       const char *dtor_func_name;
+                       unsigned long addr;
+                       s32 dtor_btf_id;
+
+                       /* This call also serves as a whitelist of allowed objects that
+                        * can be used as a referenced pointer and be stored in a map at
+                        * the same time.
+                        */
+                       dtor_btf_id = btf_find_dtor_kfunc(kernel_btf, id);
+                       if (dtor_btf_id < 0) {
+                               ret = dtor_btf_id;
+                               goto end_btf;
+                       }
+
+                       dtor_func = btf_type_by_id(kernel_btf, dtor_btf_id);
+                       if (!dtor_func) {
+                               ret = -ENOENT;
+                               goto end_btf;
+                       }
+
+                       if (btf_is_module(kernel_btf)) {
+                               mod = btf_try_get_module(kernel_btf);
+                               if (!mod) {
+                                       ret = -ENXIO;
+                                       goto end_btf;
+                               }
+                       }
+
+                       /* We already verified dtor_func to be btf_type_is_func
+                        * in register_btf_id_dtor_kfuncs.
+                        */
+                       dtor_func_name = __btf_name_by_offset(kernel_btf, dtor_func->name_off);
+                       addr = kallsyms_lookup_name(dtor_func_name);
+                       if (!addr) {
+                               ret = -EINVAL;
+                               goto end_mod;
+                       }
+                       tab->off[i].kptr.dtor = (void *)addr;
+               }
+
                tab->off[i].offset = info_arr[i].off;
                tab->off[i].type = info_arr[i].type;
                tab->off[i].kptr.btf_id = id;
                tab->off[i].kptr.btf = kernel_btf;
+               tab->off[i].kptr.module = mod;
        }
        tab->nr_off = nr_off;
        return tab;
+end_mod:
+       module_put(mod);
+end_btf:
+       btf_put(kernel_btf);
 end:
-       while (i--)
+       while (i--) {
                btf_put(tab->off[i].kptr.btf);
+               if (tab->off[i].kptr.module)
+                       module_put(tab->off[i].kptr.module);
+       }
        kfree(tab);
        return ERR_PTR(ret);
 }
        return dtor->kfunc_btf_id;
 }
 
+static int btf_check_dtor_kfuncs(struct btf *btf, const struct btf_id_dtor_kfunc *dtors, u32 cnt)
+{
+       const struct btf_type *dtor_func, *dtor_func_proto, *t;
+       const struct btf_param *args;
+       s32 dtor_btf_id;
+       u32 nr_args, i;
+
+       for (i = 0; i < cnt; i++) {
+               dtor_btf_id = dtors[i].kfunc_btf_id;
+
+               dtor_func = btf_type_by_id(btf, dtor_btf_id);
+               if (!dtor_func || !btf_type_is_func(dtor_func))
+                       return -EINVAL;
+
+               dtor_func_proto = btf_type_by_id(btf, dtor_func->type);
+               if (!dtor_func_proto || !btf_type_is_func_proto(dtor_func_proto))
+                       return -EINVAL;
+
+               /* Make sure the prototype of the destructor kfunc is 'void func(type *)' */
+               t = btf_type_by_id(btf, dtor_func_proto->type);
+               if (!t || !btf_type_is_void(t))
+                       return -EINVAL;
+
+               nr_args = btf_type_vlen(dtor_func_proto);
+               if (nr_args != 1)
+                       return -EINVAL;
+               args = btf_params(dtor_func_proto);
+               t = btf_type_by_id(btf, args[0].type);
+               /* Allow any pointer type, as width on targets Linux supports
+                * will be same for all pointer types (i.e. sizeof(void *))
+                */
+               if (!t || !btf_type_is_ptr(t))
+                       return -EINVAL;
+       }
+       return 0;
+}
+
 /* This function must be invoked only from initcalls/module init functions */
 int register_btf_id_dtor_kfuncs(const struct btf_id_dtor_kfunc *dtors, u32 add_cnt,
                                struct module *owner)
                goto end;
        }
 
+       /* Ensure that the prototype of dtor kfuncs being registered is sane */
+       ret = btf_check_dtor_kfuncs(btf, dtors, add_cnt);
+       if (ret < 0)
+               goto end;
+
        tab = btf->dtor_kfunc_tab;
        /* Only one call allowed for modules */
        if (WARN_ON_ONCE(tab && btf_is_module(btf))) {
 
        u32 num_entries = htab->map.max_entries;
        int i;
 
-       if (likely(!map_value_has_timer(&htab->map)))
+       if (!map_value_has_timer(&htab->map))
                return;
        if (htab_has_extra_elems(htab))
                num_entries += num_possible_cpus();
        }
 }
 
+static void htab_free_prealloced_kptrs(struct bpf_htab *htab)
+{
+       u32 num_entries = htab->map.max_entries;
+       int i;
+
+       if (!map_value_has_kptrs(&htab->map))
+               return;
+       if (htab_has_extra_elems(htab))
+               num_entries += num_possible_cpus();
+
+       for (i = 0; i < num_entries; i++) {
+               struct htab_elem *elem;
+
+               elem = get_htab_elem(htab, i);
+               bpf_map_free_kptrs(&htab->map, elem->key + round_up(htab->map.key_size, 8));
+               cond_resched();
+       }
+}
+
 static void htab_free_elems(struct bpf_htab *htab)
 {
        int i;
        return insn - insn_buf;
 }
 
-static void check_and_free_timer(struct bpf_htab *htab, struct htab_elem *elem)
+static void check_and_free_fields(struct bpf_htab *htab,
+                                 struct htab_elem *elem)
 {
-       if (unlikely(map_value_has_timer(&htab->map)))
-               bpf_timer_cancel_and_free(elem->key +
-                                         round_up(htab->map.key_size, 8) +
-                                         htab->map.timer_off);
+       void *map_value = elem->key + round_up(htab->map.key_size, 8);
+
+       if (map_value_has_timer(&htab->map))
+               bpf_timer_cancel_and_free(map_value + htab->map.timer_off);
+       if (map_value_has_kptrs(&htab->map))
+               bpf_map_free_kptrs(&htab->map, map_value);
 }
 
 /* It is called from the bpf_lru_list when the LRU needs to delete
        hlist_nulls_for_each_entry_rcu(l, n, head, hash_node)
                if (l == tgt_l) {
                        hlist_nulls_del_rcu(&l->hash_node);
-                       check_and_free_timer(htab, l);
+                       check_and_free_fields(htab, l);
                        break;
                }
 
 {
        if (htab->map.map_type == BPF_MAP_TYPE_PERCPU_HASH)
                free_percpu(htab_elem_get_ptr(l, htab->map.key_size));
-       check_and_free_timer(htab, l);
+       check_and_free_fields(htab, l);
        kfree(l);
 }
 
        htab_put_fd_value(htab, l);
 
        if (htab_is_prealloc(htab)) {
-               check_and_free_timer(htab, l);
+               check_and_free_fields(htab, l);
                __pcpu_freelist_push(&htab->freelist, &l->fnode);
        } else {
                atomic_dec(&htab->count);
                if (!htab_is_prealloc(htab))
                        free_htab_elem(htab, l_old);
                else
-                       check_and_free_timer(htab, l_old);
+                       check_and_free_fields(htab, l_old);
        }
        ret = 0;
 err:
 
 static void htab_lru_push_free(struct bpf_htab *htab, struct htab_elem *elem)
 {
-       check_and_free_timer(htab, elem);
+       check_and_free_fields(htab, elem);
        bpf_lru_push_free(&htab->lru, &elem->lru_node);
 }
 
                struct hlist_nulls_node *n;
                struct htab_elem *l;
 
-               hlist_nulls_for_each_entry(l, n, head, hash_node)
-                       check_and_free_timer(htab, l);
+               hlist_nulls_for_each_entry(l, n, head, hash_node) {
+                       /* We don't reset or free kptr on uref dropping to zero,
+                        * hence just free timer.
+                        */
+                       bpf_timer_cancel_and_free(l->key +
+                                                 round_up(htab->map.key_size, 8) +
+                                                 htab->map.timer_off);
+               }
                cond_resched_rcu();
        }
        rcu_read_unlock();
 {
        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 
-       if (likely(!map_value_has_timer(&htab->map)))
+       /* We don't reset or free kptr on uref dropping to zero. */
+       if (!map_value_has_timer(&htab->map))
                return;
        if (!htab_is_prealloc(htab))
                htab_free_malloced_timers(htab);
         * not have executed. Wait for them.
         */
        rcu_barrier();
-       if (!htab_is_prealloc(htab))
+       if (!htab_is_prealloc(htab)) {
                delete_all_elements(htab);
-       else
+       } else {
+               htab_free_prealloced_kptrs(htab);
                prealloc_destroy(htab);
+       }
 
+       bpf_map_free_kptr_off_tab(map);
        free_percpu(htab->extra_elems);
        bpf_map_area_free(htab->buckets);
        for (i = 0; i < HASHTAB_MAP_LOCK_COUNT; i++)
 
 
        if (!map_value_has_kptrs(map))
                return;
-       for (i = 0; i < tab->nr_off; i++)
+       for (i = 0; i < tab->nr_off; i++) {
+               if (tab->off[i].kptr.module)
+                       module_put(tab->off[i].kptr.module);
                btf_put(tab->off[i].kptr.btf);
+       }
        kfree(tab);
        map->kptr_off_tab = NULL;
 }
        if (!new_tab)
                return ERR_PTR(-ENOMEM);
        /* Do a deep copy of the kptr_off_tab */
-       for (i = 0; i < tab->nr_off; i++)
+       for (i = 0; i < tab->nr_off; i++) {
                btf_get(tab->off[i].kptr.btf);
+               if (tab->off[i].kptr.module && !try_module_get(tab->off[i].kptr.module)) {
+                       while (i--) {
+                               if (tab->off[i].kptr.module)
+                                       module_put(tab->off[i].kptr.module);
+                               btf_put(tab->off[i].kptr.btf);
+                       }
+                       kfree(new_tab);
+                       return ERR_PTR(-ENXIO);
+               }
+       }
        return new_tab;
 }
 
        return !memcmp(tab_a, tab_b, size);
 }
 
+/* Caller must ensure map_value_has_kptrs is true. Note that this function can
+ * be called on a map value while the map_value is visible to BPF programs, as
+ * it ensures the correct synchronization, and we already enforce the same using
+ * the bpf_kptr_xchg helper on the BPF program side for referenced kptrs.
+ */
+void bpf_map_free_kptrs(struct bpf_map *map, void *map_value)
+{
+       struct bpf_map_value_off *tab = map->kptr_off_tab;
+       unsigned long *btf_id_ptr;
+       int i;
+
+       for (i = 0; i < tab->nr_off; i++) {
+               struct bpf_map_value_off_desc *off_desc = &tab->off[i];
+               unsigned long old_ptr;
+
+               btf_id_ptr = map_value + off_desc->offset;
+               if (off_desc->type == BPF_KPTR_UNREF) {
+                       u64 *p = (u64 *)btf_id_ptr;
+
+                       WRITE_ONCE(p, 0);
+                       continue;
+               }
+               old_ptr = xchg(btf_id_ptr, 0);
+               off_desc->kptr.dtor((void *)old_ptr);
+       }
+}
+
 /* called from workqueue */
 static void bpf_map_free_deferred(struct work_struct *work)
 {
 
        security_bpf_map_free(map);
        kfree(map->off_arr);
-       bpf_map_free_kptr_off_tab(map);
        bpf_map_release_memcg(map);
-       /* implementation dependent freeing */
+       /* implementation dependent freeing, map_free callback also does
+        * bpf_map_free_kptr_off_tab, if needed.
+        */
        map->ops->map_free(map);
 }