]> git-server-git.apps.pok.os.sepia.ceph.com Git - rocksdb.git/commitdiff
Verify write batch checksum before WAL (#10114)
authorChangyu Bi <changyubi@fb.com>
Wed, 15 Jun 2022 20:43:58 +0000 (13:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Jun 2022 20:43:58 +0000 (13:43 -0700)
Summary:
Context: WriteBatch can have key-value checksums when it was created `with protection_bytes_per_key > 0`.
This PR added checksum verification for write batches before they are written to WAL.

Pull Request resolved: https://github.com/facebook/rocksdb/pull/10114

Test Plan:
- Added new unit tests to db_kv_checksum_test.cc: `make check -j32`
- benchmark on performance regression: `./db_bench --benchmarks=fillrandom[-X20] -db=/dev/shm/test_rocksdb -write_batch_protection_bytes_per_key=8`
  - Pre-PR:
`
fillrandom [AVG    20 runs] : 198875 (± 3006) ops/sec;   22.0 (± 0.3) MB/sec
`
  - Post-PR:
`
fillrandom [AVG    20 runs] : 196487 (± 2279) ops/sec;   21.7 (± 0.3) MB/sec
`
  Mean regressed about 1% (198875 -> 196487 ops/sec).

Reviewed By: ajkr

Differential Revision: D36917464

Pulled By: cbi42

fbshipit-source-id: 29beb74edf65f04b1a890b4f650d873dc7ed790d

db/db_impl/db_impl.h
db/db_impl/db_impl_write.cc
db/db_kv_checksum_test.cc
db/write_batch.cc
db/write_batch_internal.h
db/write_thread.cc
include/rocksdb/write_batch.h
tools/db_bench_tool.cc

index 018d7904cf9609b1ab251999d24d71c52eea4788..733a87a0c8063d654d88de183c9f8c55c76e5ea4 100644 (file)
@@ -1915,9 +1915,12 @@ class DBImpl : public DB {
   Status PreprocessWrite(const WriteOptions& write_options, bool* need_log_sync,
                          WriteContext* write_context);
 
-  WriteBatch* MergeBatch(const WriteThread::WriteGroup& write_group,
-                         WriteBatch* tmp_batch, size_t* write_with_wal,
-                         WriteBatch** to_be_cached_state);
+  // Merge write batches in the write group into merged_batch.
+  // Returns OK if merge is successful.
+  // Returns Corruption if corruption in write batch is detected.
+  Status MergeBatch(const WriteThread::WriteGroup& write_group,
+                    WriteBatch* tmp_batch, WriteBatch** merged_batch,
+                    size_t* write_with_wal, WriteBatch** to_be_cached_state);
 
   // rate_limiter_priority is used to charge `DBOptions::rate_limiter`
   // for automatic WAL flush (`Options::manual_wal_flush` == false)
index c6ce801ae45b4494a43b7a401f618c83c5da8bcf..787006d35c8052c7413f88b715e7b3d0154d895d 100644 (file)
@@ -533,15 +533,18 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,
   }
   PERF_TIMER_START(write_pre_and_post_process_time);
 
+  if (!io_s.ok()) {
+    // Check WriteToWAL status
+    IOStatusCheck(io_s);
+  }
   if (!w.CallbackFailed()) {
     if (!io_s.ok()) {
       assert(pre_release_cb_status.ok());
-      IOStatusCheck(io_s);
     } else {
       WriteStatusCheck(pre_release_cb_status);
     }
   } else {
-    assert(io_s.ok() && pre_release_cb_status.ok());
+    assert(pre_release_cb_status.ok());
   }
 
   if (need_log_sync) {
@@ -695,12 +698,11 @@ Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options,
       w.status = io_s;
     }
 
-    if (!w.CallbackFailed()) {
-      if (!io_s.ok()) {
-        IOStatusCheck(io_s);
-      } else {
-        WriteStatusCheck(w.status);
-      }
+    if (!io_s.ok()) {
+      // Check WriteToWAL status
+      IOStatusCheck(io_s);
+    } else if (!w.CallbackFailed()) {
+      WriteStatusCheck(w.status);
     }
 
     if (need_log_sync) {
@@ -936,11 +938,18 @@ Status DBImpl::WriteImplWALOnly(
     seq_inc = total_batch_cnt;
   }
   Status status;
-  IOStatus io_s;
-  io_s.PermitUncheckedError();  // Allow io_s to be uninitialized
   if (!write_options.disableWAL) {
-    io_s = ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc);
+    IOStatus io_s =
+        ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc);
     status = io_s;
+    // last_sequence may not be set if there is an error
+    // This error checking and return is moved up to avoid using uninitialized
+    // last_sequence.
+    if (!io_s.ok()) {
+      IOStatusCheck(io_s);
+      write_thread->ExitAsBatchGroupLeader(write_group, status);
+      return status;
+    }
   } else {
     // Otherwise we inc seq number to do solely the seq allocation
     last_sequence = versions_->FetchAddLastAllocatedSequence(seq_inc);
@@ -975,11 +984,7 @@ Status DBImpl::WriteImplWALOnly(
   PERF_TIMER_START(write_pre_and_post_process_time);
 
   if (!w.CallbackFailed()) {
-    if (!io_s.ok()) {
-      IOStatusCheck(io_s);
-    } else {
-      WriteStatusCheck(status);
-    }
+    WriteStatusCheck(status);
   }
   if (status.ok()) {
     size_t index = 0;
@@ -1171,13 +1176,13 @@ Status DBImpl::PreprocessWrite(const WriteOptions& write_options,
   return status;
 }
 
-WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
-                               WriteBatch* tmp_batch, size_t* write_with_wal,
-                               WriteBatch** to_be_cached_state) {
+Status DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
+                          WriteBatch* tmp_batch, WriteBatch** merged_batch,
+                          size_t* write_with_wal,
+                          WriteBatch** to_be_cached_state) {
   assert(write_with_wal != nullptr);
   assert(tmp_batch != nullptr);
   assert(*to_be_cached_state == nullptr);
-  WriteBatch* merged_batch = nullptr;
   *write_with_wal = 0;
   auto* leader = write_group.leader;
   assert(!leader->disable_wal);  // Same holds for all in the batch group
@@ -1186,22 +1191,24 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
     // we simply write the first WriteBatch to WAL if the group only
     // contains one batch, that batch should be written to the WAL,
     // and the batch is not wanting to be truncated
-    merged_batch = leader->batch;
-    if (WriteBatchInternal::IsLatestPersistentState(merged_batch)) {
-      *to_be_cached_state = merged_batch;
+    *merged_batch = leader->batch;
+    if (WriteBatchInternal::IsLatestPersistentState(*merged_batch)) {
+      *to_be_cached_state = *merged_batch;
     }
     *write_with_wal = 1;
   } else {
     // WAL needs all of the batches flattened into a single batch.
     // We could avoid copying here with an iov-like AddRecord
     // interface
-    merged_batch = tmp_batch;
+    *merged_batch = tmp_batch;
     for (auto writer : write_group) {
       if (!writer->CallbackFailed()) {
-        Status s = WriteBatchInternal::Append(merged_batch, writer->batch,
+        Status s = WriteBatchInternal::Append(*merged_batch, writer->batch,
                                               /*WAL_only*/ true);
-        // Always returns Status::OK.
-        assert(s.ok());
+        if (!s.ok()) {
+          tmp_batch->Clear();
+          return s;
+        }
         if (WriteBatchInternal::IsLatestPersistentState(writer->batch)) {
           // We only need to cache the last of such write batch
           *to_be_cached_state = writer->batch;
@@ -1210,7 +1217,8 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
       }
     }
   }
-  return merged_batch;
+  // return merged_batch;
+  return Status::OK();
 }
 
 // When two_write_queues_ is disabled, this function is called from the only
@@ -1223,6 +1231,11 @@ IOStatus DBImpl::WriteToWAL(const WriteBatch& merged_batch,
   assert(log_size != nullptr);
 
   Slice log_entry = WriteBatchInternal::Contents(&merged_batch);
+  TEST_SYNC_POINT_CALLBACK("DBImpl::WriteToWAL:log_entry", &log_entry);
+  auto s = merged_batch.VerifyChecksum();
+  if (!s.ok()) {
+    return status_to_io_status(std::move(s));
+  }
   *log_size = log_entry.size();
   // When two_write_queues_ WriteToWAL has to be protected from concurretn calls
   // from the two queues anyway and log_write_mutex_ is already held. Otherwise
@@ -1260,8 +1273,13 @@ IOStatus DBImpl::WriteToWAL(const WriteThread::WriteGroup& write_group,
   // Same holds for all in the batch group
   size_t write_with_wal = 0;
   WriteBatch* to_be_cached_state = nullptr;
-  WriteBatch* merged_batch = MergeBatch(write_group, &tmp_batch_,
-                                        &write_with_wal, &to_be_cached_state);
+  WriteBatch* merged_batch;
+  io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch_, &merged_batch,
+                                        &write_with_wal, &to_be_cached_state));
+  if (UNLIKELY(!io_s.ok())) {
+    return io_s;
+  }
+
   if (merged_batch == write_group.leader->batch) {
     write_group.leader->log_used = logfile_number_;
   } else if (write_with_wal > 1) {
@@ -1351,8 +1369,12 @@ IOStatus DBImpl::ConcurrentWriteToWAL(
   WriteBatch tmp_batch;
   size_t write_with_wal = 0;
   WriteBatch* to_be_cached_state = nullptr;
-  WriteBatch* merged_batch =
-      MergeBatch(write_group, &tmp_batch, &write_with_wal, &to_be_cached_state);
+  WriteBatch* merged_batch;
+  io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch, &merged_batch,
+                                        &write_with_wal, &to_be_cached_state));
+  if (UNLIKELY(!io_s.ok())) {
+    return io_s;
+  }
 
   // We need to lock log_write_mutex_ since logs_ and alive_log_files might be
   // pushed back concurrently
index 44ee56786fd7815b4e6995ecdfc06eb295ed61a9..5636c9e6ebac71f684b95d72fc641946b642d397 100644 (file)
@@ -25,6 +25,49 @@ WriteBatchOpType operator+(WriteBatchOpType lhs, const int rhs) {
   return static_cast<WriteBatchOpType>(static_cast<T>(lhs) + rhs);
 }
 
+std::pair<WriteBatch, Status> GetWriteBatch(ColumnFamilyHandle* cf_handle,
+                                            WriteBatchOpType op_type) {
+  Status s;
+  WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */,
+                8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */);
+  switch (op_type) {
+    case WriteBatchOpType::kPut:
+      s = wb.Put(cf_handle, "key", "val");
+      break;
+    case WriteBatchOpType::kDelete:
+      s = wb.Delete(cf_handle, "key");
+      break;
+    case WriteBatchOpType::kSingleDelete:
+      s = wb.SingleDelete(cf_handle, "key");
+      break;
+    case WriteBatchOpType::kDeleteRange:
+      s = wb.DeleteRange(cf_handle, "begin", "end");
+      break;
+    case WriteBatchOpType::kMerge:
+      s = wb.Merge(cf_handle, "key", "val");
+      break;
+    case WriteBatchOpType::kBlobIndex: {
+      // TODO(ajkr): use public API once available.
+      uint32_t cf_id;
+      if (cf_handle == nullptr) {
+        cf_id = 0;
+      } else {
+        cf_id = cf_handle->GetID();
+      }
+
+      std::string blob_index;
+      BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210,
+                                  "val");
+
+      s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index);
+      break;
+    }
+    case WriteBatchOpType::kNum:
+      assert(false);
+  }
+  return {std::move(wb), std::move(s)};
+}
+
 class DbKvChecksumTest
     : public DBTestBase,
       public ::testing::WithParamInterface<std::tuple<WriteBatchOpType, char>> {
@@ -35,48 +78,6 @@ class DbKvChecksumTest
     corrupt_byte_addend_ = std::get<1>(GetParam());
   }
 
-  std::pair<WriteBatch, Status> GetWriteBatch(ColumnFamilyHandle* cf_handle) {
-    Status s;
-    WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */,
-                  8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */);
-    switch (op_type_) {
-      case WriteBatchOpType::kPut:
-        s = wb.Put(cf_handle, "key", "val");
-        break;
-      case WriteBatchOpType::kDelete:
-        s = wb.Delete(cf_handle, "key");
-        break;
-      case WriteBatchOpType::kSingleDelete:
-        s = wb.SingleDelete(cf_handle, "key");
-        break;
-      case WriteBatchOpType::kDeleteRange:
-        s = wb.DeleteRange(cf_handle, "begin", "end");
-        break;
-      case WriteBatchOpType::kMerge:
-        s = wb.Merge(cf_handle, "key", "val");
-        break;
-      case WriteBatchOpType::kBlobIndex: {
-        // TODO(ajkr): use public API once available.
-        uint32_t cf_id;
-        if (cf_handle == nullptr) {
-          cf_id = 0;
-        } else {
-          cf_id = cf_handle->GetID();
-        }
-
-        std::string blob_index;
-        BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210,
-                                    "val");
-
-        s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index);
-        break;
-      }
-      case WriteBatchOpType::kNum:
-        assert(false);
-    }
-    return {std::move(wb), std::move(s)};
-  }
-
   void CorruptNextByteCallBack(void* arg) {
     Slice encoded = *static_cast<Slice*>(arg);
     if (entry_len_ == std::numeric_limits<size_t>::max()) {
@@ -99,34 +100,28 @@ class DbKvChecksumTest
   size_t entry_len_ = std::numeric_limits<size_t>::max();
 };
 
-std::string GetTestNameSuffix(
-    ::testing::TestParamInfo<std::tuple<WriteBatchOpType, char>> info) {
-  std::ostringstream oss;
-  switch (std::get<0>(info.param)) {
+std::string GetOpTypeString(const WriteBatchOpType& op_type) {
+  switch (op_type) {
     case WriteBatchOpType::kPut:
-      oss << "Put";
-      break;
+      return "Put";
     case WriteBatchOpType::kDelete:
-      oss << "Delete";
-      break;
+      return "Delete";
     case WriteBatchOpType::kSingleDelete:
-      oss << "SingleDelete";
-      break;
+      return "SingleDelete";
     case WriteBatchOpType::kDeleteRange:
-      oss << "DeleteRange";
+      return "DeleteRange";
       break;
     case WriteBatchOpType::kMerge:
-      oss << "Merge";
+      return "Merge";
       break;
     case WriteBatchOpType::kBlobIndex:
-      oss << "BlobIndex";
+      return "BlobIndex";
       break;
     case WriteBatchOpType::kNum:
       assert(false);
   }
-  oss << "Add"
-      << static_cast<int>(static_cast<unsigned char>(std::get<1>(info.param)));
-  return oss.str();
+  assert(false);
+  return "";
 }
 
 INSTANTIATE_TEST_CASE_P(
@@ -134,7 +129,13 @@ INSTANTIATE_TEST_CASE_P(
     ::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0),
                                         WriteBatchOpType::kNum),
                        ::testing::Values(2, 103, 251)),
-    GetTestNameSuffix);
+    [](const testing::TestParamInfo<std::tuple<WriteBatchOpType, char>>& args) {
+      std::ostringstream oss;
+      oss << GetOpTypeString(std::get<0>(args.param)) << "Add"
+          << static_cast<int>(
+                 static_cast<unsigned char>(std::get<1>(args.param)));
+      return oss.str();
+    });
 
 TEST_P(DbKvChecksumTest, MemTableAddCorrupted) {
   // This test repeatedly attempts to write `WriteBatch`es containing a single
@@ -157,11 +158,16 @@ TEST_P(DbKvChecksumTest, MemTableAddCorrupted) {
     Reopen(options);
 
     SyncPoint::GetInstance()->EnableProcessing();
-    auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */);
+    auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_);
     ASSERT_OK(batch_and_status.second);
     ASSERT_TRUE(
         db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
     SyncPoint::GetInstance()->DisableProcessing();
+
+    // In case the above callback is not invoked, this test will run
+    // numeric_limits<size_t>::max() times until it reports an error (or will
+    // exhaust disk space). Added this assert to report error early.
+    ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
   }
 }
 
@@ -188,14 +194,373 @@ TEST_P(DbKvChecksumTest, MemTableAddWithColumnFamilyCorrupted) {
     ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options);
 
     SyncPoint::GetInstance()->EnableProcessing();
-    auto batch_and_status = GetWriteBatch(handles_[1]);
+    auto batch_and_status = GetWriteBatch(handles_[1], op_type_);
+    ASSERT_OK(batch_and_status.second);
+    ASSERT_TRUE(
+        db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
+    SyncPoint::GetInstance()->DisableProcessing();
+
+    // In case the above callback is not invoked, this test will run
+    // numeric_limits<size_t>::max() times until it reports an error (or will
+    // exhaust disk space). Added this assert to report error early.
+    ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
+  }
+}
+
+TEST_P(DbKvChecksumTest, NoCorruptionCase) {
+  // If this test fails, we may have found a piece of malfunctioned hardware
+  auto batch_and_status = GetWriteBatch(nullptr, op_type_);
+  ASSERT_OK(batch_and_status.second);
+  ASSERT_OK(batch_and_status.first.VerifyChecksum());
+}
+
+TEST_P(DbKvChecksumTest, WriteToWALCorrupted) {
+  // This test repeatedly attempts to write `WriteBatch`es containing a single
+  // entry of type `op_type_`. Each attempt has one byte corrupted by adding
+  // `corrupt_byte_addend_` to its original value. The test repeats until an
+  // attempt has been made on each byte in the encoded write batch. All attempts
+  // are expected to fail with `Status::Corruption`
+  Options options = CurrentOptions();
+  if (op_type_ == WriteBatchOpType::kMerge) {
+    options.merge_operator = MergeOperators::CreateStringAppendOperator();
+  }
+  SyncPoint::GetInstance()->SetCallBack(
+      "DBImpl::WriteToWAL:log_entry",
+      std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
+                std::placeholders::_1));
+  // First 8 bytes are for sequence number which is not protected in write batch
+  corrupt_byte_offset_ = 8;
+
+  while (MoreBytesToCorrupt()) {
+    // Corrupted write batch leads to read-only mode, so we have to
+    // reopen for every attempt.
+    Reopen(options);
+    auto log_size_pre_write = dbfull()->TEST_total_log_size();
+
+    SyncPoint::GetInstance()->EnableProcessing();
+    auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_);
+    ASSERT_OK(batch_and_status.second);
+    ASSERT_TRUE(
+        db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
+    // Confirm that nothing was written to WAL
+    ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
+    ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
+    SyncPoint::GetInstance()->DisableProcessing();
+
+    // In case the above callback is not invoked, this test will run
+    // numeric_limits<size_t>::max() times until it reports an error (or will
+    // exhaust disk space). Added this assert to report error early.
+    ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
+  }
+}
+
+TEST_P(DbKvChecksumTest, WriteToWALWithColumnFamilyCorrupted) {
+  // This test repeatedly attempts to write `WriteBatch`es containing a single
+  // entry of type `op_type_`. Each attempt has one byte corrupted by adding
+  // `corrupt_byte_addend_` to its original value. The test repeats until an
+  // attempt has been made on each byte in the encoded write batch. All attempts
+  // are expected to fail with `Status::Corruption`
+  Options options = CurrentOptions();
+  if (op_type_ == WriteBatchOpType::kMerge) {
+    options.merge_operator = MergeOperators::CreateStringAppendOperator();
+  }
+  CreateAndReopenWithCF({"pikachu"}, options);
+  SyncPoint::GetInstance()->SetCallBack(
+      "DBImpl::WriteToWAL:log_entry",
+      std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
+                std::placeholders::_1));
+  // First 8 bytes are for sequence number which is not protected in write batch
+  corrupt_byte_offset_ = 8;
+
+  while (MoreBytesToCorrupt()) {
+    // Corrupted write batch leads to read-only mode, so we have to
+    // reopen for every attempt.
+    ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options);
+    auto log_size_pre_write = dbfull()->TEST_total_log_size();
+
+    SyncPoint::GetInstance()->EnableProcessing();
+    auto batch_and_status = GetWriteBatch(handles_[1], op_type_);
     ASSERT_OK(batch_and_status.second);
     ASSERT_TRUE(
         db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
+    // Confirm that nothing was written to WAL
+    ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
+    ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
     SyncPoint::GetInstance()->DisableProcessing();
+
+    // In case the above callback is not invoked, this test will run
+    // numeric_limits<size_t>::max() times until it reports an error (or will
+    // exhaust disk space). Added this assert to report error early.
+    ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
+  }
+}
+
+class DbKvChecksumTestMergedBatch
+    : public DBTestBase,
+      public ::testing::WithParamInterface<
+          std::tuple<WriteBatchOpType, WriteBatchOpType, char>> {
+ public:
+  DbKvChecksumTestMergedBatch()
+      : DBTestBase("db_kv_checksum_test", /*env_do_fsync=*/false) {
+    op_type1_ = std::get<0>(GetParam());
+    op_type2_ = std::get<1>(GetParam());
+    corrupt_byte_addend_ = std::get<2>(GetParam());
   }
+
+ protected:
+  WriteBatchOpType op_type1_;
+  WriteBatchOpType op_type2_;
+  char corrupt_byte_addend_;
+};
+
+void CorruptWriteBatch(Slice* content, size_t offset,
+                       char corrupt_byte_addend) {
+  ASSERT_TRUE(offset < content->size());
+  char* buf = const_cast<char*>(content->data());
+  buf[offset] += corrupt_byte_addend;
+}
+
+TEST_P(DbKvChecksumTestMergedBatch, NoCorruptionCase) {
+  // Veirfy write batch checksum after write batch append
+  auto batch1 = GetWriteBatch(nullptr /* cf_handle */, op_type1_);
+  ASSERT_OK(batch1.second);
+  auto batch2 = GetWriteBatch(nullptr /* cf_handle */, op_type2_);
+  ASSERT_OK(batch2.second);
+  ASSERT_OK(WriteBatchInternal::Append(&batch1.first, &batch2.first));
+  ASSERT_OK(batch1.first.VerifyChecksum());
 }
 
+TEST_P(DbKvChecksumTestMergedBatch, WriteToWALCorrupted) {
+  // This test has two writers repeatedly attempt to write `WriteBatch`es
+  // containing a single entry of type op_type1_ and op_type2_ respectively. The
+  // leader of the write group writes the batch containinng the entry of type
+  // op_type1_. One byte of the pre-merged write batches is corrupted by adding
+  // `corrupt_byte_addend_` to the batch's original value during each attempt.
+  // The test repeats until an attempt has been made on each byte in both
+  // pre-merged write batches. All attempts are expected to fail with
+  // `Status::Corruption`.
+  Options options = CurrentOptions();
+  if (op_type1_ == WriteBatchOpType::kMerge ||
+      op_type2_ == WriteBatchOpType::kMerge) {
+    options.merge_operator = MergeOperators::CreateStringAppendOperator();
+  }
+
+  auto leader_batch_and_status =
+      GetWriteBatch(nullptr /* cf_handle */, op_type1_);
+  ASSERT_OK(leader_batch_and_status.second);
+  auto follower_batch_and_status =
+      GetWriteBatch(nullptr /* cf_handle */, op_type2_);
+  size_t leader_batch_size = leader_batch_and_status.first.GetDataSize();
+  size_t total_bytes =
+      leader_batch_size + follower_batch_and_status.first.GetDataSize();
+  // First 8 bytes are for sequence number which is not protected in write batch
+  size_t corrupt_byte_offset = 8;
+
+  std::atomic<bool> follower_joined{false};
+  std::atomic<int> leader_count{0};
+  port::Thread follower_thread;
+  // This callback should only be called by the leader thread
+  SyncPoint::GetInstance()->SetCallBack(
+      "WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) {
+        auto* leader = reinterpret_cast<WriteThread::Writer*>(arg_leader);
+        ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER);
+
+        // This callback should only be called by the follower thread
+        SyncPoint::GetInstance()->SetCallBack(
+            "WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) {
+              auto* follower =
+                  reinterpret_cast<WriteThread::Writer*>(arg_follower);
+              // The leader thread will wait on this bool and hence wait until
+              // this writer joins the write group
+              ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER);
+              if (corrupt_byte_offset >= leader_batch_size) {
+                Slice batch_content = follower->batch->Data();
+                CorruptWriteBatch(&batch_content,
+                                  corrupt_byte_offset - leader_batch_size,
+                                  corrupt_byte_addend_);
+              }
+              // Leader busy waits on this flag
+              follower_joined = true;
+              // So the follower does not enter the outer callback at
+              // WriteThread::JoinBatchGroup:Wait2
+              SyncPoint::GetInstance()->DisableProcessing();
+            });
+
+        // Start the other writer thread which will join the write group as
+        // follower
+        follower_thread = port::Thread([&]() {
+          follower_batch_and_status =
+              GetWriteBatch(nullptr /* cf_handle */, op_type2_);
+          ASSERT_OK(follower_batch_and_status.second);
+          ASSERT_TRUE(
+              db_->Write(WriteOptions(), &follower_batch_and_status.first)
+                  .IsCorruption());
+        });
+
+        ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size);
+        if (corrupt_byte_offset < leader_batch_size) {
+          Slice batch_content = leader->batch->Data();
+          CorruptWriteBatch(&batch_content, corrupt_byte_offset,
+                            corrupt_byte_addend_);
+        }
+        leader_count++;
+        while (!follower_joined) {
+          // busy waiting
+        }
+      });
+  while (corrupt_byte_offset < total_bytes) {
+    // Reopen DB since it failed WAL write which lead to read-only mode
+    Reopen(options);
+    SyncPoint::GetInstance()->EnableProcessing();
+    auto log_size_pre_write = dbfull()->TEST_total_log_size();
+    leader_batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type1_);
+    ASSERT_OK(leader_batch_and_status.second);
+    ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first)
+                    .IsCorruption());
+    follower_thread.join();
+    // Prevent leader thread from entering this callback
+    SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait");
+    ASSERT_EQ(1, leader_count);
+    // Nothing should have been written to WAL
+    ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
+    ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
+
+    corrupt_byte_offset++;
+    if (corrupt_byte_offset == leader_batch_size) {
+      // skip over the sequence number part of follower's write batch
+      corrupt_byte_offset += 8;
+    }
+    follower_joined = false;
+    leader_count = 0;
+  }
+  SyncPoint::GetInstance()->DisableProcessing();
+}
+
+TEST_P(DbKvChecksumTestMergedBatch, WriteToWALWithColumnFamilyCorrupted) {
+  // This test has two writers repeatedly attempt to write `WriteBatch`es
+  // containing a single entry of type op_type1_ and op_type2_ respectively. The
+  // leader of the write group writes the batch containinng the entry of type
+  // op_type1_. One byte of the pre-merged write batches is corrupted by adding
+  // `corrupt_byte_addend_` to the batch's original value during each attempt.
+  // The test repeats until an attempt has been made on each byte in both
+  // pre-merged write batches. All attempts are expected to fail with
+  // `Status::Corruption`.
+  Options options = CurrentOptions();
+  if (op_type1_ == WriteBatchOpType::kMerge ||
+      op_type2_ == WriteBatchOpType::kMerge) {
+    options.merge_operator = MergeOperators::CreateStringAppendOperator();
+  }
+  CreateAndReopenWithCF({"ramen"}, options);
+
+  auto leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_);
+  ASSERT_OK(leader_batch_and_status.second);
+  auto follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_);
+  size_t leader_batch_size = leader_batch_and_status.first.GetDataSize();
+  size_t total_bytes =
+      leader_batch_size + follower_batch_and_status.first.GetDataSize();
+  // First 8 bytes are for sequence number which is not protected in write batch
+  size_t corrupt_byte_offset = 8;
+
+  std::atomic<bool> follower_joined{false};
+  std::atomic<int> leader_count{0};
+  port::Thread follower_thread;
+  // This callback should only be called by the leader thread
+  SyncPoint::GetInstance()->SetCallBack(
+      "WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) {
+        auto* leader = reinterpret_cast<WriteThread::Writer*>(arg_leader);
+        ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER);
+
+        // This callback should only be called by the follower thread
+        SyncPoint::GetInstance()->SetCallBack(
+            "WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) {
+              auto* follower =
+                  reinterpret_cast<WriteThread::Writer*>(arg_follower);
+              // The leader thread will wait on this bool and hence wait until
+              // this writer joins the write group
+              ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER);
+              if (corrupt_byte_offset >= leader_batch_size) {
+                Slice batch_content =
+                    WriteBatchInternal::Contents(follower->batch);
+                CorruptWriteBatch(&batch_content,
+                                  corrupt_byte_offset - leader_batch_size,
+                                  corrupt_byte_addend_);
+              }
+              follower_joined = true;
+              // So the follower does not enter the outer callback at
+              // WriteThread::JoinBatchGroup:Wait2
+              SyncPoint::GetInstance()->DisableProcessing();
+            });
+
+        // Start the other writer thread which will join the write group as
+        // follower
+        follower_thread = port::Thread([&]() {
+          follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_);
+          ASSERT_OK(follower_batch_and_status.second);
+          ASSERT_TRUE(
+              db_->Write(WriteOptions(), &follower_batch_and_status.first)
+                  .IsCorruption());
+        });
+
+        ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size);
+        if (corrupt_byte_offset < leader_batch_size) {
+          Slice batch_content = WriteBatchInternal::Contents(leader->batch);
+          CorruptWriteBatch(&batch_content, corrupt_byte_offset,
+                            corrupt_byte_addend_);
+        }
+        leader_count++;
+        while (!follower_joined) {
+          // busy waiting
+        }
+      });
+  SyncPoint::GetInstance()->EnableProcessing();
+  while (corrupt_byte_offset < total_bytes) {
+    // Reopen DB since it failed WAL write which lead to read-only mode
+    ReopenWithColumnFamilies({kDefaultColumnFamilyName, "ramen"}, options);
+    SyncPoint::GetInstance()->EnableProcessing();
+    auto log_size_pre_write = dbfull()->TEST_total_log_size();
+    leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_);
+    ASSERT_OK(leader_batch_and_status.second);
+    ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first)
+                    .IsCorruption());
+    follower_thread.join();
+    // Prevent leader thread from entering this callback
+    SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait");
+
+    ASSERT_EQ(1, leader_count);
+    // Nothing should have been written to WAL
+    ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
+    ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
+
+    corrupt_byte_offset++;
+    if (corrupt_byte_offset == leader_batch_size) {
+      // skip over the sequence number part of follower's write batch
+      corrupt_byte_offset += 8;
+    }
+    follower_joined = false;
+    leader_count = 0;
+  }
+  SyncPoint::GetInstance()->DisableProcessing();
+}
+
+INSTANTIATE_TEST_CASE_P(
+    DbKvChecksumTestMergedBatch, DbKvChecksumTestMergedBatch,
+    ::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0),
+                                        WriteBatchOpType::kNum),
+                       ::testing::Range(static_cast<WriteBatchOpType>(0),
+                                        WriteBatchOpType::kNum),
+                       ::testing::Values(2, 103, 251)),
+    [](const testing::TestParamInfo<
+        std::tuple<WriteBatchOpType, WriteBatchOpType, char>>& args) {
+      std::ostringstream oss;
+      oss << GetOpTypeString(std::get<0>(args.param))
+          << GetOpTypeString(std::get<1>(args.param)) << "Add"
+          << static_cast<int>(
+                 static_cast<unsigned char>(std::get<2>(args.param)));
+      return oss.str();
+    });
+
+// TODO: add test for transactions
+// TODO: add test for corrupted write batch with WAL disabled
 }  // namespace ROCKSDB_NAMESPACE
 
 int main(int argc, char** argv) {
index 788b9bae49e2fbca64b1dc237162ef598e53fe9c..b919ea056d67f943c2d225ca4dee9d161f1fe86e 100644 (file)
@@ -1491,6 +1491,94 @@ Status WriteBatch::UpdateTimestamps(
   return s;
 }
 
+Status WriteBatch::VerifyChecksum() const {
+  if (prot_info_ == nullptr) {
+    return Status::OK();
+  }
+  Slice input(rep_.data() + WriteBatchInternal::kHeader,
+              rep_.size() - WriteBatchInternal::kHeader);
+  Slice key, value, blob, xid;
+  char tag = 0;
+  uint32_t column_family = 0;  // default
+  Status s;
+  size_t prot_info_idx = 0;
+  bool checksum_protected = true;
+  while (!input.empty() && prot_info_idx < prot_info_->entries_.size()) {
+    // In case key/value/column_family are not updated by
+    // ReadRecordFromWriteBatch
+    key.clear();
+    value.clear();
+    column_family = 0;
+    s = ReadRecordFromWriteBatch(&input, &tag, &column_family, &key, &value,
+                                 &blob, &xid);
+    if (!s.ok()) {
+      return s;
+    }
+    checksum_protected = true;
+    // Write batch checksum uses op_type without ColumnFamily (e.g., if op_type
+    // in the write batch is kTypeColumnFamilyValue, kTypeValue is used to
+    // compute the checksum), and encodes column family id separately. See
+    // comment in first `WriteBatchInternal::Put()` for more detail.
+    switch (tag) {
+      case kTypeColumnFamilyValue:
+      case kTypeValue:
+        tag = kTypeValue;
+        break;
+      case kTypeColumnFamilyDeletion:
+      case kTypeDeletion:
+        tag = kTypeDeletion;
+        break;
+      case kTypeColumnFamilySingleDeletion:
+      case kTypeSingleDeletion:
+        tag = kTypeSingleDeletion;
+        break;
+      case kTypeColumnFamilyRangeDeletion:
+      case kTypeRangeDeletion:
+        tag = kTypeRangeDeletion;
+        break;
+      case kTypeColumnFamilyMerge:
+      case kTypeMerge:
+        tag = kTypeMerge;
+        break;
+      case kTypeColumnFamilyBlobIndex:
+      case kTypeBlobIndex:
+        tag = kTypeBlobIndex;
+        break;
+      case kTypeLogData:
+      case kTypeBeginPrepareXID:
+      case kTypeEndPrepareXID:
+      case kTypeCommitXID:
+      case kTypeRollbackXID:
+      case kTypeNoop:
+      case kTypeBeginPersistedPrepareXID:
+      case kTypeBeginUnprepareXID:
+      case kTypeDeletionWithTimestamp:
+      case kTypeCommitXIDAndTimestamp:
+        checksum_protected = false;
+        break;
+      default:
+        return Status::Corruption(
+            "unknown WriteBatch tag",
+            std::to_string(static_cast<unsigned int>(tag)));
+    }
+    if (checksum_protected) {
+      s = prot_info_->entries_[prot_info_idx++]
+              .StripC(column_family)
+              .StripKVO(key, value, static_cast<ValueType>(tag))
+              .GetStatus();
+      if (!s.ok()) {
+        return s;
+      }
+    }
+  }
+
+  if (prot_info_idx != WriteBatchInternal::Count(this)) {
+    return Status::Corruption("WriteBatch has wrong count");
+  }
+  assert(WriteBatchInternal::Count(this) == prot_info_->entries_.size());
+  return Status::OK();
+}
+
 namespace {
 
 class MemTableInserter : public WriteBatch::Handler {
@@ -2773,6 +2861,14 @@ Status WriteBatchInternal::Append(WriteBatch* dst, const WriteBatch* src,
                                   const bool wal_only) {
   assert(dst->Count() == 0 ||
          (dst->prot_info_ == nullptr) == (src->prot_info_ == nullptr));
+  if ((src->prot_info_ != nullptr &&
+       src->prot_info_->entries_.size() != src->Count()) ||
+      (dst->prot_info_ != nullptr &&
+       dst->prot_info_->entries_.size() != dst->Count())) {
+    return Status::Corruption(
+        "Write batch has inconsistent count and number of checksums");
+  }
+
   size_t src_len;
   int src_count;
   uint32_t src_flags;
index 49abed74e10c3cd3a654266ec2f5fb691cf2ab76..926acc63a0776e50684f38c27f29ad075035f5bb 100644 (file)
@@ -206,6 +206,10 @@ class WriteBatchInternal {
                            bool batch_per_txn = true,
                            bool hint_per_batch = false);
 
+  // Appends src write batch to dst write batch and updates count in dst
+  // write batch. Returns OK if the append is successful. Checks number of
+  // checksum against count in dst and src write batches, and returns Corruption
+  // if the count is inconsistent.
   static Status Append(WriteBatch* dst, const WriteBatch* src,
                        const bool WAL_only = false);
 
index d59eba263522759a68d561974b8d9e0ea0796028..06d7f4500041afb7a8678fd44aa80f56541abe96 100644 (file)
@@ -389,6 +389,7 @@ void WriteThread::JoinBatchGroup(Writer* w) {
   }
 
   TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait", w);
+  TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait2", w);
 
   if (!linked_as_leader) {
     /**
index 618e4734c86c9dc32b9754fa9e2c24f07ecdb961..d8bd108ea01ac3d780f488c124338c413c8fb047 100644 (file)
@@ -391,6 +391,12 @@ class WriteBatch : public WriteBatchBase {
   Status UpdateTimestamps(const Slice& ts,
                           std::function<size_t(uint32_t /*cf*/)> ts_sz_func);
 
+  // Verify the per-key-value checksums of this write batch.
+  // Corruption status will be returned if the verification fails.
+  // If this write batch does not have per-key-value checksum,
+  // OK status will be returned.
+  Status VerifyChecksum() const;
+
   using WriteBatchBase::GetWriteBatch;
   WriteBatch* GetWriteBatch() override { return this; }
 
index a163d86677cb2f1c103f0d1f36dee633d06d53c4..46d8a9af17607291b388b1b294cf18b9f52efc86 100644 (file)
@@ -1656,6 +1656,10 @@ static const bool FLAGS_table_cache_numshardbits_dummy __attribute__((__unused__
     RegisterFlagValidator(&FLAGS_table_cache_numshardbits,
                           &ValidateTableCacheNumshardbits);
 
+DEFINE_uint32(write_batch_protection_bytes_per_key, 0,
+              "Size of per-key-value checksum in each write batch. Currently "
+              "only value 0 and 8 are supported.");
+
 namespace ROCKSDB_NAMESPACE {
 namespace {
 static Status CreateMemTableRepFactory(
@@ -4910,7 +4914,8 @@ class Benchmark {
 
     RandomGenerator gen;
     WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
-                     /*protection_bytes_per_key=*/0, user_timestamp_size_);
+                     FLAGS_write_batch_protection_bytes_per_key,
+                     user_timestamp_size_);
     Status s;
     int64_t bytes = 0;
 
@@ -6699,7 +6704,8 @@ class Benchmark {
 
   void DoDelete(ThreadState* thread, bool seq) {
     WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
-                     /*protection_bytes_per_key=*/0, user_timestamp_size_);
+                     FLAGS_write_batch_protection_bytes_per_key,
+                     user_timestamp_size_);
     Duration duration(seq ? 0 : FLAGS_duration, deletes_);
     int64_t i = 0;
     std::unique_ptr<const char[]> key_guard;
@@ -6899,7 +6905,8 @@ class Benchmark {
     std::string keys[3];
 
     WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
-                     /*protection_bytes_per_key=*/0, user_timestamp_size_);
+                     FLAGS_write_batch_protection_bytes_per_key,
+                     user_timestamp_size_);
     Status s;
     for (int i = 0; i < 3; i++) {
       keys[i] = key.ToString() + suffixes[i];
@@ -6931,7 +6938,7 @@ class Benchmark {
     std::string suffixes[3] = {"1", "2", "0"};
     std::string keys[3];
 
-    WriteBatch batch(0, 0, /*protection_bytes_per_key=*/0,
+    WriteBatch batch(0, 0, FLAGS_write_batch_protection_bytes_per_key,
                      user_timestamp_size_);
     Status s;
     for (int i = 0; i < 3; i++) {