]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
common/async: implement max_concurrent_for_each() for awaitable
authorCasey Bodley <cbodley@redhat.com>
Thu, 27 Jun 2024 20:53:01 +0000 (16:53 -0400)
committerCasey Bodley <cbodley@redhat.com>
Wed, 24 Jul 2024 16:51:07 +0000 (12:51 -0400)
Signed-off-by: Casey Bodley <cbodley@redhat.com>
src/common/async/max_concurrent_for_each.h
src/test/common/test_async_max_concurrent_for_each.cc

index c0789a1606431e41d3a4783698905483dbb7531d..c99a0fbb04832ba4173d34164b0cb9e779ef63ed 100644 (file)
@@ -21,6 +21,7 @@
 #include <utility>
 #include <boost/asio/spawn.hpp>
 #include "cancel_on_error.h"
+#include "co_throttle.h"
 #include "yield_context.h"
 #include "spawn_throttle.h"
 
@@ -54,8 +55,7 @@ void max_concurrent_for_each(Iterator begin,
                              Func&& func,
                              cancel_on_error on_error = cancel_on_error::none)
 {
-  const size_t count = std::ranges::distance(begin, end);
-  if (!count) {
+  if (begin == end) {
     return;
   }
   auto throttle = spawn_throttle{y, max_concurrent, on_error};
@@ -84,6 +84,54 @@ auto max_concurrent_for_each(Range&& range,
                                  on_error);
 }
 
-// TODO: overloads for co_spawn()
+// \overload
+template <typename Iterator, typename Sentinel, typename VoidAwaitableFactory,
+          typename Value = std::iter_reference_t<Iterator>,
+          typename VoidAwaitable = std::invoke_result_t<
+              VoidAwaitableFactory, Value>,
+          typename AwaitableT = typename VoidAwaitable::value_type,
+          typename AwaitableExecutor = typename VoidAwaitable::executor_type>
+    requires (std::input_iterator<Iterator> &&
+              std::sentinel_for<Sentinel, Iterator> &&
+              std::same_as<AwaitableT, void> &&
+              boost::asio::execution::executor<AwaitableExecutor>)
+auto max_concurrent_for_each(Iterator begin,
+                             Sentinel end,
+                             size_t max_concurrent,
+                             VoidAwaitableFactory&& factory,
+                             cancel_on_error on_error = cancel_on_error::none)
+    -> boost::asio::awaitable<void, AwaitableExecutor>
+{
+  if (begin == end) {
+    co_return;
+  }
+  auto ex = co_await boost::asio::this_coro::executor;
+  auto throttle = co_throttle{ex, max_concurrent, on_error};
+  for (Iterator i = begin; i != end; ++i) {
+    co_await throttle.spawn(factory(*i));
+  }
+  co_await throttle.wait();
+}
+
+/// \overload
+template <typename Range, typename VoidAwaitableFactory,
+          typename Value = std::ranges::range_reference_t<Range>,
+          typename VoidAwaitable = std::invoke_result_t<
+              VoidAwaitableFactory, Value>,
+          typename AwaitableT = typename VoidAwaitable::value_type,
+          typename AwaitableExecutor = typename VoidAwaitable::executor_type>
+    requires (std::ranges::range<Range> &&
+              std::same_as<AwaitableT, void> &&
+              boost::asio::execution::executor<AwaitableExecutor>)
+auto max_concurrent_for_each(Range&& range,
+                             size_t max_concurrent,
+                             VoidAwaitableFactory&& factory,
+                             cancel_on_error on_error = cancel_on_error::none)
+    -> boost::asio::awaitable<void, AwaitableExecutor>
+{
+  return max_concurrent_for_each(
+      std::begin(range), std::end(range), max_concurrent,
+      std::forward<VoidAwaitableFactory>(factory), on_error);
+}
 
 } // namespace ceph::async
index 2e6f919a39485445065eb5a95b1e9c14f3632251..b0880dfdb8538b9c4ac1933c92e59a624ece4fb1 100644 (file)
@@ -39,6 +39,12 @@ void wait_for(std::chrono::milliseconds dur, asio::yield_context yield)
   timer.async_wait(yield);
 }
 
+asio::awaitable<void> wait_for(std::chrono::milliseconds dur)
+{
+  auto timer = asio::steady_timer{co_await asio::this_coro::executor, dur};
+  co_await timer.async_wait(asio::use_awaitable);
+}
+
 struct null_sentinel {};
 bool operator==(const char* c, null_sentinel) { return !*c; }
 static_assert(std::sentinel_for<null_sentinel, const char*>);
@@ -222,4 +228,108 @@ TEST(range_yield, over_limit)
   EXPECT_EQ(10, completed);
 }
 
+TEST(iterator_co, empty)
+{
+  int* end = nullptr;
+  auto cr = [] (int) -> asio::awaitable<void> { co_return; };
+
+  asio::io_context ctx;
+  asio::co_spawn(ctx, [&] () -> asio::awaitable<void> {
+        co_await max_concurrent_for_each(end, end, 10, cr);
+      }, rethrow);
+  ctx.run();
+}
+
+TEST(iterator_co, over_limit)
+{
+  int concurrent = 0;
+  int max_concurrent = 0;
+  int completed = 0;
+
+  auto cr = [&] (int) -> asio::awaitable<void> {
+    ++concurrent;
+    if (max_concurrent < concurrent) {
+      max_concurrent = concurrent;
+    }
+
+    co_await wait_for(1ms);
+
+    --concurrent;
+    ++completed;
+  };
+
+  asio::io_context ctx;
+  asio::co_spawn(ctx, [&] () -> asio::awaitable<void> {
+        constexpr auto arr = std::array{1,2,3,4,5,6,7,8,9,10};
+        co_await max_concurrent_for_each(begin(arr), end(arr), 2, cr);
+      }, rethrow);
+  ctx.run();
+
+  EXPECT_EQ(0, concurrent);
+  EXPECT_EQ(2, max_concurrent);
+  EXPECT_EQ(10, completed);
+}
+
+TEST(iterator_co, sentinel)
+{
+  const char* begin = "hello";
+  null_sentinel end;
+
+  size_t completed = 0;
+  auto cr = [&completed] (char c) -> asio::awaitable<void> {
+    ++completed;
+    co_return;
+  };
+
+  asio::io_context ctx;
+  asio::co_spawn(ctx, [&] () -> asio::awaitable<void> {
+        co_await max_concurrent_for_each(begin, end, 10, cr);
+      }, rethrow);
+  ctx.run();
+
+  EXPECT_EQ(completed, 5);
+}
+
+TEST(range_co, empty)
+{
+  constexpr std::array<int, 0> arr{};
+  auto cr = [] (int) -> asio::awaitable<void> { co_return; };
+
+  asio::io_context ctx;
+  asio::co_spawn(ctx, [&] () -> asio::awaitable<void> {
+        co_await max_concurrent_for_each(arr, 10, cr);
+      }, rethrow);
+  ctx.run();
+}
+
+TEST(range_co, over_limit)
+{
+  int concurrent = 0;
+  int max_concurrent = 0;
+  int completed = 0;
+
+  auto cr = [&] (int) -> asio::awaitable<void> {
+    ++concurrent;
+    if (max_concurrent < concurrent) {
+      max_concurrent = concurrent;
+    }
+
+    co_await wait_for(1ms);
+
+    --concurrent;
+    ++completed;
+  };
+
+  asio::io_context ctx;
+  asio::co_spawn(ctx, [&] () -> asio::awaitable<void> {
+        constexpr auto arr = std::array{1,2,3,4,5,6,7,8,9,10};
+        co_await max_concurrent_for_each(arr, 2, cr);
+      }, rethrow);
+  ctx.run();
+
+  EXPECT_EQ(0, concurrent);
+  EXPECT_EQ(2, max_concurrent);
+  EXPECT_EQ(10, completed);
+}
+
 } // namespace ceph::async