* add rsp, 8                      // skip eth_type_trans's frame
  * ret                             // return to its caller
  */
-int arch_prepare_bpf_trampoline(void *image, void *image_end,
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
                                const struct btf_func_model *m, u32 flags,
                                struct bpf_tramp_progs *tprogs,
                                void *orig_call)
 
        save_regs(m, &prog, nr_args, stack_size);
 
+       if (flags & BPF_TRAMP_F_CALL_ORIG) {
+               /* arg1: mov rdi, im */
+               emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+               if (emit_call(&prog, __bpf_tramp_enter, prog)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
+       }
+
        if (fentry->nr_progs)
                if (invoke_bpf(m, &prog, fentry, stack_size))
                        return -EINVAL;
        }
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
-               if (fentry->nr_progs || fmod_ret->nr_progs)
-                       restore_regs(m, &prog, nr_args, stack_size);
+               restore_regs(m, &prog, nr_args, stack_size);
 
                /* call original function */
                if (emit_call(&prog, orig_call, prog)) {
                }
                /* remember return value in a stack for bpf prog to access */
                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
+               im->ip_after_call = prog;
+               emit_nops(&prog, 5);
        }
 
        if (fmod_ret->nr_progs) {
         * the return value is only updated on the stack and still needs to be
         * restored to R0.
         */
-       if (flags & BPF_TRAMP_F_CALL_ORIG)
+       if (flags & BPF_TRAMP_F_CALL_ORIG) {
+               im->ip_epilogue = prog;
+               /* arg1: mov rdi, im */
+               emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
+               if (emit_call(&prog, __bpf_tramp_exit, prog)) {
+                       ret = -EINVAL;
+                       goto cleanup;
+               }
                /* restore original return value back into RAX */
                emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
+       }
 
        EMIT1(0x5B); /* pop rbx */
        EMIT1(0xC9); /* leave */
 
 #include <linux/capability.h>
 #include <linux/sched/mm.h>
 #include <linux/slab.h>
+#include <linux/percpu-refcount.h>
 
 struct bpf_verifier_env;
 struct bpf_verifier_log;
  *      fentry = a set of program to run before calling original function
  *      fexit = a set of program to run after original function
  */
-int arch_prepare_bpf_trampoline(void *image, void *image_end,
+struct bpf_tramp_image;
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *tr, void *image, void *image_end,
                                const struct btf_func_model *m, u32 flags,
                                struct bpf_tramp_progs *tprogs,
                                void *orig_call);
 void notrace __bpf_prog_exit(struct bpf_prog *prog, u64 start);
 u64 notrace __bpf_prog_enter_sleepable(struct bpf_prog *prog);
 void notrace __bpf_prog_exit_sleepable(struct bpf_prog *prog, u64 start);
+void notrace __bpf_tramp_enter(struct bpf_tramp_image *tr);
+void notrace __bpf_tramp_exit(struct bpf_tramp_image *tr);
 
 struct bpf_ksym {
        unsigned long            start;
        BPF_TRAMP_REPLACE, /* more than MAX */
 };
 
+struct bpf_tramp_image {
+       void *image;
+       struct bpf_ksym ksym;
+       struct percpu_ref pcref;
+       void *ip_after_call;
+       void *ip_epilogue;
+       union {
+               struct rcu_head rcu;
+               struct work_struct work;
+       };
+};
+
 struct bpf_trampoline {
        /* hlist for trampoline_table */
        struct hlist_node hlist;
        /* Number of attached programs. A counter per kind. */
        int progs_cnt[BPF_TRAMP_MAX];
        /* Executable image of trampoline */
-       void *image;
+       struct bpf_tramp_image *cur_image;
        u64 selector;
-       struct bpf_ksym ksym;
 };
 
 struct bpf_attach_target_info {
 void bpf_image_ksym_del(struct bpf_ksym *ksym);
 void bpf_ksym_add(struct bpf_ksym *ksym);
 void bpf_ksym_del(struct bpf_ksym *ksym);
+int bpf_jit_charge_modmem(u32 pages);
+void bpf_jit_uncharge_modmem(u32 pages);
 #else
 static inline int bpf_trampoline_link_prog(struct bpf_prog *prog,
                                           struct bpf_trampoline *tr)
        bool func_proto_unreliable;
        bool sleepable;
        bool tail_call_reachable;
-       enum bpf_tramp_prog_type trampoline_prog_type;
        struct hlist_node tramp_hlist;
        /* BTF_KIND_FUNC_PROTO for valid attach_btf_id */
        const struct btf_type *attach_func_proto;
 
                           PAGE_SIZE, true, ksym->name);
 }
 
-static void bpf_trampoline_ksym_add(struct bpf_trampoline *tr)
-{
-       struct bpf_ksym *ksym = &tr->ksym;
-
-       snprintf(ksym->name, KSYM_NAME_LEN, "bpf_trampoline_%llu", tr->key);
-       bpf_image_ksym_add(tr->image, ksym);
-}
-
 static struct bpf_trampoline *bpf_trampoline_lookup(u64 key)
 {
        struct bpf_trampoline *tr;
        struct hlist_head *head;
-       void *image;
        int i;
 
        mutex_lock(&trampoline_mutex);
        if (!tr)
                goto out;
 
-       /* is_root was checked earlier. No need for bpf_jit_charge_modmem() */
-       image = bpf_jit_alloc_exec_page();
-       if (!image) {
-               kfree(tr);
-               tr = NULL;
-               goto out;
-       }
-
        tr->key = key;
        INIT_HLIST_NODE(&tr->hlist);
        hlist_add_head(&tr->hlist, head);
        mutex_init(&tr->mutex);
        for (i = 0; i < BPF_TRAMP_MAX; i++)
                INIT_HLIST_HEAD(&tr->progs_hlist[i]);
-       tr->image = image;
-       INIT_LIST_HEAD_RCU(&tr->ksym.lnode);
-       bpf_trampoline_ksym_add(tr);
 out:
        mutex_unlock(&trampoline_mutex);
        return tr;
        return tprogs;
 }
 
+static void __bpf_tramp_image_put_deferred(struct work_struct *work)
+{
+       struct bpf_tramp_image *im;
+
+       im = container_of(work, struct bpf_tramp_image, work);
+       bpf_image_ksym_del(&im->ksym);
+       bpf_jit_free_exec(im->image);
+       bpf_jit_uncharge_modmem(1);
+       percpu_ref_exit(&im->pcref);
+       kfree_rcu(im, rcu);
+}
+
+/* callback, fexit step 3 or fentry step 2 */
+static void __bpf_tramp_image_put_rcu(struct rcu_head *rcu)
+{
+       struct bpf_tramp_image *im;
+
+       im = container_of(rcu, struct bpf_tramp_image, rcu);
+       INIT_WORK(&im->work, __bpf_tramp_image_put_deferred);
+       schedule_work(&im->work);
+}
+
+/* callback, fexit step 2. Called after percpu_ref_kill confirms. */
+static void __bpf_tramp_image_release(struct percpu_ref *pcref)
+{
+       struct bpf_tramp_image *im;
+
+       im = container_of(pcref, struct bpf_tramp_image, pcref);
+       call_rcu_tasks(&im->rcu, __bpf_tramp_image_put_rcu);
+}
+
+/* callback, fexit or fentry step 1 */
+static void __bpf_tramp_image_put_rcu_tasks(struct rcu_head *rcu)
+{
+       struct bpf_tramp_image *im;
+
+       im = container_of(rcu, struct bpf_tramp_image, rcu);
+       if (im->ip_after_call)
+               /* the case of fmod_ret/fexit trampoline and CONFIG_PREEMPTION=y */
+               percpu_ref_kill(&im->pcref);
+       else
+               /* the case of fentry trampoline */
+               call_rcu_tasks(&im->rcu, __bpf_tramp_image_put_rcu);
+}
+
+static void bpf_tramp_image_put(struct bpf_tramp_image *im)
+{
+       /* The trampoline image that calls original function is using:
+        * rcu_read_lock_trace to protect sleepable bpf progs
+        * rcu_read_lock to protect normal bpf progs
+        * percpu_ref to protect trampoline itself
+        * rcu tasks to protect trampoline asm not covered by percpu_ref
+        * (which are few asm insns before __bpf_tramp_enter and
+        *  after __bpf_tramp_exit)
+        *
+        * The trampoline is unreachable before bpf_tramp_image_put().
+        *
+        * First, patch the trampoline to avoid calling into fexit progs.
+        * The progs will be freed even if the original function is still
+        * executing or sleeping.
+        * In case of CONFIG_PREEMPT=y use call_rcu_tasks() to wait on
+        * first few asm instructions to execute and call into
+        * __bpf_tramp_enter->percpu_ref_get.
+        * Then use percpu_ref_kill to wait for the trampoline and the original
+        * function to finish.
+        * Then use call_rcu_tasks() to make sure few asm insns in
+        * the trampoline epilogue are done as well.
+        *
+        * In !PREEMPT case the task that got interrupted in the first asm
+        * insns won't go through an RCU quiescent state which the
+        * percpu_ref_kill will be waiting for. Hence the first
+        * call_rcu_tasks() is not necessary.
+        */
+       if (im->ip_after_call) {
+               int err = bpf_arch_text_poke(im->ip_after_call, BPF_MOD_JUMP,
+                                            NULL, im->ip_epilogue);
+               WARN_ON(err);
+               if (IS_ENABLED(CONFIG_PREEMPTION))
+                       call_rcu_tasks(&im->rcu, __bpf_tramp_image_put_rcu_tasks);
+               else
+                       percpu_ref_kill(&im->pcref);
+               return;
+       }
+
+       /* The trampoline without fexit and fmod_ret progs doesn't call original
+        * function and doesn't use percpu_ref.
+        * Use call_rcu_tasks_trace() to wait for sleepable progs to finish.
+        * Then use call_rcu_tasks() to wait for the rest of trampoline asm
+        * and normal progs.
+        */
+       call_rcu_tasks_trace(&im->rcu, __bpf_tramp_image_put_rcu_tasks);
+}
+
+static struct bpf_tramp_image *bpf_tramp_image_alloc(u64 key, u32 idx)
+{
+       struct bpf_tramp_image *im;
+       struct bpf_ksym *ksym;
+       void *image;
+       int err = -ENOMEM;
+
+       im = kzalloc(sizeof(*im), GFP_KERNEL);
+       if (!im)
+               goto out;
+
+       err = bpf_jit_charge_modmem(1);
+       if (err)
+               goto out_free_im;
+
+       err = -ENOMEM;
+       im->image = image = bpf_jit_alloc_exec_page();
+       if (!image)
+               goto out_uncharge;
+
+       err = percpu_ref_init(&im->pcref, __bpf_tramp_image_release, 0, GFP_KERNEL);
+       if (err)
+               goto out_free_image;
+
+       ksym = &im->ksym;
+       INIT_LIST_HEAD_RCU(&ksym->lnode);
+       snprintf(ksym->name, KSYM_NAME_LEN, "bpf_trampoline_%llu_%u", key, idx);
+       bpf_image_ksym_add(image, ksym);
+       return im;
+
+out_free_image:
+       bpf_jit_free_exec(im->image);
+out_uncharge:
+       bpf_jit_uncharge_modmem(1);
+out_free_im:
+       kfree(im);
+out:
+       return ERR_PTR(err);
+}
+
 static int bpf_trampoline_update(struct bpf_trampoline *tr)
 {
-       void *old_image = tr->image + ((tr->selector + 1) & 1) * PAGE_SIZE/2;
-       void *new_image = tr->image + (tr->selector & 1) * PAGE_SIZE/2;
+       struct bpf_tramp_image *im;
        struct bpf_tramp_progs *tprogs;
        u32 flags = BPF_TRAMP_F_RESTORE_REGS;
        int err, total;
                return PTR_ERR(tprogs);
 
        if (total == 0) {
-               err = unregister_fentry(tr, old_image);
+               err = unregister_fentry(tr, tr->cur_image->image);
+               bpf_tramp_image_put(tr->cur_image);
+               tr->cur_image = NULL;
                tr->selector = 0;
                goto out;
        }
 
+       im = bpf_tramp_image_alloc(tr->key, tr->selector);
+       if (IS_ERR(im)) {
+               err = PTR_ERR(im);
+               goto out;
+       }
+
        if (tprogs[BPF_TRAMP_FEXIT].nr_progs ||
            tprogs[BPF_TRAMP_MODIFY_RETURN].nr_progs)
                flags = BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_SKIP_FRAME;
 
-       /* Though the second half of trampoline page is unused a task could be
-        * preempted in the middle of the first half of trampoline and two
-        * updates to trampoline would change the code from underneath the
-        * preempted task. Hence wait for tasks to voluntarily schedule or go
-        * to userspace.
-        * The same trampoline can hold both sleepable and non-sleepable progs.
-        * synchronize_rcu_tasks_trace() is needed to make sure all sleepable
-        * programs finish executing.
-        * Wait for these two grace periods together.
-        */
-       synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
-
-       err = arch_prepare_bpf_trampoline(new_image, new_image + PAGE_SIZE / 2,
+       err = arch_prepare_bpf_trampoline(im, im->image, im->image + PAGE_SIZE,
                                          &tr->func.model, flags, tprogs,
                                          tr->func.addr);
        if (err < 0)
                goto out;
 
-       if (tr->selector)
+       WARN_ON(tr->cur_image && tr->selector == 0);
+       WARN_ON(!tr->cur_image && tr->selector);
+       if (tr->cur_image)
                /* progs already running at this address */
-               err = modify_fentry(tr, old_image, new_image);
+               err = modify_fentry(tr, tr->cur_image->image, im->image);
        else
                /* first time registering */
-               err = register_fentry(tr, new_image);
+               err = register_fentry(tr, im->image);
        if (err)
                goto out;
+       if (tr->cur_image)
+               bpf_tramp_image_put(tr->cur_image);
+       tr->cur_image = im;
        tr->selector++;
 out:
        kfree(tprogs);
                goto out;
        if (WARN_ON_ONCE(!hlist_empty(&tr->progs_hlist[BPF_TRAMP_FEXIT])))
                goto out;
-       bpf_image_ksym_del(&tr->ksym);
-       /* This code will be executed when all bpf progs (both sleepable and
-        * non-sleepable) went through
-        * bpf_prog_put()->call_rcu[_tasks_trace]()->bpf_prog_free_deferred().
-        * Hence no need for another synchronize_rcu_tasks_trace() here,
-        * but synchronize_rcu_tasks() is still needed, since trampoline
-        * may not have had any sleepable programs and we need to wait
-        * for tasks to get out of trampoline code before freeing it.
+       /* This code will be executed even when the last bpf_tramp_image
+        * is alive. All progs are detached from the trampoline and the
+        * trampoline image is patched with jmp into epilogue to skip
+        * fexit progs. The fentry-only trampoline will be freed via
+        * multiple rcu callbacks.
         */
-       synchronize_rcu_tasks();
-       bpf_jit_free_exec(tr->image);
        hlist_del(&tr->hlist);
        kfree(tr);
 out:
        rcu_read_unlock_trace();
 }
 
+void notrace __bpf_tramp_enter(struct bpf_tramp_image *tr)
+{
+       percpu_ref_get(&tr->pcref);
+}
+
+void notrace __bpf_tramp_exit(struct bpf_tramp_image *tr)
+{
+       percpu_ref_put(&tr->pcref);
+}
+
 int __weak
-arch_prepare_bpf_trampoline(void *image, void *image_end,
+arch_prepare_bpf_trampoline(struct bpf_tramp_image *tr, void *image, void *image_end,
                            const struct btf_func_model *m, u32 flags,
                            struct bpf_tramp_progs *tprogs,
                            void *orig_call)