]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
rgw/beast: reference count Connections for timeout_handler
authorCasey Bodley <cbodley@redhat.com>
Thu, 11 Nov 2021 17:01:06 +0000 (12:01 -0500)
committerCasey Bodley <cbodley@redhat.com>
Fri, 12 Nov 2021 14:53:01 +0000 (09:53 -0500)
resolves a use-after-free in the timeout_handler, where a timeout fires
and schedules the timeout_handler for execution, but the coroutine exits
and destroys the socket before asio executes the timeout_handler

timeout_handler now holds a reference on the Connection to extend its
lifetime

now that the Connection is allocated on the heap, we can include the
parse_buffer in this memory instead of allocating it separately

Signed-off-by: Casey Bodley <cbodley@redhat.com>
src/rgw/rgw_asio_frontend.cc
src/rgw/rgw_asio_frontend_timer.h

index 2954e64883d81fdbbfe78aa27c05400c229900dd..5aba68072c375e666b439ac14fbf4d4655d1f14d 100644 (file)
@@ -8,6 +8,7 @@
 
 #include <boost/asio.hpp>
 #include <boost/intrusive/list.hpp>
+#include <boost/smart_ptr/intrusive_ref_counter.hpp>
 
 #include <boost/context/protected_fixedsize_stack.hpp>
 #include <spawn/spawn.hpp>
@@ -43,6 +44,8 @@ namespace http = boost::beast::http;
 namespace ssl = boost::asio::ssl;
 #endif
 
+struct Connection;
+
 // use explicit executor types instead of the type-erased boost::asio::executor
 using executor_type = boost::asio::io_context::executor_type;
 
@@ -50,7 +53,7 @@ using tcp_socket = boost::asio::basic_stream_socket<tcp, executor_type>;
 using tcp_stream = boost::beast::basic_stream<tcp, executor_type>;
 
 using timeout_timer = rgw::basic_timeout_timer<ceph::coarse_mono_clock,
-      executor_type>;
+      executor_type, Connection>;
 
 using parse_buffer = boost::beast::flat_static_buffer<65536>;
 
@@ -68,24 +71,20 @@ class StreamIO : public rgw::asio::ClientIO {
   timeout_timer& timeout;
   yield_context yield;
   parse_buffer& buffer;
-  ceph::timespan request_timeout;
  public:
   StreamIO(CephContext *cct, Stream& stream, timeout_timer& timeout,
            rgw::asio::parser_type& parser, yield_context yield,
            parse_buffer& buffer, bool is_ssl,
            const tcp::endpoint& local_endpoint,
-           const tcp::endpoint& remote_endpoint,
-           ceph::timespan request_timeout)
+           const tcp::endpoint& remote_endpoint)
       : ClientIO(parser, is_ssl, local_endpoint, remote_endpoint),
         cct(cct), stream(stream), timeout(timeout), yield(yield),
-        buffer(buffer), request_timeout(request_timeout)
+        buffer(buffer)
   {}
 
   size_t write_data(const char* buf, size_t len) override {
     boost::system::error_code ec;
-    if (request_timeout.count()) {
-      timeout.expires_after(stream.lowest_layer(), request_timeout);
-    }
+    timeout.start();
     auto bytes = boost::asio::async_write(stream, boost::asio::buffer(buf, len),
                                           yield[ec]);
     timeout.cancel();
@@ -108,9 +107,7 @@ class StreamIO : public rgw::asio::ClientIO {
 
     while (body_remaining.size && !parser.is_done()) {
       boost::system::error_code ec;
-      if (request_timeout.count()) {
-        timeout.expires_after(stream.lowest_layer(), request_timeout);
-      }
+      timeout.start();
       http::async_read_some(stream, buffer, parser, yield[ec]);
       timeout.cancel();
       if (ec == http::error::need_buffer) {
@@ -186,8 +183,7 @@ void handle_connection(boost::asio::io_context& context,
                        SharedMutex& pause_mutex,
                        rgw::dmclock::Scheduler *scheduler,
                        boost::system::error_code& ec,
-                       yield_context yield,
-                       ceph::timespan request_timeout)
+                       yield_context yield)
 {
   // limit header to 4k, since we read it all into a single flat_buffer
   static constexpr size_t header_limit = 4096;
@@ -202,9 +198,7 @@ void handle_connection(boost::asio::io_context& context,
     rgw::asio::parser_type parser;
     parser.header_limit(header_limit);
     parser.body_limit(body_limit);
-    if (request_timeout.count()) {
-      timeout.expires_after(stream.lowest_layer(), request_timeout);
-    }
+    timeout.start();
     // parse the header
     http::async_read_header(stream, buffer, parser, yield[ec]);
     timeout.cancel();
@@ -225,9 +219,7 @@ void handle_connection(boost::asio::io_context& context,
       response.result(http::status::bad_request);
       response.version(message.version() == 10 ? 10 : 11);
       response.prepare_payload();
-      if (request_timeout.count()) {
-        timeout.expires_after(stream.lowest_layer(), request_timeout);
-      }
+      timeout.start();
       http::async_write(stream, response, yield[ec]);
       timeout.cancel();
       if (ec) {
@@ -258,7 +250,7 @@ void handle_connection(boost::asio::io_context& context,
 
       StreamIO real_client{cct, stream, timeout, parser, yield, buffer,
                            is_ssl, socket.local_endpoint(),
-                           remote_endpoint, request_timeout};
+                           remote_endpoint};
 
       auto real_client_io = rgw::io::add_reordering(
                               rgw::io::add_buffering(cct,
@@ -306,9 +298,7 @@ void handle_connection(boost::asio::io_context& context,
       body.size = discard_buffer.size();
       body.data = discard_buffer.data();
 
-      if (request_timeout.count()) {
-        timeout.expires_after(stream.lowest_layer(), request_timeout);
-      }
+      timeout.start();
       http::async_read_some(stream, buffer, parser, yield[ec]);
       timeout.cancel();
       if (ec == http::error::need_buffer) {
@@ -326,9 +316,20 @@ void handle_connection(boost::asio::io_context& context,
   }
 }
 
-struct Connection : boost::intrusive::list_base_hook<> {
-  tcp_socket& socket;
-  Connection(tcp_socket& socket) : socket(socket) {}
+// timeout support requires that connections are reference-counted, because the
+// timeout_handler can outlive the coroutine
+struct Connection : boost::intrusive::list_base_hook<>,
+                    boost::intrusive_ref_counter<Connection>
+{
+  tcp_socket socket;
+  parse_buffer buffer;
+
+  explicit Connection(tcp_socket&& socket) noexcept
+      : socket(std::move(socket)) {}
+
+  void close(boost::system::error_code& ec) {
+    socket.close(ec);
+  }
 };
 
 class ConnectionList {
@@ -967,32 +968,29 @@ void AsioFrontend::accept(Listener& l, boost::system::error_code ec)
   if (l.use_ssl) {
     spawn::spawn(context,
       [this, s=std::move(stream)] (yield_context yield) mutable {
-        Connection conn{s};
-        auto c = connections.add(conn);
+        auto conn = boost::intrusive_ptr{new Connection(std::move(s))};
+        auto c = connections.add(*conn);
         // wrap the tcp stream in an ssl stream
-        boost::asio::ssl::stream<tcp_socket&> stream{s, *ssl_context};
-        auto timeout = timeout_timer{context.get_executor()};
-        auto buffer = std::make_unique<parse_buffer>();
+        boost::asio::ssl::stream<tcp_socket&> stream{conn->socket, *ssl_context};
+        auto timeout = timeout_timer{context.get_executor(), request_timeout, conn};
         // do ssl handshake
         boost::system::error_code ec;
-        if (request_timeout.count()) {
-          timeout.expires_after(s, request_timeout);
-        }
+        timeout.start();
         auto bytes = stream.async_handshake(ssl::stream_base::server,
-                                            buffer->data(), yield[ec]);
+                                            conn->buffer.data(), yield[ec]);
         timeout.cancel();
         if (ec) {
           ldout(ctx(), 1) << "ssl handshake failed: " << ec.message() << dendl;
           return;
         }
-        buffer->consume(bytes);
-        handle_connection(context, env, stream, timeout, *buffer, true, pause_mutex,
-                          scheduler.get(), ec, yield, request_timeout);
+        conn->buffer.consume(bytes);
+        handle_connection(context, env, stream, timeout, conn->buffer, true,
+                          pause_mutex, scheduler.get(), ec, yield);
         if (!ec) {
           // ssl shutdown (ignoring errors)
           stream.async_shutdown(yield[ec]);
         }
-        s.shutdown(tcp::socket::shutdown_both, ec);
+        conn->socket.shutdown(tcp::socket::shutdown_both, ec);
       }, make_stack_allocator());
   } else {
 #else
@@ -1000,14 +998,13 @@ void AsioFrontend::accept(Listener& l, boost::system::error_code ec)
 #endif // WITH_RADOSGW_BEAST_OPENSSL
     spawn::spawn(context,
       [this, s=std::move(stream)] (yield_context yield) mutable {
-        Connection conn{s};
-        auto c = connections.add(conn);
-        auto timeout = timeout_timer{context.get_executor()};
-        auto buffer = std::make_unique<parse_buffer>();
+        auto conn = boost::intrusive_ptr{new Connection(std::move(s))};
+        auto c = connections.add(*conn);
+        auto timeout = timeout_timer{context.get_executor(), request_timeout, conn};
         boost::system::error_code ec;
-        handle_connection(context, env, s, timeout, *buffer, false, pause_mutex,
-                          scheduler.get(), ec, yield, request_timeout);
-        s.shutdown(tcp_socket::shutdown_both, ec);
+        handle_connection(context, env, conn->socket, timeout, conn->buffer,
+                          false, pause_mutex, scheduler.get(), ec, yield);
+        conn->socket.shutdown(tcp_socket::shutdown_both, ec);
       }, make_stack_allocator());
   }
 }
index 4fc81ee6585d6230ffcfb6700cec86ab7c7ce970..e5353ade2ce675d9e8078018b244bf890eb1e3ec 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <boost/asio/basic_waitable_timer.hpp>
+#include <boost/intrusive_ptr.hpp>
 
 #include "common/ceph_time.h"
 
@@ -9,9 +10,12 @@ namespace rgw {
 // a WaitHandler that closes a stream if the timeout expires
 template <typename Stream>
 struct timeout_handler {
-  Stream* stream;
+  // this handler may outlive the timer/stream, so we need to hold a reference
+  // to keep the stream alive
+  boost::intrusive_ptr<Stream> stream;
 
-  explicit timeout_handler(Stream* stream) noexcept : stream(stream) {}
+  explicit timeout_handler(boost::intrusive_ptr<Stream> stream) noexcept
+      : stream(std::move(stream)) {}
 
   void operator()(boost::system::error_code ec) {
     if (!ec) { // wait was not canceled
@@ -22,22 +26,24 @@ struct timeout_handler {
 };
 
 // a timeout timer for stream operations
-template <typename Clock, typename Executor>
+template <typename Clock, typename Executor, typename Stream>
 class basic_timeout_timer {
  public:
   using clock_type = Clock;
   using duration = typename clock_type::duration;
   using executor_type = Executor;
 
-  explicit basic_timeout_timer(const executor_type& ex) : timer(ex) {}
+  explicit basic_timeout_timer(const executor_type& ex, duration dur,
+                               boost::intrusive_ptr<Stream> stream)
+      : timer(ex), dur(dur), stream(std::move(stream))
+  {}
 
   basic_timeout_timer(const basic_timeout_timer&) = delete;
   basic_timeout_timer& operator=(const basic_timeout_timer&) = delete;
 
-  template <typename Stream>
-  void expires_after(Stream& stream, duration dur) {
+  void start() {
     timer.expires_after(dur);
-    timer.async_wait(timeout_handler{&stream});
+    timer.async_wait(timeout_handler{stream});
   }
 
   void cancel() {
@@ -48,6 +54,8 @@ class basic_timeout_timer {
   using Timer = boost::asio::basic_waitable_timer<clock_type,
         boost::asio::wait_traits<clock_type>, executor_type>;
   Timer timer;
+  duration dur;
+  boost::intrusive_ptr<Stream> stream;
 };
 
 } // namespace rgw