]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
common/async: add yield_waiter template
authorCasey Bodley <cbodley@redhat.com>
Tue, 30 Apr 2024 16:04:15 +0000 (12:04 -0400)
committerYuval Lifshitz <ylifshit@ibm.com>
Tue, 22 Apr 2025 14:33:04 +0000 (14:33 +0000)
Signed-off-by: Casey Bodley <cbodley@redhat.com>
(cherry picked from commit dd779c74e1eebaf888d95b2329c7d5ead176f0a9)

src/common/async/yield_waiter.h [new file with mode: 0644]
src/test/common/CMakeLists.txt
src/test/common/test_async_yield_waiter.cc [new file with mode: 0644]

diff --git a/src/common/async/yield_waiter.h b/src/common/async/yield_waiter.h
new file mode 100644 (file)
index 0000000..9c14d9b
--- /dev/null
@@ -0,0 +1,191 @@
+// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
+// vim: ts=8 sw=2 smarttab ft=cpp
+
+/*
+ * Ceph - scalable distributed file system
+ *
+ * Copyright contributors to the Ceph project
+ *
+ * This is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License version 2.1, as published by the Free Software
+ * Foundation. See file COPYING.
+ *
+ */
+
+#pragma once
+
+#include <exception>
+#include <optional>
+#include <boost/asio/append.hpp>
+#include <boost/asio/associated_cancellation_slot.hpp>
+#include <boost/asio/async_result.hpp>
+#include <boost/asio/dispatch.hpp>
+#include <boost/asio/spawn.hpp>
+
+namespace ceph::async {
+
+/// Captures a yield_context handler for deferred completion or cancellation.
+template <typename Ret>
+class yield_waiter {
+ public:
+  /// Function signature for the completion handler.
+  using Signature = void(boost::system::error_code, Ret);
+
+  yield_waiter() = default;
+
+  // copy and move are disabled because the cancellation handler captures 'this'
+  yield_waiter(const yield_waiter&) = delete;
+  yield_waiter& operator=(const yield_waiter&) = delete;
+
+  /// Returns true if there's a handler awaiting completion.
+  operator bool() const { return state.has_value(); }
+
+  /// Suspends the given yield_context until the captured handler is invoked
+  /// via complete() or cancel().
+  template <typename CompletionToken>
+  auto async_wait(CompletionToken&& token)
+  {
+    return boost::asio::async_initiate<CompletionToken, Signature>(
+        [this] (handler_type h) {
+          auto slot = get_associated_cancellation_slot(h);
+          if (slot.is_connected()) {
+            slot.template emplace<op_cancellation>(this);
+          }
+          state.emplace(std::move(h));
+        }, token);
+  }
+
+  /// Schedule the completion handler with the given arguments.
+  void complete(boost::system::error_code ec, Ret value)
+  {
+    auto s = std::move(*state);
+    state.reset();
+    auto h = boost::asio::append(std::move(s.handler), ec, std::move(value));
+    boost::asio::dispatch(std::move(h));
+  }
+
+  /// Destroy the completion handler.
+  void shutdown()
+  {
+    state.reset();
+  }
+
+ private:
+  using handler_type = typename boost::asio::async_result<
+      boost::asio::yield_context, Signature>::handler_type;
+  using work_guard = boost::asio::executor_work_guard<
+      boost::asio::any_io_executor>;
+
+  struct handler_state {
+    handler_type handler;
+    work_guard work;
+
+    explicit handler_state(handler_type&& h)
+      : handler(std::move(h)),
+        work(make_work_guard(handler))
+    {}
+  };
+  std::optional<handler_state> state;
+
+  struct op_cancellation {
+    yield_waiter* self;
+    op_cancellation(yield_waiter* self) : self(self) {}
+    void operator()(boost::asio::cancellation_type type) {
+      if (type != boost::asio::cancellation_type::none) {
+        self->cancel();
+      }
+    }
+  };
+
+  // Cancel the coroutine with an operation_aborted error.
+  void cancel()
+  {
+    if (state) {
+      complete(make_error_code(boost::asio::error::operation_aborted), Ret{});
+    }
+  }
+};
+
+// specialization for Ret=void
+template <>
+class yield_waiter<void> {
+ public:
+  /// Function signature for the completion handler.
+  using Signature = void(boost::system::error_code);
+
+  yield_waiter() = default;
+
+  // copy and move are disabled because the cancellation handler captures 'this'
+  yield_waiter(const yield_waiter&) = delete;
+  yield_waiter& operator=(const yield_waiter&) = delete;
+
+  /// Returns true if there's a handler awaiting completion.
+  operator bool() const { return state.has_value(); }
+
+  /// Suspends the given yield_context until the captured handler is invoked
+  /// via complete() or cancel().
+  template <typename CompletionToken>
+  auto async_wait(CompletionToken&& token)
+  {
+    return boost::asio::async_initiate<CompletionToken, Signature>(
+        [this] (handler_type h) {
+          auto slot = get_associated_cancellation_slot(h);
+          if (slot.is_connected()) {
+            slot.template emplace<op_cancellation>(this);
+          }
+          state.emplace(std::move(h));
+        }, token);
+  }
+
+  /// Schedule the completion handler with the given arguments.
+  void complete(boost::system::error_code ec)
+  {
+    auto s = std::move(*state);
+    state.reset();
+    boost::asio::dispatch(boost::asio::append(std::move(s.handler), ec));
+  }
+
+  /// Destroy the completion handler.
+  void shutdown()
+  {
+    state.reset();
+  }
+
+ private:
+  using handler_type = typename boost::asio::async_result<
+      boost::asio::yield_context, Signature>::handler_type;
+  using work_guard = boost::asio::executor_work_guard<
+      boost::asio::any_io_executor>;
+
+  struct handler_state {
+    handler_type handler;
+    work_guard work;
+
+    explicit handler_state(handler_type&& h)
+      : handler(std::move(h)),
+        work(make_work_guard(handler))
+    {}
+  };
+  std::optional<handler_state> state;
+
+  struct op_cancellation {
+    yield_waiter* self;
+    op_cancellation(yield_waiter* self) : self(self) {}
+    void operator()(boost::asio::cancellation_type type) {
+      if (type != boost::asio::cancellation_type::none) {
+        self->cancel();
+      }
+    }
+  };
+
+  // Cancel the coroutine with an operation_aborted error.
+  void cancel()
+  {
+    if (state) {
+      complete(make_error_code(boost::asio::error::operation_aborted));
+    }
+  }
+};
+
+} // namespace ceph::async
index 8aabfaa98e991bdadaaa6b698b89298494621253..5fbd3be6ac434d9715a263a1a7cef232df6beca8 100644 (file)
@@ -367,6 +367,10 @@ add_executable(unittest_async_shared_mutex test_async_shared_mutex.cc)
 add_ceph_unittest(unittest_async_shared_mutex)
 target_link_libraries(unittest_async_shared_mutex ceph-common Boost::system)
 
+add_executable(unittest_async_yield_waiter test_async_yield_waiter.cc)
+add_ceph_unittest(unittest_async_yield_waiter)
+target_link_libraries(unittest_async_yield_waiter ceph-common Boost::system Boost::context)
+
 add_executable(unittest_cdc test_cdc.cc
   $<TARGET_OBJECTS:unit-main>)
 target_link_libraries(unittest_cdc global ceph-common)
diff --git a/src/test/common/test_async_yield_waiter.cc b/src/test/common/test_async_yield_waiter.cc
new file mode 100644 (file)
index 0000000..6746825
--- /dev/null
@@ -0,0 +1,243 @@
+// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
+// vim: ts=8 sw=2 smarttab ft=cpp
+
+/*
+ * Ceph - scalable distributed file system
+ *
+ * Copyright contributors to the Ceph project
+ *
+ * This is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License version 2.1, as published by the Free Software
+ * Foundation. See file COPYING.
+ *
+ */
+
+#include "common/async/yield_waiter.h"
+#include <exception>
+#include <memory>
+#include <optional>
+#include <boost/asio/io_context.hpp>
+#include <boost/asio/spawn.hpp>
+#include <gtest/gtest.h>
+
+namespace ceph::async {
+
+namespace asio = boost::asio;
+using error_code = boost::system::error_code;
+
+void rethrow(std::exception_ptr eptr)
+{
+  if (eptr) std::rethrow_exception(eptr);
+}
+
+auto capture(std::optional<std::exception_ptr>& eptr)
+{
+  return [&eptr] (std::exception_ptr e) { eptr = e; };
+}
+
+
+TEST(YieldWaiterVoid, wait_shutdown)
+{
+  asio::io_context ctx;
+  yield_waiter<void> waiter;
+
+  asio::spawn(ctx, [&waiter] (asio::yield_context yield) {
+        waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+}
+
+TEST(YieldWaiterVoid, wait_complete)
+{
+  asio::io_context ctx;
+  yield_waiter<void> waiter;
+
+  asio::spawn(ctx, [&waiter] (asio::yield_context yield) {
+        waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(error_code{});
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  EXPECT_TRUE(ctx.stopped());
+}
+
+TEST(YieldWaiterVoid, wait_error)
+{
+  asio::io_context ctx;
+  yield_waiter<void> waiter;
+  std::optional<std::exception_ptr> eptr;
+
+  asio::spawn(ctx, [&waiter] (asio::yield_context yield) {
+        waiter.async_wait(yield);
+      }, capture(eptr));
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(make_error_code(asio::error::operation_aborted));
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  ASSERT_TRUE(ctx.stopped());
+  ASSERT_TRUE(eptr);
+  ASSERT_TRUE(*eptr);
+  try {
+    std::rethrow_exception(*eptr);
+  } catch (const boost::system::system_error& e) {
+    EXPECT_EQ(e.code(), asio::error::operation_aborted);
+  } catch (const std::exception&) {
+    EXPECT_THROW(throw, boost::system::system_error);
+  }
+}
+
+
+TEST(YieldWaiterInt, wait_shutdown)
+{
+  asio::io_context ctx;
+  yield_waiter<int> waiter;
+
+  asio::spawn(ctx, [&waiter] (asio::yield_context yield) {
+        waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+}
+
+TEST(YieldWaiterInt, wait_complete)
+{
+  asio::io_context ctx;
+  yield_waiter<int> waiter;
+  std::optional<int> result;
+
+  asio::spawn(ctx, [&waiter, &result] (asio::yield_context yield) {
+        result = waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(error_code{}, 42);
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  EXPECT_TRUE(ctx.stopped());
+  ASSERT_TRUE(result);
+  EXPECT_EQ(42, *result);
+}
+
+TEST(YieldWaiterInt, wait_error)
+{
+  asio::io_context ctx;
+  yield_waiter<int> waiter;
+  std::optional<int> result;
+  std::optional<std::exception_ptr> eptr;
+
+  asio::spawn(ctx, [&waiter, &result] (asio::yield_context yield) {
+        result = waiter.async_wait(yield);
+      }, capture(eptr));
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(make_error_code(std::errc::no_such_file_or_directory), 0);
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  ASSERT_TRUE(ctx.stopped());
+  EXPECT_FALSE(result);
+  ASSERT_TRUE(eptr);
+  ASSERT_TRUE(*eptr);
+  try {
+    std::rethrow_exception(*eptr);
+  } catch (const boost::system::system_error& e) {
+    EXPECT_EQ(e.code(), std::errc::no_such_file_or_directory);
+  } catch (const std::exception&) {
+    EXPECT_THROW(throw, boost::system::system_error);
+  }
+}
+
+
+// test with move-only value type
+TEST(YieldWaiterPtr, wait_shutdown)
+{
+  asio::io_context ctx;
+  yield_waiter<std::unique_ptr<int>> waiter;
+
+  asio::spawn(ctx, [&waiter] (asio::yield_context yield) {
+        waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+}
+
+TEST(YieldWaiterPtr, wait_complete)
+{
+  asio::io_context ctx;
+  yield_waiter<std::unique_ptr<int>> waiter;
+  std::optional<std::unique_ptr<int>> result;
+
+  asio::spawn(ctx, [&waiter, &result] (asio::yield_context yield) {
+        result = waiter.async_wait(yield);
+      }, rethrow);
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(error_code{}, std::make_unique<int>(42));
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  EXPECT_TRUE(ctx.stopped());
+  ASSERT_TRUE(result);
+  ASSERT_TRUE(*result);
+  EXPECT_EQ(42, **result);
+}
+
+TEST(YieldWaiterPtr, wait_error)
+{
+  asio::io_context ctx;
+  yield_waiter<std::unique_ptr<int>> waiter;
+  std::optional<std::unique_ptr<int>> result;
+  std::optional<std::exception_ptr> eptr;
+
+  asio::spawn(ctx, [&waiter, &result] (asio::yield_context yield) {
+        result = waiter.async_wait(yield);
+      }, capture(eptr));
+
+  ctx.poll();
+  ASSERT_FALSE(ctx.stopped());
+
+  ASSERT_TRUE(waiter);
+  waiter.complete(make_error_code(std::errc::no_such_file_or_directory), nullptr);
+  EXPECT_FALSE(waiter);
+
+  ctx.poll();
+  ASSERT_TRUE(ctx.stopped());
+  EXPECT_FALSE(result);
+  ASSERT_TRUE(eptr);
+  ASSERT_TRUE(*eptr);
+  try {
+    std::rethrow_exception(*eptr);
+  } catch (const boost::system::system_error& e) {
+    EXPECT_EQ(e.code(), std::errc::no_such_file_or_directory);
+  } catch (const std::exception&) {
+    EXPECT_THROW(throw, boost::system::system_error);
+  }
+}
+
+} // namespace ceph::async