]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
rgw/sts: adding validation of jwks_uri cert according
authorPritha Srivastava <prsrivas@redhat.com>
Thu, 13 Feb 2025 11:18:43 +0000 (16:48 +0530)
committerPritha Srivastava <prsrivas@redhat.com>
Sat, 26 Apr 2025 06:02:04 +0000 (11:32 +0530)
to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_create_oidc_verify-thumbprint.html
for n&e which can be later used for all key types
(x5c, n&e).

Signed-off-by: Pritha Srivastava <prsrivas@redhat.com>
src/rgw/rgw_rest_sts.cc
src/rgw/rgw_rest_sts.h

index c2c979ca9ce8f518e20dc5479f682a3b73227818..0b57f2ea858bafc6a8df83ee3d31c15678c64ed7 100644 (file)
@@ -1,5 +1,10 @@
 // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
 // vim: ts=8 sw=2 smarttab ft=cpp
+#include <openssl/ssl.h>
+#include <openssl/x509.h>
+#include <openssl/pem.h>
+#include <openssl/bio.h>
+
 #include <vector>
 #include <string>
 #include <array>
@@ -38,6 +43,7 @@
 
 #include "rgw_sts.h"
 #include "rgw_rest_oidc_provider.h"
+#include "rgw_asio_thread.h"
 
 
 #define dout_context g_ceph_context
@@ -247,7 +253,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t
     int r = load_provider(dpp, y, role_arn, iss, provider);
     if (r < 0) {
       ldpp_dout(dpp, 0) << "Couldn't get oidc provider info using input iss" << iss << dendl;
-      throw -EACCES;
+      throw std::system_error(EACCES, std::system_category());
     }
     if (decoded.has_payload_claim(string(princTagsNamespace))) {
       auto& cl = decoded.get_payload_claim(string(princTagsNamespace));
@@ -258,7 +264,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t
         }
       } else {
         ldpp_dout(dpp, 0) << "Malformed principal tags" << cl.as_string() << dendl;
-        throw -EINVAL;
+        throw std::system_error(EINVAL, std::system_category());
       }
     }
     if (! provider.client_ids.empty()) {
@@ -271,7 +277,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t
       }
       if (! found && ! is_client_id_valid(provider.client_ids, client_id) && ! is_client_id_valid(provider.client_ids, azp)) {
         ldpp_dout(dpp, 0) << "Client id in token doesn't match with that registered with oidc provider" << dendl;
-        throw -EACCES;
+        throw std::system_error(EACCES, std::system_category());
       }
     }
     //Validate signature
@@ -279,20 +285,13 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t
       auto& algorithm = decoded.get_algorithm();
       try {
         validate_signature(dpp, decoded, algorithm, iss, provider.thumbprints, y);
-      } catch (...) {
-        throw -EACCES;
+      } catch (const std::exception& e) {
+        throw std::system_error(EACCES, std::system_category());
       }
     } else {
       return {boost::none, boost::none};
     }
-  } catch (int error) {
-    if (error == -EACCES) {
-      throw -EACCES;
-    }
-    ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl;
-    return {boost::none, boost::none};
-  }
-  catch (...) {
+  } catch (const std::exception& e) {
     ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl;
     return {boost::none, boost::none};
   }
@@ -339,6 +338,220 @@ WebTokenEngine::get_cert_url(const string& iss, const DoutPrefixProvider *dpp, o
   return cert_url;
 }
 
+std::string
+WebTokenEngine::get_top_level_domain_from_host(const DoutPrefixProvider* dpp, const std::string& hostname) const
+{
+  std::string host = hostname;
+  //get top level domain only, removing https etc
+  auto pos = host.find("http://");
+  if (pos == std::string::npos) {
+    pos = host.find("https://");
+    if (pos != std::string::npos) {
+      host.erase(pos, 8);
+    } else {
+      pos = host.find("www.");
+      if (pos != std::string::npos) {
+        host.erase(pos, 4);
+      }
+    }
+  } else {
+    host.erase(pos, 7);
+  }
+
+  pos = host.find("/");
+  if (pos != std::string::npos) {
+    host.erase(pos, (host.length() - 1));
+  }
+
+  ldpp_dout(dpp, 20) << "Top level domain name of the host is: " << host << dendl;
+  return host;
+}
+
+int
+WebTokenEngine::create_connection(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const
+{
+  struct hostent* host = gethostbyname(hostname.c_str());
+  if (!host) {
+    ldpp_dout(dpp, 0) << "gethostbyname failed for host: " << hostname << dendl;
+    return -1;
+  }
+
+  struct sockaddr_in addr;
+  memset(&addr, 0, sizeof(addr));
+  addr.sin_family = AF_INET;
+  addr.sin_port = htons(port);
+  memcpy(&addr.sin_addr, host->h_addr, host->h_length);
+
+  int sock = socket(AF_INET, SOCK_STREAM, 0);
+  if (sock < 0) {
+    ldpp_dout(dpp, 10) << "creation of socket failed: " << sock << dendl;
+    return -1;
+  }
+
+  int ret = connect(sock, (struct sockaddr*)&addr, sizeof(addr));
+  if (ret != 0) {
+    ldpp_dout(dpp, 10) << "connection to socket failed: " << ret << dendl;
+    close(sock);
+    return -1;
+  }
+
+  return sock;
+}
+
+std::string
+WebTokenEngine::extract_last_certificate(const DoutPrefixProvider* dpp, const std::string& pem_chain) const
+{
+  const std::string BEGIN_MARKER = "-----BEGIN CERTIFICATE-----";
+  const std::string END_MARKER = "-----END CERTIFICATE-----";
+
+  // Find the last occurrence of BEGIN marker
+  size_t begin_pos = pem_chain.rfind(BEGIN_MARKER);
+  if (begin_pos == std::string::npos) {
+    ldpp_dout(dpp, 10) << "No BEGIN marker found in certificate chain" << dendl;
+    throw std::runtime_error("No BEGIN marker found in certificate chain");
+  }
+
+  // Find the END marker that comes after the last BEGIN marker
+  size_t end_pos = pem_chain.find(END_MARKER, begin_pos);
+  if (end_pos == std::string::npos) {
+    ldpp_dout(dpp, 10) << "No matching END marker found after last BEGIN marker" << dendl;
+    throw std::runtime_error("No matching END marker found after last BEGIN marker");
+  }
+
+  // Calculate the start and length of the complete certificate (including markers)
+  size_t cert_length = (end_pos + END_MARKER.length()) - begin_pos;
+
+  // Extract the complete certificate
+  std::string last_cert = pem_chain.substr(begin_pos, cert_length);
+
+  return last_cert;
+}
+
+void
+WebTokenEngine::shutdown_ssl(const DoutPrefixProvider* dpp, SSL* ssl, SSL_CTX* ctx) const
+{
+  int status = SSL_shutdown(ssl);
+  //status = 0, we have issued shutdown but not acknowledged by remote connection
+  //status = 1, remote connection has shutdown
+  //status !=1 && != 0, error
+  if (status == 0) {
+    status = SSL_shutdown(ssl);
+  }
+  if (status != 1) {
+    auto error = SSL_get_error(ssl, status);
+    ldpp_dout(dpp, 10) << "SSL shutdown failed with error: "<< error << dendl;
+  }
+  SSL_free(ssl); // This also frees cert chains
+  SSL_CTX_free(ctx);
+}
+
+std::string
+WebTokenEngine::connect_to_host_get_cert_chain(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const
+{
+  maybe_warn_about_blocking(dpp);
+
+  // Create SSL context
+  SSL_CTX* ctx = SSL_CTX_new(TLS_client_method());
+  if (!ctx) {
+    throw std::runtime_error("Failed to create SSL context");
+  }
+
+  // Create SSL connection
+  SSL* ssl = SSL_new(ctx);
+  if (!ssl) {
+    SSL_CTX_free(ctx);
+    throw std::runtime_error("Failed to create SSL object");
+  }
+
+  // Create socket and connect
+  int sock = create_connection(dpp, hostname, port);
+  if (sock < 0) {
+    SSL_CTX_free(ctx);
+    throw std::runtime_error("Failed to connect to host:" + hostname);
+  }
+
+  SSL_set_fd(ssl, sock);
+
+  // Set SNI hostname
+  SSL_set_tlsext_host_name(ssl, hostname.c_str());
+
+  // Perform handshake
+  if (SSL_connect(ssl) != 1) {
+    SSL_free(ssl);
+    SSL_CTX_free(ctx);
+    close(sock);
+    throw std::runtime_error("SSL handshake failed");
+  }
+
+  std::string chain_pem;
+
+  // Get the peer certificate (server's certificate)
+  X509* cert = SSL_get_peer_certificate(ssl);
+  if (!cert) {
+    shutdown_ssl(dpp, ssl, ctx);
+    close(sock);
+    throw std::runtime_error("No certificate was presented");
+  }
+
+  // Get the chain
+  STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
+  if (!chain) {
+    X509_free(cert);
+    shutdown_ssl(dpp, ssl, ctx);
+    close(sock);
+    throw std::runtime_error("Failed to get certificate chain");
+  }
+
+  // Create BIO for PEM output
+  BIO* bio = BIO_new(BIO_s_mem());
+  if (!bio) {
+    X509_free(cert);
+    shutdown_ssl(dpp, ssl, ctx);
+    close(sock);
+    throw std::runtime_error("Failed to to create BIO for PEM");
+  }
+
+  // Write the server's certificate first
+  if (!PEM_write_bio_X509(bio, cert)) {
+    BIO_free(bio);
+    X509_free(cert);
+    shutdown_ssl(dpp, ssl, ctx);
+    close(sock);
+    throw std::runtime_error("Failed to write server certificate to BIO");
+  }
+  X509_free(cert);
+
+  // Write the rest of the chain
+  int chain_length = sk_X509_num(chain);
+  for (int i = 0; i < chain_length; i++) {
+    X509* chain_cert = sk_X509_value(chain, i);
+    if (!chain_cert) {
+      BIO_free(bio);
+      shutdown_ssl(dpp, ssl, ctx);
+      close(sock);
+      throw std::runtime_error("NULL certificate encountered in chain at position " + std::to_string(i));
+    }
+    if (!PEM_write_bio_X509(bio, chain_cert)) {
+      BIO_free(bio);
+      shutdown_ssl(dpp, ssl, ctx);
+      close(sock);
+      throw std::runtime_error("Failed to write chain certificate to BIO at position " + std::to_string(i));
+    }
+  }
+
+  // Get the PEM data
+  char* pem_data;
+  long pem_size = BIO_get_mem_data(bio, &pem_data);
+  chain_pem = std::string(pem_data, pem_size);
+
+  // Cleanup
+  BIO_free(bio);
+  shutdown_ssl(dpp, ssl, ctx);
+  close(sock);
+
+  return chain_pem;
+}
+
 void
 WebTokenEngine::validate_signature_using_n_e(const DoutPrefixProvider* dpp, const jwt::decoded_jwt& decoded, const std::string &algorithm, const std::string& n, const std::string& e) const
 {
@@ -357,6 +570,7 @@ WebTokenEngine::validate_signature_using_n_e(const DoutPrefixProvider* dpp, cons
       verifier.verify(decoded);
     }
   } catch (const std::exception& e) {
+    ldpp_dout(dpp, 10) << std::string("Signature validation using n, e failed: ") + e.what() << dendl;
     throw std::system_error(EACCES, std::system_category(), std::string("Signature validation using n, e failed: ") + e.what());
   }
   ldpp_dout(dpp, 10) << "Verified signature using n and e"<< dendl;
@@ -369,7 +583,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec
   if (algorithm != "HS256" && algorithm != "HS384" && algorithm != "HS512") {
     string cert_url = get_cert_url(iss, dpp, y);
     if (cert_url.empty()) {
-      throw -EINVAL;
+      throw std::system_error(EINVAL, std::system_category());
     }
 
     // Get certificate
@@ -381,7 +595,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec
     int res = cert_req.process(dpp, y);
     if (res < 0) {
       ldpp_dout(dpp, 10) << "HTTP request res: " << res << dendl;
-      throw -EINVAL;
+      throw std::system_error(EINVAL, std::system_category());
     }
     //Debug only
     ldpp_dout(dpp, 20) << "HTTP status: " << cert_req.get_http_status() << dendl;
@@ -421,7 +635,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec
               }
               if (! found_valid_cert) {
                 ldpp_dout(dpp, 0) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl;
-                throw -EINVAL;
+                throw std::system_error(EINVAL, std::system_category());
               }
               try {
                 //verify method takes care of expired tokens also
@@ -481,44 +695,63 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec
                   return;
                 } else {
                   ldpp_dout(dpp, 0) << "Unsupported algorithm: " << algorithm << dendl;
-                  throw -EINVAL;
+                  throw std::system_error(EINVAL, std::system_category());
                 }
-              } catch (std::runtime_error& e) {
-                ldpp_dout(dpp, 0) << "Signature validation failed: " << e.what() << dendl;
-                throw;
               }
-              catch (...) {
-                ldpp_dout(dpp, 0) << "Signature validation failed" << dendl;
-                throw;
+              catch (const std::exception& e) {
+                ldpp_dout(dpp, 0) << "Signature validation using x5c failed" << e.what() << dendl;
+                throw std::system_error(EACCES, std::system_category());
               }
             } else {
               if (algorithm == "RS256" || algorithm == "RS384" || algorithm == "RS512") {
                 string n, e; //modulus and exponent
-                if (JSONDecoder::decode_json("n", n, &parser) && JSONDecoder::decode_json("e", e, &parser)) {
+                if (JSONDecoder::decode_json("n", n, &k_parser) && JSONDecoder::decode_json("e", e, &k_parser)) {
+                  if (skip == true) {
+                    continue;
+                  }
+                  //Fetch and verify cert according to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_create_oidc_verify-thumbprint.html
+                  //and the same must be installed as part of create oidc provider in rgw
+                  //this can be made common to all types of keys(x5c, n&e), making thumbprint validation similar to
+                  //AWS
+                  std::string hostname = get_top_level_domain_from_host(dpp, cert_url);
+                  //connect to host and get back cert chain from it
+                  std::string cert_chain = connect_to_host_get_cert_chain(dpp, hostname, 443);
+                  std::string cert;
+                  try {
+                    cert = extract_last_certificate(dpp, cert_chain);
+                    ldpp_dout(dpp, 20) << "last cert: " << cert << dendl;
+                  } catch(const std::exception& e) {
+                    ldpp_dout(dpp, 20) << "Extracting last cert of jwks uri failed with: " << e.what() << dendl;
+                    throw std::system_error(EINVAL, std::system_category());
+                  }
+                  if (!is_cert_valid(thumbprints, cert)) {
+                    ldpp_dout(dpp, 20) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl;
+                    throw std::system_error(EINVAL, std::system_category());
+                  }
                   validate_signature_using_n_e(dpp, decoded, algorithm, n, e);
                   return;
                 }
                 ldpp_dout(dpp, 0) << "x5c not present or n, e not present" << dendl;
-                throw -EINVAL;
+                throw std::system_error(EINVAL, std::system_category());
               } else {
                 throw std::system_error(EINVAL, std::system_category(), "Invalid algorithm: " + algorithm);
               }
             }
             ldpp_dout(dpp, 0) << "Signature can not be validated with the input given in keys: "<< dendl;
-            throw -EINVAL;
+            throw std::system_error(EINVAL, std::system_category());
           } //end k_parser.parse
         }//end for iterate through keys
       } else { //end val->is_array
         ldpp_dout(dpp, 0) << "keys not present in JSON" << dendl;
-        throw -EINVAL;
+        throw std::system_error(EINVAL, std::system_category());
       }
     } else {
       ldpp_dout(dpp, 0) << "Malformed json returned while fetching cert" << dendl;
-      throw -EINVAL;
+      throw std::system_error(EINVAL, std::system_category());
     } //if-else get-data
   } else {
     ldpp_dout(dpp, 0) << "JWT signed by HMAC algos are currently not supported" << dendl;
-    throw -EINVAL;
+    throw std::system_error(EINVAL, std::system_category());
   }
 }
 
@@ -578,7 +811,7 @@ WebTokenEngine::authenticate( const DoutPrefixProvider* dpp,
     }
     return result_t::deny(-EACCES);
   }
-  catch (...) {
+  catch (const std::exception& e) {
     return result_t::deny(-EACCES);
   }
 }
@@ -715,7 +948,7 @@ int RGWSTSAssumeRoleWithWebIdentity::get_params()
   aud = s->info.args.get("aud");
 
   if (roleArn.empty() || roleSessionName.empty() || sub.empty() || aud.empty()) {
-    ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or token is empty" << dendl;
+    ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or sub or aud is empty" << dendl;
     return -EINVAL;
   }
 
index 8d25f6784329e165643c4987ff2cc40b44f9ba2a..befa93169b4ead0e4077f8d8a96e6ca470177ba8 100644 (file)
@@ -65,6 +65,12 @@ class WebTokenEngine : public rgw::auth::Engine {
   void recurse_and_insert(const string& key, const jwt::claim& c, T& t) const;
   WebTokenEngine::token_t get_token_claims(const jwt::decoded_jwt& decoded) const;
 
+  int create_connection(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const;
+  std::string connect_to_host_get_cert_chain(const DoutPrefixProvider* dpp, const std::string& hostname, int port = 443) const;
+  std::string get_top_level_domain_from_host(const DoutPrefixProvider* dpp, const std::string& hostname) const;
+  std::string extract_last_certificate(const DoutPrefixProvider* dpp, const std::string& pem_chain) const;
+  void shutdown_ssl(const DoutPrefixProvider* dpp, SSL* ssl, SSL_CTX* ctx) const;
+
 public:
   WebTokenEngine(CephContext* const cct,
                     rgw::sal::Driver* driver,