cdef extern from "limits.h":
     cdef uint64_t INT64_MAX
 
+cdef extern from "rados/librados.h":
+    enum:
+        _LIBRADOS_SNAP_HEAD "LIBRADOS_SNAP_HEAD"
+
 cdef extern from "rbd/librbd.h" nogil:
     enum:
         _RBD_FEATURE_LAYERING "RBD_FEATURE_LAYERING"
     int rbd_snap_set_limit(rbd_image_t image, uint64_t limit)
     int rbd_snap_get_timestamp(rbd_image_t image, uint64_t snap_id, timespec *timestamp)
     int rbd_snap_set(rbd_image_t image, const char *snapname)
+    int rbd_snap_set_by_id(rbd_image_t image, uint64_t snap_id)
     int rbd_snap_get_namespace_type(rbd_image_t image,
                                     uint64_t snap_id,
                                     rbd_snap_namespace_type_t *namespace_type)
         if ret != 0:
             raise make_ex(ret, 'error setting image %s to snapshot %s' % (self.name, name))
 
+    def set_snap_by_id(self, snap_id):
+        """
+        Set the snapshot to read from. Writes will raise ReadOnlyImage
+        while a snapshot is set. Pass None to unset the snapshot
+        (reads come from the current image) , and allow writing again.
+
+        :param snap_id: the snapshot to read from, or None to unset the snapshot
+        :type snap_id: int
+        """
+        if not snap_id:
+            snap_id = _LIBRADOS_SNAP_HEAD
+        cdef int64_t _snap_id = snap_id
+        with nogil:
+            ret = rbd_snap_set_by_id(self.image, _snap_id)
+        if ret != 0:
+            raise make_ex(ret, 'error setting image %s to snapshot %d' % (self.name, snap_id))
+
     def read(self, offset, length, fadvise_flags=0):
         """
         Read data from the image. Raises :class:`InvalidArgument` if
 
         eq(read, data)
         self.image.remove_snap('snap1')
 
+    def test_set_snap_by_id(self):
+        self.image.write(b'\0' * 256, 0)
+        self.image.create_snap('snap1')
+        read = self.image.read(0, 256)
+        eq(read, b'\0' * 256)
+        data = rand_data(256)
+        self.image.write(data, 0)
+        read = self.image.read(0, 256)
+        eq(read, data)
+        snaps = list(self.image.list_snaps())
+        self.image.set_snap_by_id(snaps[0]['id'])
+        read = self.image.read(0, 256)
+        eq(read, b'\0' * 256)
+        self.image.set_snap_by_id(None)
+        read = self.image.read(0, 256)
+        eq(read, data)
+        self.image.remove_snap('snap1')
+
     def test_set_snap_sparse(self):
         self.image.create_snap('snap1')
         read = self.image.read(0, 256)