WARN_ON(!list_empty(&kvm->arch.tdp_mmu_roots));
 }
 
-#define for_each_tdp_mmu_root(_kvm, _root)                         \
+static void tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root)
+{
+       if (kvm_mmu_put_root(kvm, root))
+               kvm_tdp_mmu_free_root(kvm, root);
+}
+
+static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
+                                          struct kvm_mmu_page *root)
+{
+       lockdep_assert_held(&kvm->mmu_lock);
+
+       if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
+               return false;
+
+       kvm_mmu_get_root(kvm, root);
+       return true;
+
+}
+
+static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+                                                    struct kvm_mmu_page *root)
+{
+       struct kvm_mmu_page *next_root;
+
+       next_root = list_next_entry(root, link);
+       tdp_mmu_put_root(kvm, root);
+       return next_root;
+}
+
+/*
+ * Note: this iterator gets and puts references to the roots it iterates over.
+ * This makes it safe to release the MMU lock and yield within the loop, but
+ * if exiting the loop early, the caller must drop the reference to the most
+ * recent root. (Unless keeping a live reference is desirable.)
+ */
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root)                          \
+       for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,        \
+                                     typeof(*_root), link);            \
+            tdp_mmu_next_root_valid(_kvm, _root);                      \
+            _root = tdp_mmu_next_root(_kvm, _root))
+
+#define for_each_tdp_mmu_root(_kvm, _root)                             \
        list_for_each_entry(_root, &_kvm->arch.tdp_mmu_roots, link)
 
 bool is_tdp_mmu_root(struct kvm *kvm, hpa_t hpa)
        struct kvm_mmu_page *root;
        bool flush = false;
 
-       for_each_tdp_mmu_root(kvm, root) {
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
+       for_each_tdp_mmu_root_yield_safe(kvm, root)
                flush |= zap_gfn_range(kvm, root, start, end, true);
 
-               kvm_mmu_put_root(kvm, root);
-       }
-
        return flush;
 }
 
        int ret = 0;
        int as_id;
 
-       for_each_tdp_mmu_root(kvm, root) {
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
+       for_each_tdp_mmu_root_yield_safe(kvm, root) {
                as_id = kvm_mmu_page_as_id(root);
                slots = __kvm_memslots(kvm, as_id);
                kvm_for_each_memslot(memslot, slots) {
                        ret |= handler(kvm, memslot, root, gfn_start,
                                       gfn_end, data);
                }
-
-               kvm_mmu_put_root(kvm, root);
        }
 
        return ret;
        int root_as_id;
        bool spte_set = false;
 
-       for_each_tdp_mmu_root(kvm, root) {
+       for_each_tdp_mmu_root_yield_safe(kvm, root) {
                root_as_id = kvm_mmu_page_as_id(root);
                if (root_as_id != slot->as_id)
                        continue;
 
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
                spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
                             slot->base_gfn + slot->npages, min_level);
-
-               kvm_mmu_put_root(kvm, root);
        }
 
        return spte_set;
        int root_as_id;
        bool spte_set = false;
 
-       for_each_tdp_mmu_root(kvm, root) {
+       for_each_tdp_mmu_root_yield_safe(kvm, root) {
                root_as_id = kvm_mmu_page_as_id(root);
                if (root_as_id != slot->as_id)
                        continue;
 
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
                spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
                                slot->base_gfn + slot->npages);
-
-               kvm_mmu_put_root(kvm, root);
        }
 
        return spte_set;
        int root_as_id;
        bool spte_set = false;
 
-       for_each_tdp_mmu_root(kvm, root) {
+       for_each_tdp_mmu_root_yield_safe(kvm, root) {
                root_as_id = kvm_mmu_page_as_id(root);
                if (root_as_id != slot->as_id)
                        continue;
 
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
                spte_set |= set_dirty_gfn_range(kvm, root, slot->base_gfn,
                                slot->base_gfn + slot->npages);
-
-               kvm_mmu_put_root(kvm, root);
        }
        return spte_set;
 }
        struct kvm_mmu_page *root;
        int root_as_id;
 
-       for_each_tdp_mmu_root(kvm, root) {
+       for_each_tdp_mmu_root_yield_safe(kvm, root) {
                root_as_id = kvm_mmu_page_as_id(root);
                if (root_as_id != slot->as_id)
                        continue;
 
-               /*
-                * Take a reference on the root so that it cannot be freed if
-                * this thread releases the MMU lock and yields in this loop.
-                */
-               kvm_mmu_get_root(kvm, root);
-
                zap_collapsible_spte_range(kvm, root, slot->base_gfn,
                                           slot->base_gfn + slot->npages);
-
-               kvm_mmu_put_root(kvm, root);
        }
 }