From a68fbb18a96fa89ad2bc61c0f711cf8ce1b9ed31 Mon Sep 17 00:00:00 2001 From: Casey Bodley Date: Thu, 11 Nov 2021 12:01:06 -0500 Subject: [PATCH] rgw/beast: reference count Connections for timeout_handler resolves a use-after-free in the timeout_handler, where a timeout fires and schedules the timeout_handler for execution, but the coroutine exits and destroys the socket before asio executes the timeout_handler timeout_handler now holds a reference on the Connection to extend its lifetime now that the Connection is allocated on the heap, we can include the parse_buffer in this memory instead of allocating it separately Signed-off-by: Casey Bodley --- src/rgw/rgw_asio_frontend.cc | 89 +++++++++++++++---------------- src/rgw/rgw_asio_frontend_timer.h | 22 +++++--- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/src/rgw/rgw_asio_frontend.cc b/src/rgw/rgw_asio_frontend.cc index 2954e64883d81..5aba68072c375 100644 --- a/src/rgw/rgw_asio_frontend.cc +++ b/src/rgw/rgw_asio_frontend.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -43,6 +44,8 @@ namespace http = boost::beast::http; namespace ssl = boost::asio::ssl; #endif +struct Connection; + // use explicit executor types instead of the type-erased boost::asio::executor using executor_type = boost::asio::io_context::executor_type; @@ -50,7 +53,7 @@ using tcp_socket = boost::asio::basic_stream_socket; using tcp_stream = boost::beast::basic_stream; using timeout_timer = rgw::basic_timeout_timer; + executor_type, Connection>; using parse_buffer = boost::beast::flat_static_buffer<65536>; @@ -68,24 +71,20 @@ class StreamIO : public rgw::asio::ClientIO { timeout_timer& timeout; yield_context yield; parse_buffer& buffer; - ceph::timespan request_timeout; public: StreamIO(CephContext *cct, Stream& stream, timeout_timer& timeout, rgw::asio::parser_type& parser, yield_context yield, parse_buffer& buffer, bool is_ssl, const tcp::endpoint& local_endpoint, - const tcp::endpoint& remote_endpoint, - ceph::timespan request_timeout) + const tcp::endpoint& remote_endpoint) : ClientIO(parser, is_ssl, local_endpoint, remote_endpoint), cct(cct), stream(stream), timeout(timeout), yield(yield), - buffer(buffer), request_timeout(request_timeout) + buffer(buffer) {} size_t write_data(const char* buf, size_t len) override { boost::system::error_code ec; - if (request_timeout.count()) { - timeout.expires_after(stream.lowest_layer(), request_timeout); - } + timeout.start(); auto bytes = boost::asio::async_write(stream, boost::asio::buffer(buf, len), yield[ec]); timeout.cancel(); @@ -108,9 +107,7 @@ class StreamIO : public rgw::asio::ClientIO { while (body_remaining.size && !parser.is_done()) { boost::system::error_code ec; - if (request_timeout.count()) { - timeout.expires_after(stream.lowest_layer(), request_timeout); - } + timeout.start(); http::async_read_some(stream, buffer, parser, yield[ec]); timeout.cancel(); if (ec == http::error::need_buffer) { @@ -186,8 +183,7 @@ void handle_connection(boost::asio::io_context& context, SharedMutex& pause_mutex, rgw::dmclock::Scheduler *scheduler, boost::system::error_code& ec, - yield_context yield, - ceph::timespan request_timeout) + yield_context yield) { // limit header to 4k, since we read it all into a single flat_buffer static constexpr size_t header_limit = 4096; @@ -202,9 +198,7 @@ void handle_connection(boost::asio::io_context& context, rgw::asio::parser_type parser; parser.header_limit(header_limit); parser.body_limit(body_limit); - if (request_timeout.count()) { - timeout.expires_after(stream.lowest_layer(), request_timeout); - } + timeout.start(); // parse the header http::async_read_header(stream, buffer, parser, yield[ec]); timeout.cancel(); @@ -225,9 +219,7 @@ void handle_connection(boost::asio::io_context& context, response.result(http::status::bad_request); response.version(message.version() == 10 ? 10 : 11); response.prepare_payload(); - if (request_timeout.count()) { - timeout.expires_after(stream.lowest_layer(), request_timeout); - } + timeout.start(); http::async_write(stream, response, yield[ec]); timeout.cancel(); if (ec) { @@ -258,7 +250,7 @@ void handle_connection(boost::asio::io_context& context, StreamIO real_client{cct, stream, timeout, parser, yield, buffer, is_ssl, socket.local_endpoint(), - remote_endpoint, request_timeout}; + remote_endpoint}; auto real_client_io = rgw::io::add_reordering( rgw::io::add_buffering(cct, @@ -306,9 +298,7 @@ void handle_connection(boost::asio::io_context& context, body.size = discard_buffer.size(); body.data = discard_buffer.data(); - if (request_timeout.count()) { - timeout.expires_after(stream.lowest_layer(), request_timeout); - } + timeout.start(); http::async_read_some(stream, buffer, parser, yield[ec]); timeout.cancel(); if (ec == http::error::need_buffer) { @@ -326,9 +316,20 @@ void handle_connection(boost::asio::io_context& context, } } -struct Connection : boost::intrusive::list_base_hook<> { - tcp_socket& socket; - Connection(tcp_socket& socket) : socket(socket) {} +// timeout support requires that connections are reference-counted, because the +// timeout_handler can outlive the coroutine +struct Connection : boost::intrusive::list_base_hook<>, + boost::intrusive_ref_counter +{ + tcp_socket socket; + parse_buffer buffer; + + explicit Connection(tcp_socket&& socket) noexcept + : socket(std::move(socket)) {} + + void close(boost::system::error_code& ec) { + socket.close(ec); + } }; class ConnectionList { @@ -967,32 +968,29 @@ void AsioFrontend::accept(Listener& l, boost::system::error_code ec) if (l.use_ssl) { spawn::spawn(context, [this, s=std::move(stream)] (yield_context yield) mutable { - Connection conn{s}; - auto c = connections.add(conn); + auto conn = boost::intrusive_ptr{new Connection(std::move(s))}; + auto c = connections.add(*conn); // wrap the tcp stream in an ssl stream - boost::asio::ssl::stream stream{s, *ssl_context}; - auto timeout = timeout_timer{context.get_executor()}; - auto buffer = std::make_unique(); + boost::asio::ssl::stream stream{conn->socket, *ssl_context}; + auto timeout = timeout_timer{context.get_executor(), request_timeout, conn}; // do ssl handshake boost::system::error_code ec; - if (request_timeout.count()) { - timeout.expires_after(s, request_timeout); - } + timeout.start(); auto bytes = stream.async_handshake(ssl::stream_base::server, - buffer->data(), yield[ec]); + conn->buffer.data(), yield[ec]); timeout.cancel(); if (ec) { ldout(ctx(), 1) << "ssl handshake failed: " << ec.message() << dendl; return; } - buffer->consume(bytes); - handle_connection(context, env, stream, timeout, *buffer, true, pause_mutex, - scheduler.get(), ec, yield, request_timeout); + conn->buffer.consume(bytes); + handle_connection(context, env, stream, timeout, conn->buffer, true, + pause_mutex, scheduler.get(), ec, yield); if (!ec) { // ssl shutdown (ignoring errors) stream.async_shutdown(yield[ec]); } - s.shutdown(tcp::socket::shutdown_both, ec); + conn->socket.shutdown(tcp::socket::shutdown_both, ec); }, make_stack_allocator()); } else { #else @@ -1000,14 +998,13 @@ void AsioFrontend::accept(Listener& l, boost::system::error_code ec) #endif // WITH_RADOSGW_BEAST_OPENSSL spawn::spawn(context, [this, s=std::move(stream)] (yield_context yield) mutable { - Connection conn{s}; - auto c = connections.add(conn); - auto timeout = timeout_timer{context.get_executor()}; - auto buffer = std::make_unique(); + auto conn = boost::intrusive_ptr{new Connection(std::move(s))}; + auto c = connections.add(*conn); + auto timeout = timeout_timer{context.get_executor(), request_timeout, conn}; boost::system::error_code ec; - handle_connection(context, env, s, timeout, *buffer, false, pause_mutex, - scheduler.get(), ec, yield, request_timeout); - s.shutdown(tcp_socket::shutdown_both, ec); + handle_connection(context, env, conn->socket, timeout, conn->buffer, + false, pause_mutex, scheduler.get(), ec, yield); + conn->socket.shutdown(tcp_socket::shutdown_both, ec); }, make_stack_allocator()); } } diff --git a/src/rgw/rgw_asio_frontend_timer.h b/src/rgw/rgw_asio_frontend_timer.h index 4fc81ee6585d6..e5353ade2ce67 100644 --- a/src/rgw/rgw_asio_frontend_timer.h +++ b/src/rgw/rgw_asio_frontend_timer.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "common/ceph_time.h" @@ -9,9 +10,12 @@ namespace rgw { // a WaitHandler that closes a stream if the timeout expires template struct timeout_handler { - Stream* stream; + // this handler may outlive the timer/stream, so we need to hold a reference + // to keep the stream alive + boost::intrusive_ptr stream; - explicit timeout_handler(Stream* stream) noexcept : stream(stream) {} + explicit timeout_handler(boost::intrusive_ptr stream) noexcept + : stream(std::move(stream)) {} void operator()(boost::system::error_code ec) { if (!ec) { // wait was not canceled @@ -22,22 +26,24 @@ struct timeout_handler { }; // a timeout timer for stream operations -template +template class basic_timeout_timer { public: using clock_type = Clock; using duration = typename clock_type::duration; using executor_type = Executor; - explicit basic_timeout_timer(const executor_type& ex) : timer(ex) {} + explicit basic_timeout_timer(const executor_type& ex, duration dur, + boost::intrusive_ptr stream) + : timer(ex), dur(dur), stream(std::move(stream)) + {} basic_timeout_timer(const basic_timeout_timer&) = delete; basic_timeout_timer& operator=(const basic_timeout_timer&) = delete; - template - void expires_after(Stream& stream, duration dur) { + void start() { timer.expires_after(dur); - timer.async_wait(timeout_handler{&stream}); + timer.async_wait(timeout_handler{stream}); } void cancel() { @@ -48,6 +54,8 @@ class basic_timeout_timer { using Timer = boost::asio::basic_waitable_timer, executor_type>; Timer timer; + duration dur; + boost::intrusive_ptr stream; }; } // namespace rgw -- 2.39.5