]> git-server-git.apps.pok.os.sepia.ceph.com Git - rocksdb.git/commitdiff
fix shared state used after free (#11059)
authorehds <ehds@qq.com>
Thu, 5 Jan 2023 03:35:34 +0000 (19:35 -0800)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 5 Jan 2023 03:35:34 +0000 (19:35 -0800)
Summary:
Before this pr,  the destruction order is `shared` -> `db_`(StressTest destruction) -> `stress`, but `compaction_filter` of `db_` will hold the `shared` pointer, so `shared` maybe used after free.

Pull Request resolved: https://github.com/facebook/rocksdb/pull/11059

Reviewed By: hx235

Differential Revision: D42297366

Pulled By: ajkr

fbshipit-source-id: 17b314635359acacd5ba62f9db5f955f451133f7

db_stress_tool/db_stress_driver.cc
db_stress_tool/db_stress_driver.h
db_stress_tool/db_stress_tool.cc

index ed1240e0078c534bd39e83a16c323de2ec732993..2c8dcf6108670fa131008230bbef02f8d57457ad 100644 (file)
@@ -56,12 +56,11 @@ void ThreadBody(void* v) {
   }
 }
 
-bool RunStressTest(StressTest* stress) {
+bool RunStressTest(SharedState* shared) {
   SystemClock* clock = db_stress_env->GetSystemClock().get();
+  StressTest* stress = shared->GetStressTest();
 
-  SharedState shared(db_stress_env, stress);
-
-  if (shared.ShouldVerifyAtBeginning() && FLAGS_preserve_unverified_changes) {
+  if (shared->ShouldVerifyAtBeginning() && FLAGS_preserve_unverified_changes) {
     Status s = InitUnverifiedSubdir(FLAGS_db);
     if (s.ok() && !FLAGS_expected_values_dir.empty()) {
       s = InitUnverifiedSubdir(FLAGS_expected_values_dir);
@@ -73,8 +72,8 @@ bool RunStressTest(StressTest* stress) {
     }
   }
 
-  stress->InitDb(&shared);
-  stress->FinishInitDb(&shared);
+  stress->InitDb(shared);
+  stress->FinishInitDb(shared);
 
   if (FLAGS_sync_fault_injection) {
     fault_fs_guard->SetFilesystemDirectWritable(false);
@@ -88,28 +87,28 @@ bool RunStressTest(StressTest* stress) {
   fprintf(stdout, "%s Initializing worker threads\n",
           clock->TimeToString(now / 1000000).c_str());
 
-  shared.SetThreads(n);
+  shared->SetThreads(n);
 
   if (FLAGS_compaction_thread_pool_adjust_interval > 0) {
-    shared.IncBgThreads();
+    shared->IncBgThreads();
   }
 
   if (FLAGS_continuous_verification_interval > 0) {
-    shared.IncBgThreads();
+    shared->IncBgThreads();
   }
 
   std::vector<ThreadState*> threads(n);
   for (uint32_t i = 0; i < n; i++) {
-    threads[i] = new ThreadState(i, &shared);
+    threads[i] = new ThreadState(i, shared);
     db_stress_env->StartThread(ThreadBody, threads[i]);
   }
 
-  ThreadState bg_thread(0, &shared);
+  ThreadState bg_thread(0, shared);
   if (FLAGS_compaction_thread_pool_adjust_interval > 0) {
     db_stress_env->StartThread(PoolSizeChangeThread, &bg_thread);
   }
 
-  ThreadState continuous_verification_thread(0, &shared);
+  ThreadState continuous_verification_thread(0, shared);
   if (FLAGS_continuous_verification_interval > 0) {
     db_stress_env->StartThread(DbVerificationThread,
                                &continuous_verification_thread);
@@ -120,12 +119,12 @@ bool RunStressTest(StressTest* stress) {
   // wait for others to operate -> verify -> done
 
   {
-    MutexLock l(shared.GetMutex());
-    while (!shared.AllInitialized()) {
-      shared.GetCondVar()->Wait();
+    MutexLock l(shared->GetMutex());
+    while (!shared->AllInitialized()) {
+      shared->GetCondVar()->Wait();
     }
-    if (shared.ShouldVerifyAtBeginning()) {
-      if (shared.HasVerificationFailedYet()) {
+    if (shared->ShouldVerifyAtBeginning()) {
+      if (shared->HasVerificationFailedYet()) {
         fprintf(stderr, "Crash-recovery verification failed :(\n");
       } else {
         fprintf(stdout, "Crash-recovery verification passed :)\n");
@@ -144,17 +143,17 @@ bool RunStressTest(StressTest* stress) {
     // This is after the verification step to avoid making all those `Get()`s
     // and `MultiGet()`s contend on the DB-wide trace mutex.
     if (!FLAGS_expected_values_dir.empty()) {
-      stress->TrackExpectedState(&shared);
+      stress->TrackExpectedState(shared);
     }
 
     now = clock->NowMicros();
     fprintf(stdout, "%s Starting database operations\n",
             clock->TimeToString(now / 1000000).c_str());
 
-    shared.SetStart();
-    shared.GetCondVar()->SignalAll();
-    while (!shared.AllOperated()) {
-      shared.GetCondVar()->Wait();
+    shared->SetStart();
+    shared->GetCondVar()->SignalAll();
+    while (!shared->AllOperated()) {
+      shared->GetCondVar()->Wait();
     }
 
     now = clock->NowMicros();
@@ -169,10 +168,10 @@ bool RunStressTest(StressTest* stress) {
               clock->TimeToString((uint64_t)now / 1000000).c_str());
     }
 
-    shared.SetStartVerify();
-    shared.GetCondVar()->SignalAll();
-    while (!shared.AllDone()) {
-      shared.GetCondVar()->Wait();
+    shared->SetStartVerify();
+    shared->GetCondVar()->SignalAll();
+    while (!shared->AllDone()) {
+      shared->GetCondVar()->Wait();
     }
   }
 
@@ -187,7 +186,7 @@ bool RunStressTest(StressTest* stress) {
   }
   now = clock->NowMicros();
   if (!FLAGS_skip_verifydb && !FLAGS_test_batches_snapshots &&
-      !shared.HasVerificationFailedYet()) {
+      !shared->HasVerificationFailedYet()) {
     fprintf(stdout, "%s Verification successful\n",
             clock->TimeToString(now / 1000000).c_str());
   }
@@ -195,14 +194,14 @@ bool RunStressTest(StressTest* stress) {
 
   if (FLAGS_compaction_thread_pool_adjust_interval > 0 ||
       FLAGS_continuous_verification_interval > 0) {
-    MutexLock l(shared.GetMutex());
-    shared.SetShouldStopBgThread();
-    while (!shared.BgThreadsFinished()) {
-      shared.GetCondVar()->Wait();
+    MutexLock l(shared->GetMutex());
+    shared->SetShouldStopBgThread();
+    while (!shared->BgThreadsFinished()) {
+      shared->GetCondVar()->Wait();
     }
   }
 
-  if (shared.HasVerificationFailedYet()) {
+  if (shared->HasVerificationFailedYet()) {
     fprintf(stderr, "Verification failed :(\n");
     return false;
   }
index ff701fcb2985c8fc631e843fce4e88da7d6350ed..a173470ff7d787f83ee7d4cf2f6809b3500ef61e 100644 (file)
@@ -7,11 +7,12 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file. See the AUTHORS file for names of contributors.
 
+#include "db_stress_tool/db_stress_shared_state.h"
 #ifdef GFLAGS
 #pragma once
 #include "db_stress_tool/db_stress_test_base.h"
 namespace ROCKSDB_NAMESPACE {
 extern void ThreadBody(void* /*thread_state*/);
-extern bool RunStressTest(StressTest*);
+extern bool RunStressTest(SharedState*);
 }  // namespace ROCKSDB_NAMESPACE
 #endif  // GFLAGS
index 7c73a08ba0e3ae8c96cca532494cbc78f2260ea2..fd28856b731187cf42ea33d5194c199a43e1dbba 100644 (file)
@@ -20,6 +20,7 @@
 // NOTE that if FLAGS_test_batches_snapshots is set, the test will have
 // different behavior. See comment of the flag for details.
 
+#include "db_stress_tool/db_stress_shared_state.h"
 #ifdef GFLAGS
 #include "db_stress_tool/db_stress_common.h"
 #include "db_stress_tool/db_stress_driver.h"
@@ -340,7 +341,7 @@ int db_stress_tool(int argc, char** argv) {
     key_gen_ctx.weights.emplace_back(key_gen_ctx.window -
                                      keys_per_level * (levels - 1));
   }
-
+  std::unique_ptr<ROCKSDB_NAMESPACE::SharedState> shared;
   std::unique_ptr<ROCKSDB_NAMESPACE::StressTest> stress;
   if (FLAGS_test_cf_consistency) {
     stress.reset(CreateCfConsistencyStressTest());
@@ -353,7 +354,8 @@ int db_stress_tool(int argc, char** argv) {
   }
   // Initialize the Zipfian pre-calculated array
   InitializeHotKeyGenerator(FLAGS_hot_key_alpha);
-  if (RunStressTest(stress.get())) {
+  shared.reset(new SharedState(db_stress_env, stress.get()));
+  if (RunStressTest(shared.get())) {
     return 0;
   } else {
     return 1;