]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
msg/async: msgr2: authentication phase
authorRicardo Dias <rdias@suse.com>
Wed, 12 Sep 2018 07:28:32 +0000 (08:28 +0100)
committerRicardo Dias <rdias@suse.com>
Wed, 23 Jan 2019 13:59:23 +0000 (13:59 +0000)
Signed-off-by: Ricardo Dias <rdias@suse.com>
src/msg/async/ProtocolV2.cc
src/msg/async/ProtocolV2.h

index f0da5aa5ceb4a5bf63388fdf83abad46dc428376..8a8cbdb7632589b3745d4133c869e7e8868ed06a 100644 (file)
@@ -31,7 +31,8 @@ ProtocolV2::ProtocolV2(AsyncConnection *connection)
     : Protocol(2, connection),
       temp_buffer(nullptr),
       state(NONE),
-      bannerExchangeCallback(nullptr) {
+      bannerExchangeCallback(nullptr),
+      next_frame_len(0) {
   temp_buffer = new char[4096];
 }
 
@@ -43,9 +44,35 @@ void ProtocolV2::accept() { state = START_ACCEPT; }
 
 bool ProtocolV2::is_connected() { return false; }
 
-void ProtocolV2::stop() {}
+void ProtocolV2::stop() {
+  ldout(cct, 2) << __func__ << dendl;
+  if (state == CLOSED) {
+    return;
+  }
+
+  if (connection->delay_state) connection->delay_state->flush();
+
+  connection->_stop();
+
+  state = CLOSED;
+}
 
-void ProtocolV2::fault() { _fault(); }
+void ProtocolV2::fault() {
+  ldout(cct, 10) << __func__ << dendl;
+
+  if (state == CLOSED || state == NONE) {
+    ldout(cct, 10) << __func__ << " connection is already closed" << dendl;
+    return;
+  }
+
+  if (connection->policy.lossy && state != START_CONNECT &&
+      state != CONNECTING) {
+    ldout(cct, 1) << __func__ << " on lossy channel, failing" << dendl;
+    stop();
+    connection->dispatch_queue->queue_reset(connection);
+    return;
+  }
+}
 
 void ProtocolV2::send_message(Message *m) {}
 
@@ -220,6 +247,108 @@ unsigned banner_prefix_len = strlen(CEPH_BANNER_V2_PREFIX);
   return callback;
 }
 
+CtPtr ProtocolV2::read_frame() {
+  ldout(cct, 20) << __func__ << dendl;
+  return READ(sizeof(__le32), handle_read_frame_length);
+}
+
+CtPtr ProtocolV2::handle_read_frame_length(char *buffer, int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " read frame length failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  next_frame_len = *(__le32 *)buffer;
+
+  return READ(next_frame_len, handle_frame);
+}
+
+CtPtr ProtocolV2::handle_frame(char *buffer, int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " read frame payload failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  Tag tag = static_cast<Tag>(*(__le32 *)buffer);
+  buffer += sizeof(__le32);
+  uint32_t payload_len = next_frame_len - sizeof(__le32);
+
+  ldout(cct, 10) << __func__ << " tag=" << static_cast<uint32_t>(tag) << dendl;
+
+  switch (tag) {
+    case Tag::AUTH_REQUEST:
+      return handle_auth_request(buffer, payload_len);
+    case Tag::AUTH_BAD_METHOD:
+      return handle_auth_bad_method(buffer, payload_len);
+    case Tag::AUTH_BAD_AUTH:
+      return handle_auth_bad_auth(buffer, payload_len);
+    case Tag::AUTH_MORE:
+      return handle_auth_more(buffer, payload_len);
+    case Tag::AUTH_DONE:
+      return handle_auth_done(buffer, payload_len);
+    default:
+      ceph_abort();
+  }
+  return nullptr;
+}
+
+CtPtr ProtocolV2::handle_auth_more(char *payload, uint32_t length) {
+  ldout(cct, 20) << __func__ << " payload_len=" << length << dendl;
+
+  AuthMoreFrame auth_more(payload, length);
+  ldout(cct, 1) << __func__ << " auth more len=" << auth_more.len << dendl;
+
+  /* BEGIN TO REMOVE */
+  auto p = auth_more.auth_payload.cbegin();
+  int32_t i;
+  std::string s;
+  try {
+    decode(i, p);
+    decode(s, p);
+  } catch (const buffer::error &e) {
+    lderr(cct) << __func__ << " decode auth_payload failed" << dendl;
+    return _fault();
+  }
+
+  ldout(cct, 10) << __func__ << " (TO REMOVE) auth_more (" << (int32_t)i << ", "
+                 << s << ")" << dendl;
+
+  if (i == 45 && s == "hello server more") {
+    bufferlist auth_bl;
+    encode((int32_t)55, auth_bl, 0);
+    std::string hello("hello client more");
+    encode(hello, auth_bl, 0);
+    /* END TO REMOVE */
+    AuthMoreFrame more(auth_bl);
+    bufferlist bl = more.to_bufferlist();
+    return WRITE(bl, handle_auth_more_write);
+  }
+  /* END TO REMOVE */
+
+  AuthDoneFrame auth_done(0);
+
+  auto bl = auth_done.to_bufferlist();
+  return WRITE(bl, handle_auth_done_write);
+}
+
+CtPtr ProtocolV2::handle_auth_more_write(int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " auth more write failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  return CONTINUE(read_frame);
+}
+
 /* Client Protocol Methods */
 
 CtPtr ProtocolV2::start_client_banner_exchange() {
@@ -235,7 +364,82 @@ CtPtr ProtocolV2::post_client_banner_exchange() {
   // at this point we can change how the client protocol behaves based on
   // this->peer_required_features
 
-  ceph_abort();
+  return send_auth_request();
+}
+
+CtPtr ProtocolV2::send_auth_request(std::vector<__u32> allowed_methods) {
+  ldout(cct, 20) << __func__ << dendl;
+
+  // We need to get an authorizer at this point.
+  // this->messenger->get_authorizer(...)
+
+  bufferlist auth_bl;
+  /* BEGIN TO REMOVE */
+  encode((int32_t)35, auth_bl, 0);
+  std::string hello("hello");
+  encode(hello, auth_bl, 0);
+  /* END TO REMOVE */
+  __le32 method;
+  if (allowed_methods.empty()) {
+    // choose client's preferred method
+    method = 23;  // 23 is just for testing purposes (REMOVE THIS)
+  } else {
+    // choose one of the allowed methods
+    method = allowed_methods[0];
+  }
+  AuthRequestFrame authFrame(method, auth_bl);
+
+  bufferlist bl = authFrame.to_bufferlist();
+  return WRITE(bl, handle_auth_request_write);
+}
+
+CtPtr ProtocolV2::handle_auth_request_write(int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " auth request write failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  return CONTINUE(read_frame);
+}
+
+CtPtr ProtocolV2::handle_auth_bad_method(char *payload, uint32_t length) {
+  ldout(cct, 20) << __func__ << " payload_len=" << length << dendl;
+
+  AuthBadMethodFrame bad_method(payload, length);
+  ldout(cct, 1) << __func__ << " auth method=" << bad_method.method
+                << " rejected, allowed methods=" << bad_method.allowed_methods
+                << dendl;
+
+  return send_auth_request(bad_method.allowed_methods);
+}
+
+CtPtr ProtocolV2::handle_auth_bad_auth(char *payload, uint32_t length) {
+  ldout(cct, 20) << __func__ << " payload_len=" << length << dendl;
+
+  AuthBadAuthFrame bad_auth(payload, length);
+  ldout(cct, 1) << __func__ << " authentication failed"
+                << " error code=" << bad_auth.error_code
+                << " error message=" << bad_auth.error_msg << dendl;
+
+  return _fault();
+}
+
+CtPtr ProtocolV2::handle_auth_done(char *payload, uint32_t length) {
+  ldout(cct, 20) << __func__ << " payload_len=" << length << dendl;
+
+  AuthDoneFrame auth_done(payload, length);
+  ldout(cct, 1) << __func__ << " authentication done,"
+                << " flags=" << auth_done.flags << dendl;
+
+  return send_client_ident();
+}
+
+CtPtr ProtocolV2::send_client_ident() {
+  ldout(cct, 20) << __func__ << dendl;
+
   return nullptr;
 }
 
@@ -254,7 +458,115 @@ CtPtr ProtocolV2::post_server_banner_exchange() {
   // at this point we can change how the server protocol behaves based on
   // this->peer_required_features
 
-  ceph_abort();
-  return nullptr;
+  return CONTINUE(read_frame);
 }
 
+CtPtr ProtocolV2::handle_auth_request(char *payload, uint32_t length) {
+  ldout(cct, 20) << __func__ << " payload_len=" << length << dendl;
+
+  AuthRequestFrame auth_request(payload, length);
+
+  ldout(cct, 10) << __func__ << " AuthRequest(method=" << auth_request.method
+                 << ", auth_len=" << auth_request.len << ")" << dendl;
+
+  /* BEGIN TO REMOVE */
+  auto p = auth_request.auth_payload.cbegin();
+  int32_t i;
+  std::string s;
+  try {
+    decode(i, p);
+    decode(s, p);
+  } catch (const buffer::error &e) {
+    lderr(cct) << __func__ << " decode auth_payload failed" << dendl;
+    return _fault();
+  }
+
+  ldout(cct, 10) << __func__ << " (TO REMOVE) auth_payload (" << (int32_t)i
+                 << ", " << s << ")" << dendl;
+
+  /* END TO REMOVE */
+
+  /*
+   * Get the auth methods from somewhere.
+   * In V1 the allowed auth methods depend on the peer_type.
+   * In V2, at this stage, we still don't know the peer_type so either
+   * we define the set of allowed auth methods for any entity type,
+   * or we need to exchange the entity type before reaching this point.
+   */
+
+  std::vector<__u32> allowed_methods = {CEPH_AUTH_NONE, CEPH_AUTH_CEPHX};
+
+  bool found = false;
+  for (const auto &a_method : allowed_methods) {
+    if (a_method == auth_request.method) {
+      // auth method allowed by the server
+      found = true;
+      break;
+    }
+  }
+
+  if (!found) {
+    ldout(cct, 1) << __func__ << " auth method=" << auth_request.method
+                  << " not allowed" << dendl;
+    AuthBadMethodFrame bad_method(auth_request.method, allowed_methods);
+    bufferlist bl = bad_method.to_bufferlist();
+    return WRITE(bl, handle_auth_bad_method_write);
+  }
+
+  ldout(cct, 10) << __func__ << " auth method=" << auth_request.method
+                 << " accepted" << dendl;
+  // verify authorization blob
+  bool valid = i == 35;
+
+  if (!valid) {
+    AuthBadAuthFrame bad_auth(12, "Permission denied");
+    bufferlist bl = bad_auth.to_bufferlist();
+    return WRITE(bl, handle_auth_bad_auth_write);
+  }
+
+  bufferlist auth_bl;
+  /* BEGIN TO REMOVE */
+  encode((int32_t)45, auth_bl, 0);
+  std::string hello("hello server more");
+  encode(hello, auth_bl, 0);
+  /* END TO REMOVE */
+  AuthMoreFrame more(auth_bl);
+  bufferlist bl = more.to_bufferlist();
+  return WRITE(bl, handle_auth_more_write);
+}
+
+CtPtr ProtocolV2::handle_auth_bad_method_write(int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " auth bad method write failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  return CONTINUE(read_frame);
+}
+
+CtPtr ProtocolV2::handle_auth_bad_auth_write(int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " auth bad auth write failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  return CONTINUE(read_frame);
+}
+
+CtPtr ProtocolV2::handle_auth_done_write(int r) {
+  ldout(cct, 20) << __func__ << " r=" << r << dendl;
+
+  if (r < 0) {
+    ldout(cct, 1) << __func__ << " auth done write failed r=" << r << " ("
+                  << cpp_strerror(r) << ")" << dendl;
+    return _fault();
+  }
+
+  return nullptr;
+}
index 0fff889059d1c501b793a8b34f882c66c0438206..eb0c17635da9561aa8f223ea89d7d5e705ef29c9 100644 (file)
@@ -24,6 +24,137 @@ private:
     return statenames[state];
   }
 
+  enum class Tag : __le32 {
+    AUTH_REQUEST,
+    AUTH_BAD_METHOD,
+    AUTH_BAD_AUTH,
+    AUTH_MORE,
+    AUTH_DONE
+  };
+
+  struct Frame {
+    __le32 frame_len;
+    __le32 tag;
+    bufferlist payload;
+
+    Frame(Tag tag, __le32 payload_len)
+        : frame_len(sizeof(__le32) + payload_len),
+          tag(static_cast<__le32>(tag)) {}
+
+    bufferlist to_bufferlist() {
+      ceph_assert(payload.length() == (frame_len - sizeof(__le32)));
+      bufferlist bl;
+      encode(frame_len, bl, 0);
+      encode(tag, bl, 0);
+      bl.claim_append(payload);
+      return bl;
+    }
+  };
+
+  struct AuthRequestFrame : public Frame {
+    __le32 method;
+    __le32 len;
+    bufferlist auth_payload;
+
+    AuthRequestFrame(__le32 method, bufferlist &auth_payload)
+        : Frame(Tag::AUTH_REQUEST, sizeof(__le32) * 2 + auth_payload.length()),
+          method(method),
+          len(auth_payload.length()),
+          auth_payload(auth_payload) {
+      encode(method, payload, 0);
+      encode(len, payload, 0);
+      payload.claim_append(auth_payload);
+    }
+
+    AuthRequestFrame(char *payload, uint32_t length)
+        : Frame(Tag::AUTH_REQUEST, length) {
+      method = *(__le32 *)payload;
+      len = *(__le32 *)(payload + sizeof(__le32));
+      ceph_assert((length - (sizeof(__le32) * 2)) == len);
+      auth_payload.append((payload + (sizeof(__le32) * 2)), len);
+    }
+  };
+
+  struct AuthBadMethodFrame : public Frame {
+    __le32 method;
+    std::vector<__u32> allowed_methods;
+
+    AuthBadMethodFrame(__le32 method, std::vector<__u32> methods)
+        : Frame(Tag::AUTH_BAD_METHOD, sizeof(__le32) * (2 + methods.size())),
+          method(method),
+          allowed_methods(methods) {
+      encode(method, payload, 0);
+      encode((__le32)allowed_methods.size(), payload, 0);
+      for (const auto &a_meth : allowed_methods) {
+        encode(a_meth, payload, 0);
+      }
+    }
+
+    AuthBadMethodFrame(char *payload, uint32_t length)
+        : Frame(Tag::AUTH_BAD_METHOD, length) {
+      method = *(__le32 *)payload;
+      __le32 num_methods = *(__le32 *)(payload + sizeof(__le32));
+      payload += sizeof(__le32) * 2;
+      for (unsigned i = 0; i < num_methods; ++i) {
+        allowed_methods.push_back(*(__le32 *)(payload + sizeof(__le32) * i));
+      }
+    }
+  };
+
+  struct AuthBadAuthFrame : public Frame {
+    __le32 error_code;
+    std::string error_msg;
+
+    AuthBadAuthFrame(__le32 error_code, std::string error_msg)
+        : Frame(Tag::AUTH_BAD_AUTH, sizeof(__le32) * 2 + error_msg.size()),
+          error_code(error_code),
+          error_msg(error_msg) {
+      encode(error_code, payload, 0);
+      encode(error_msg, payload, 0);
+    }
+
+    AuthBadAuthFrame(char *payload, uint32_t length)
+        : Frame(Tag::AUTH_BAD_AUTH, length) {
+      error_code = *(__le32 *)payload;
+      __le32 len = *(__le32 *)(payload + sizeof(__le32));
+      error_msg = std::string(payload + sizeof(__le32) * 2, len);
+    }
+  };
+
+  struct AuthMoreFrame : public Frame {
+    __le32 len;
+    bufferlist auth_payload;
+
+    AuthMoreFrame(bufferlist &auth_payload)
+        : Frame(Tag::AUTH_MORE, sizeof(__le32) + auth_payload.length()),
+          len(auth_payload.length()),
+          auth_payload(auth_payload) {
+      encode(len, payload, 0);
+      payload.claim_append(auth_payload);
+    }
+
+    AuthMoreFrame(char *payload, uint32_t length)
+        : Frame(Tag::AUTH_BAD_AUTH, length) {
+      len = *(__le32 *)payload;
+      ceph_assert((length - sizeof(__le32)) == len);
+      auth_payload.append(payload + sizeof(__le32), len);
+    }
+  };
+
+  struct AuthDoneFrame : public Frame {
+    __le64 flags;
+
+    AuthDoneFrame(uint64_t flags)
+        : Frame(Tag::AUTH_DONE, sizeof(__le64)), flags(flags) {
+      encode(flags, payload, 0);
+    }
+
+    AuthDoneFrame(char *payload, uint32_t length)
+        : Frame(Tag::AUTH_DONE, length) {
+      flags = *(__le64 *)payload;
+    }
+  };
+
   char *temp_buffer;
   State state;
 
@@ -40,6 +171,7 @@ private:
                         bufferlist &bl);
 
   inline Ct<ProtocolV2> *_fault() {
+    fault();
     return nullptr;
   }
 
@@ -51,6 +183,18 @@ private:
   Ct<ProtocolV2> *_banner_exchange_handle_write(int r);
   Ct<ProtocolV2> *_banner_exchange_handle_peer_banner(char *buffer, int r);
 
+  uint32_t next_frame_len;
+  CONTINUATION_DECL(ProtocolV2, read_frame);
+  READ_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_read_frame_length);
+  READ_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_frame);
+  WRITE_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_auth_more_write);
+
+  Ct<ProtocolV2> *read_frame();
+  Ct<ProtocolV2> *handle_read_frame_length(char *buffer, int r);
+  Ct<ProtocolV2> *handle_frame(char *buffer, int r);
+  Ct<ProtocolV2> *handle_auth_more(char *payload, uint32_t length);
+  Ct<ProtocolV2> *handle_auth_more_write(int r);
+
 public:
   ProtocolV2(AsyncConnection *connection);
   virtual ~ProtocolV2();
@@ -71,16 +215,30 @@ private:
   // Client Protocol
   CONTINUATION_DECL(ProtocolV2, start_client_banner_exchange);
   CONTINUATION_DECL(ProtocolV2, post_client_banner_exchange);
+  WRITE_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_auth_request_write);
 
   Ct<ProtocolV2> *start_client_banner_exchange();
   Ct<ProtocolV2> *post_client_banner_exchange();
+  Ct<ProtocolV2> *send_auth_request(std::vector<__u32> allowed_methods = {});
+  Ct<ProtocolV2> *handle_auth_request_write(int r);
+  Ct<ProtocolV2> *handle_auth_bad_method(char *payload, uint32_t length);
+  Ct<ProtocolV2> *handle_auth_bad_auth(char *payload, uint32_t length);
+  Ct<ProtocolV2> *handle_auth_done(char *payload, uint32_t length);
+  Ct<ProtocolV2> *send_client_ident();
 
   // Server Protocol
   CONTINUATION_DECL(ProtocolV2, start_server_banner_exchange);
   CONTINUATION_DECL(ProtocolV2, post_server_banner_exchange);
+  WRITE_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_auth_bad_method_write);
+  WRITE_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_auth_bad_auth_write);
+  WRITE_HANDLER_CONTINUATION_DECL(ProtocolV2, handle_auth_done_write);
 
   Ct<ProtocolV2> *start_server_banner_exchange();
   Ct<ProtocolV2> *post_server_banner_exchange();
+  Ct<ProtocolV2> *handle_auth_request(char *payload, uint32_t length);
+  Ct<ProtocolV2> *handle_auth_bad_method_write(int r);
+  Ct<ProtocolV2> *handle_auth_bad_auth_write(int r);
+  Ct<ProtocolV2> *handle_auth_done_write(int r);
 };
 
 #endif /* _MSG_ASYNC_PROTOCOL_V2_ */