]> git-server-git.apps.pok.os.sepia.ceph.com Git - rocksdb.git/commitdiff
MultiCfIterator Impl Follow up (#12465)
authorJay Huh <jewoongh@meta.com>
Fri, 22 Mar 2024 21:51:16 +0000 (14:51 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2024 21:51:16 +0000 (14:51 -0700)
Summary:
As a follow up for https://github.com/facebook/rocksdb/issues/12422 , this PR includes the following two changes.
- Removal of `direction_` in the MultiCfIterator
- Use of Member Func Template instead of `std::function`

Pull Request resolved: https://github.com/facebook/rocksdb/pull/12465

Test Plan:
```
./multi_cf_iterator_test
```

Reviewed By: pdillinger, ltamasi

Differential Revision: D55208448

Pulled By: jaykorean

fbshipit-source-id: 8b3167c1d59839d076afc29097b5ad21a453460a

db/multi_cf_iterator.cc
db/multi_cf_iterator.h

index 4398edca72c861bd12f697c5e160c0b39e00157c..80e4171d54d2f1b3fc538e6e1d7efaa14a893e74 100644 (file)
@@ -9,11 +9,10 @@
 
 namespace ROCKSDB_NAMESPACE {
 
-void MultiCfIterator::SeekCommon(
-    const std::function<void(Iterator*)>& child_seek_func,
-    Direction direction) {
-  direction_ = direction;
-  Reset();
+template <typename BinaryHeap, typename ChildSeekFuncType>
+void MultiCfIterator::SeekCommon(BinaryHeap& heap,
+                                 ChildSeekFuncType child_seek_func) {
+  heap.clear();
   int i = 0;
   for (auto& cfh_iter_pair : cfh_iter_pairs_) {
     auto& cfh = cfh_iter_pair.first;
@@ -21,13 +20,7 @@ void MultiCfIterator::SeekCommon(
     child_seek_func(iter.get());
     if (iter->Valid()) {
       assert(iter->status().ok());
-      if (direction_ == kReverse) {
-        auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
-        max_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i});
-      } else {
-        auto& min_heap = std::get<MultiCfMinHeap>(heap_);
-        min_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i});
-      }
+      heap.push(MultiCfIteratorInfo{iter.get(), cfh, i});
     } else {
       considerStatus(iter->status());
     }
@@ -35,9 +28,9 @@ void MultiCfIterator::SeekCommon(
   }
 }
 
-template <typename BinaryHeap>
-void MultiCfIterator::AdvanceIterator(
-    BinaryHeap& heap, const std::function<void(Iterator*)>& advance_func) {
+template <typename BinaryHeap, typename AdvanceFuncType>
+void MultiCfIterator::AdvanceIterator(BinaryHeap& heap,
+                                      AdvanceFuncType advance_func) {
   // 1. Keep the top iterator (by popping it from the heap)
   // 2. Make sure all others have iterated past the top iterator key slice
   // 3. Advance the top iterator, and add it back to the heap if valid
@@ -70,33 +63,39 @@ void MultiCfIterator::AdvanceIterator(
 }
 
 void MultiCfIterator::SeekToFirst() {
-  SeekCommon([](Iterator* iter) { iter->SeekToFirst(); }, kForward);
+  auto& min_heap = GetHeap<MultiCfMinHeap>([this]() { InitMinHeap(); });
+  SeekCommon(min_heap, [](Iterator* iter) { iter->SeekToFirst(); });
 }
 void MultiCfIterator::Seek(const Slice& target) {
-  SeekCommon([&target](Iterator* iter) { iter->Seek(target); }, kForward);
+  auto& min_heap = GetHeap<MultiCfMinHeap>([this]() { InitMinHeap(); });
+  SeekCommon(min_heap, [&target](Iterator* iter) { iter->Seek(target); });
 }
 void MultiCfIterator::SeekToLast() {
-  SeekCommon([](Iterator* iter) { iter->SeekToLast(); }, kReverse);
+  auto& max_heap = GetHeap<MultiCfMaxHeap>([this]() { InitMaxHeap(); });
+  SeekCommon(max_heap, [](Iterator* iter) { iter->SeekToLast(); });
 }
 void MultiCfIterator::SeekForPrev(const Slice& target) {
-  SeekCommon([&target](Iterator* iter) { iter->SeekForPrev(target); },
-             kReverse);
+  auto& max_heap = GetHeap<MultiCfMaxHeap>([this]() { InitMaxHeap(); });
+  SeekCommon(max_heap,
+             [&target](Iterator* iter) { iter->SeekForPrev(target); });
 }
 
 void MultiCfIterator::Next() {
   assert(Valid());
-  if (direction_ != kForward) {
-    SwitchToDirection(kForward);
-  }
-  auto& min_heap = std::get<MultiCfMinHeap>(heap_);
+  auto& min_heap = GetHeap<MultiCfMinHeap>([this]() {
+    Slice target = key();
+    InitMinHeap();
+    Seek(target);
+  });
   AdvanceIterator(min_heap, [](Iterator* iter) { iter->Next(); });
 }
 void MultiCfIterator::Prev() {
   assert(Valid());
-  if (direction_ != kReverse) {
-    SwitchToDirection(kReverse);
-  }
-  auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
+  auto& max_heap = GetHeap<MultiCfMaxHeap>([this]() {
+    Slice target = key();
+    InitMaxHeap();
+    SeekForPrev(target);
+  });
   AdvanceIterator(max_heap, [](Iterator* iter) { iter->Prev(); });
 }
 
index 4269422b3c131fba3fbed2e0c65b97e4857d10fd..cdd09c16df06b58678022ca934207f2e6647d3a4 100644 (file)
@@ -86,13 +86,10 @@ class MultiCfIterator : public Iterator {
 
   MultiCfIterHeap heap_;
 
-  enum Direction : uint8_t { kForward, kReverse };
-  Direction direction_ = kForward;
-
   // TODO: Lower and Upper bounds
 
   Iterator* current() const {
-    if (direction_ == kReverse) {
+    if (std::holds_alternative<MultiCfMaxHeap>(heap_)) {
       auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
       return max_heap.top().iterator;
     }
@@ -114,7 +111,7 @@ class MultiCfIterator : public Iterator {
   }
 
   bool Valid() const override {
-    if (direction_ == kReverse) {
+    if (std::holds_alternative<MultiCfMaxHeap>(heap_)) {
       auto& max_heap = std::get<MultiCfMaxHeap>(heap_);
       return !max_heap.empty() && status_.ok();
     }
@@ -128,21 +125,13 @@ class MultiCfIterator : public Iterator {
       status_ = std::move(s);
     }
   }
-  void Reset() {
-    std::visit(overload{[&](MultiCfMinHeap& min_heap) -> void {
-                          min_heap.clear();
-                          if (direction_ == kReverse) {
-                            InitMaxHeap();
-                          }
-                        },
-                        [&](MultiCfMaxHeap& max_heap) -> void {
-                          max_heap.clear();
-                          if (direction_ == kForward) {
-                            InitMinHeap();
-                          }
-                        }},
-               heap_);
-    status_ = Status::OK();
+
+  template <typename HeapType, typename InitFunc>
+  HeapType& GetHeap(InitFunc initFunc) {
+    if (!std::holds_alternative<HeapType>(heap_)) {
+      initFunc();
+    }
+    return std::get<HeapType>(heap_);
   }
 
   void InitMinHeap() {
@@ -154,21 +143,10 @@ class MultiCfIterator : public Iterator {
         MultiCfHeapItemComparator<std::less<int>>(comparator_));
   }
 
-  void SwitchToDirection(Direction new_direction) {
-    assert(direction_ != new_direction);
-    Slice target = key();
-    if (new_direction == kForward) {
-      Seek(target);
-    } else {
-      SeekForPrev(target);
-    }
-  }
-
-  void SeekCommon(const std::function<void(Iterator*)>& child_seek_func,
-                  Direction direction);
-  template <typename BinaryHeap>
-  void AdvanceIterator(BinaryHeap& heap,
-                       const std::function<void(Iterator*)>& advance_func);
+  template <typename BinaryHeap, typename ChildSeekFuncType>
+  void SeekCommon(BinaryHeap& heap, ChildSeekFuncType child_seek_func);
+  template <typename BinaryHeap, typename AdvanceFuncType>
+  void AdvanceIterator(BinaryHeap& heap, AdvanceFuncType advance_func);
 
   void SeekToFirst() override;
   void SeekToLast() override;