static int is_rmap_pte(u64 pte)
 {
-       return (pte & (PT_WRITABLE_MASK | PT_PRESENT_MASK))
-               == (PT_WRITABLE_MASK | PT_PRESENT_MASK);
+       return pte != shadow_trap_nonpresent_pte
+               && pte != shadow_notrap_nonpresent_pte;
 }
 
 static void set_shadow_pte(u64 *sptep, u64 spte)
 {
        unsigned long *rmapp;
        u64 *spte;
-       u64 *prev_spte;
 
        gfn = unalias_gfn(kvm, gfn);
        rmapp = gfn_to_rmap(kvm, gfn);
        while (spte) {
                BUG_ON(!spte);
                BUG_ON(!(*spte & PT_PRESENT_MASK));
-               BUG_ON(!(*spte & PT_WRITABLE_MASK));
                rmap_printk("rmap_write_protect: spte %p %llx\n", spte, *spte);
-               prev_spte = spte;
-               spte = rmap_next(kvm, rmapp, spte);
-               rmap_remove(kvm, prev_spte);
-               set_shadow_pte(prev_spte, *prev_spte & ~PT_WRITABLE_MASK);
+               if (is_writeble_pte(*spte))
+                       set_shadow_pte(spte, *spte & ~PT_WRITABLE_MASK);
                kvm_flush_remote_tlbs(kvm);
+               spte = rmap_next(kvm, rmapp, spte);
        }
 }
 
                table = __va(table_addr);
 
                if (level == 1) {
+                       int was_rmapped;
+
                        pte = table[index];
+                       was_rmapped = is_rmap_pte(pte);
                        if (is_shadow_present_pte(pte) && is_writeble_pte(pte))
                                return 0;
                        mark_page_dirty(vcpu->kvm, v >> PAGE_SHIFT);
                        page_header_update_slot(vcpu->kvm, table, v);
                        table[index] = p | PT_PRESENT_MASK | PT_WRITABLE_MASK |
                                                                PT_USER_MASK;
-                       rmap_add(vcpu, &table[index], v >> PAGE_SHIFT);
+                       if (!was_rmapped)
+                               rmap_add(vcpu, &table[index], v >> PAGE_SHIFT);
                        return 0;
                }
 
                pt = page->spt;
                for (i = 0; i < PT64_ENT_PER_PAGE; ++i)
                        /* avoid RMW */
-                       if (pt[i] & PT_WRITABLE_MASK) {
-                               rmap_remove(kvm, &pt[i]);
+                       if (pt[i] & PT_WRITABLE_MASK)
                                pt[i] &= ~PT_WRITABLE_MASK;
-                       }
        }
 }