From: Pritha Srivastava Date: Thu, 13 Feb 2025 11:18:43 +0000 (+0530) Subject: rgw/sts: adding validation of jwks_uri cert according X-Git-Url: http://git.apps.os.sepia.ceph.com/?a=commitdiff_plain;h=d970f62e3e264644ea474ad4ee513caac977268f;p=ceph.git rgw/sts: adding validation of jwks_uri cert according to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_create_oidc_verify-thumbprint.html for n&e which can be later used for all key types (x5c, n&e). Signed-off-by: Pritha Srivastava --- diff --git a/src/rgw/rgw_rest_sts.cc b/src/rgw/rgw_rest_sts.cc index c2c979ca9ce8f..0b57f2ea858ba 100644 --- a/src/rgw/rgw_rest_sts.cc +++ b/src/rgw/rgw_rest_sts.cc @@ -1,5 +1,10 @@ // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*- // vim: ts=8 sw=2 smarttab ft=cpp +#include +#include +#include +#include + #include #include #include @@ -38,6 +43,7 @@ #include "rgw_sts.h" #include "rgw_rest_oidc_provider.h" +#include "rgw_asio_thread.h" #define dout_context g_ceph_context @@ -247,7 +253,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t int r = load_provider(dpp, y, role_arn, iss, provider); if (r < 0) { ldpp_dout(dpp, 0) << "Couldn't get oidc provider info using input iss" << iss << dendl; - throw -EACCES; + throw std::system_error(EACCES, std::system_category()); } if (decoded.has_payload_claim(string(princTagsNamespace))) { auto& cl = decoded.get_payload_claim(string(princTagsNamespace)); @@ -258,7 +264,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t } } else { ldpp_dout(dpp, 0) << "Malformed principal tags" << cl.as_string() << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } } if (! provider.client_ids.empty()) { @@ -271,7 +277,7 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t } if (! found && ! is_client_id_valid(provider.client_ids, client_id) && ! is_client_id_valid(provider.client_ids, azp)) { ldpp_dout(dpp, 0) << "Client id in token doesn't match with that registered with oidc provider" << dendl; - throw -EACCES; + throw std::system_error(EACCES, std::system_category()); } } //Validate signature @@ -279,20 +285,13 @@ WebTokenEngine::get_from_jwt(const DoutPrefixProvider* dpp, const std::string& t auto& algorithm = decoded.get_algorithm(); try { validate_signature(dpp, decoded, algorithm, iss, provider.thumbprints, y); - } catch (...) { - throw -EACCES; + } catch (const std::exception& e) { + throw std::system_error(EACCES, std::system_category()); } } else { return {boost::none, boost::none}; } - } catch (int error) { - if (error == -EACCES) { - throw -EACCES; - } - ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl; - return {boost::none, boost::none}; - } - catch (...) { + } catch (const std::exception& e) { ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl; return {boost::none, boost::none}; } @@ -339,6 +338,220 @@ WebTokenEngine::get_cert_url(const string& iss, const DoutPrefixProvider *dpp, o return cert_url; } +std::string +WebTokenEngine::get_top_level_domain_from_host(const DoutPrefixProvider* dpp, const std::string& hostname) const +{ + std::string host = hostname; + //get top level domain only, removing https etc + auto pos = host.find("http://"); + if (pos == std::string::npos) { + pos = host.find("https://"); + if (pos != std::string::npos) { + host.erase(pos, 8); + } else { + pos = host.find("www."); + if (pos != std::string::npos) { + host.erase(pos, 4); + } + } + } else { + host.erase(pos, 7); + } + + pos = host.find("/"); + if (pos != std::string::npos) { + host.erase(pos, (host.length() - 1)); + } + + ldpp_dout(dpp, 20) << "Top level domain name of the host is: " << host << dendl; + return host; +} + +int +WebTokenEngine::create_connection(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const +{ + struct hostent* host = gethostbyname(hostname.c_str()); + if (!host) { + ldpp_dout(dpp, 0) << "gethostbyname failed for host: " << hostname << dendl; + return -1; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + memcpy(&addr.sin_addr, host->h_addr, host->h_length); + + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + ldpp_dout(dpp, 10) << "creation of socket failed: " << sock << dendl; + return -1; + } + + int ret = connect(sock, (struct sockaddr*)&addr, sizeof(addr)); + if (ret != 0) { + ldpp_dout(dpp, 10) << "connection to socket failed: " << ret << dendl; + close(sock); + return -1; + } + + return sock; +} + +std::string +WebTokenEngine::extract_last_certificate(const DoutPrefixProvider* dpp, const std::string& pem_chain) const +{ + const std::string BEGIN_MARKER = "-----BEGIN CERTIFICATE-----"; + const std::string END_MARKER = "-----END CERTIFICATE-----"; + + // Find the last occurrence of BEGIN marker + size_t begin_pos = pem_chain.rfind(BEGIN_MARKER); + if (begin_pos == std::string::npos) { + ldpp_dout(dpp, 10) << "No BEGIN marker found in certificate chain" << dendl; + throw std::runtime_error("No BEGIN marker found in certificate chain"); + } + + // Find the END marker that comes after the last BEGIN marker + size_t end_pos = pem_chain.find(END_MARKER, begin_pos); + if (end_pos == std::string::npos) { + ldpp_dout(dpp, 10) << "No matching END marker found after last BEGIN marker" << dendl; + throw std::runtime_error("No matching END marker found after last BEGIN marker"); + } + + // Calculate the start and length of the complete certificate (including markers) + size_t cert_length = (end_pos + END_MARKER.length()) - begin_pos; + + // Extract the complete certificate + std::string last_cert = pem_chain.substr(begin_pos, cert_length); + + return last_cert; +} + +void +WebTokenEngine::shutdown_ssl(const DoutPrefixProvider* dpp, SSL* ssl, SSL_CTX* ctx) const +{ + int status = SSL_shutdown(ssl); + //status = 0, we have issued shutdown but not acknowledged by remote connection + //status = 1, remote connection has shutdown + //status !=1 && != 0, error + if (status == 0) { + status = SSL_shutdown(ssl); + } + if (status != 1) { + auto error = SSL_get_error(ssl, status); + ldpp_dout(dpp, 10) << "SSL shutdown failed with error: "<< error << dendl; + } + SSL_free(ssl); // This also frees cert chains + SSL_CTX_free(ctx); +} + +std::string +WebTokenEngine::connect_to_host_get_cert_chain(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const +{ + maybe_warn_about_blocking(dpp); + + // Create SSL context + SSL_CTX* ctx = SSL_CTX_new(TLS_client_method()); + if (!ctx) { + throw std::runtime_error("Failed to create SSL context"); + } + + // Create SSL connection + SSL* ssl = SSL_new(ctx); + if (!ssl) { + SSL_CTX_free(ctx); + throw std::runtime_error("Failed to create SSL object"); + } + + // Create socket and connect + int sock = create_connection(dpp, hostname, port); + if (sock < 0) { + SSL_CTX_free(ctx); + throw std::runtime_error("Failed to connect to host:" + hostname); + } + + SSL_set_fd(ssl, sock); + + // Set SNI hostname + SSL_set_tlsext_host_name(ssl, hostname.c_str()); + + // Perform handshake + if (SSL_connect(ssl) != 1) { + SSL_free(ssl); + SSL_CTX_free(ctx); + close(sock); + throw std::runtime_error("SSL handshake failed"); + } + + std::string chain_pem; + + // Get the peer certificate (server's certificate) + X509* cert = SSL_get_peer_certificate(ssl); + if (!cert) { + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("No certificate was presented"); + } + + // Get the chain + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) { + X509_free(cert); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("Failed to get certificate chain"); + } + + // Create BIO for PEM output + BIO* bio = BIO_new(BIO_s_mem()); + if (!bio) { + X509_free(cert); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("Failed to to create BIO for PEM"); + } + + // Write the server's certificate first + if (!PEM_write_bio_X509(bio, cert)) { + BIO_free(bio); + X509_free(cert); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("Failed to write server certificate to BIO"); + } + X509_free(cert); + + // Write the rest of the chain + int chain_length = sk_X509_num(chain); + for (int i = 0; i < chain_length; i++) { + X509* chain_cert = sk_X509_value(chain, i); + if (!chain_cert) { + BIO_free(bio); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("NULL certificate encountered in chain at position " + std::to_string(i)); + } + if (!PEM_write_bio_X509(bio, chain_cert)) { + BIO_free(bio); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + throw std::runtime_error("Failed to write chain certificate to BIO at position " + std::to_string(i)); + } + } + + // Get the PEM data + char* pem_data; + long pem_size = BIO_get_mem_data(bio, &pem_data); + chain_pem = std::string(pem_data, pem_size); + + // Cleanup + BIO_free(bio); + shutdown_ssl(dpp, ssl, ctx); + close(sock); + + return chain_pem; +} + void WebTokenEngine::validate_signature_using_n_e(const DoutPrefixProvider* dpp, const jwt::decoded_jwt& decoded, const std::string &algorithm, const std::string& n, const std::string& e) const { @@ -357,6 +570,7 @@ WebTokenEngine::validate_signature_using_n_e(const DoutPrefixProvider* dpp, cons verifier.verify(decoded); } } catch (const std::exception& e) { + ldpp_dout(dpp, 10) << std::string("Signature validation using n, e failed: ") + e.what() << dendl; throw std::system_error(EACCES, std::system_category(), std::string("Signature validation using n, e failed: ") + e.what()); } ldpp_dout(dpp, 10) << "Verified signature using n and e"<< dendl; @@ -369,7 +583,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec if (algorithm != "HS256" && algorithm != "HS384" && algorithm != "HS512") { string cert_url = get_cert_url(iss, dpp, y); if (cert_url.empty()) { - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } // Get certificate @@ -381,7 +595,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec int res = cert_req.process(dpp, y); if (res < 0) { ldpp_dout(dpp, 10) << "HTTP request res: " << res << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } //Debug only ldpp_dout(dpp, 20) << "HTTP status: " << cert_req.get_http_status() << dendl; @@ -421,7 +635,7 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec } if (! found_valid_cert) { ldpp_dout(dpp, 0) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } try { //verify method takes care of expired tokens also @@ -481,44 +695,63 @@ WebTokenEngine::validate_signature(const DoutPrefixProvider* dpp, const jwt::dec return; } else { ldpp_dout(dpp, 0) << "Unsupported algorithm: " << algorithm << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } - } catch (std::runtime_error& e) { - ldpp_dout(dpp, 0) << "Signature validation failed: " << e.what() << dendl; - throw; } - catch (...) { - ldpp_dout(dpp, 0) << "Signature validation failed" << dendl; - throw; + catch (const std::exception& e) { + ldpp_dout(dpp, 0) << "Signature validation using x5c failed" << e.what() << dendl; + throw std::system_error(EACCES, std::system_category()); } } else { if (algorithm == "RS256" || algorithm == "RS384" || algorithm == "RS512") { string n, e; //modulus and exponent - if (JSONDecoder::decode_json("n", n, &parser) && JSONDecoder::decode_json("e", e, &parser)) { + if (JSONDecoder::decode_json("n", n, &k_parser) && JSONDecoder::decode_json("e", e, &k_parser)) { + if (skip == true) { + continue; + } + //Fetch and verify cert according to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_create_oidc_verify-thumbprint.html + //and the same must be installed as part of create oidc provider in rgw + //this can be made common to all types of keys(x5c, n&e), making thumbprint validation similar to + //AWS + std::string hostname = get_top_level_domain_from_host(dpp, cert_url); + //connect to host and get back cert chain from it + std::string cert_chain = connect_to_host_get_cert_chain(dpp, hostname, 443); + std::string cert; + try { + cert = extract_last_certificate(dpp, cert_chain); + ldpp_dout(dpp, 20) << "last cert: " << cert << dendl; + } catch(const std::exception& e) { + ldpp_dout(dpp, 20) << "Extracting last cert of jwks uri failed with: " << e.what() << dendl; + throw std::system_error(EINVAL, std::system_category()); + } + if (!is_cert_valid(thumbprints, cert)) { + ldpp_dout(dpp, 20) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl; + throw std::system_error(EINVAL, std::system_category()); + } validate_signature_using_n_e(dpp, decoded, algorithm, n, e); return; } ldpp_dout(dpp, 0) << "x5c not present or n, e not present" << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } else { throw std::system_error(EINVAL, std::system_category(), "Invalid algorithm: " + algorithm); } } ldpp_dout(dpp, 0) << "Signature can not be validated with the input given in keys: "<< dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } //end k_parser.parse }//end for iterate through keys } else { //end val->is_array ldpp_dout(dpp, 0) << "keys not present in JSON" << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } } else { ldpp_dout(dpp, 0) << "Malformed json returned while fetching cert" << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } //if-else get-data } else { ldpp_dout(dpp, 0) << "JWT signed by HMAC algos are currently not supported" << dendl; - throw -EINVAL; + throw std::system_error(EINVAL, std::system_category()); } } @@ -578,7 +811,7 @@ WebTokenEngine::authenticate( const DoutPrefixProvider* dpp, } return result_t::deny(-EACCES); } - catch (...) { + catch (const std::exception& e) { return result_t::deny(-EACCES); } } @@ -715,7 +948,7 @@ int RGWSTSAssumeRoleWithWebIdentity::get_params() aud = s->info.args.get("aud"); if (roleArn.empty() || roleSessionName.empty() || sub.empty() || aud.empty()) { - ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or token is empty" << dendl; + ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or sub or aud is empty" << dendl; return -EINVAL; } diff --git a/src/rgw/rgw_rest_sts.h b/src/rgw/rgw_rest_sts.h index 8d25f6784329e..befa93169b4ea 100644 --- a/src/rgw/rgw_rest_sts.h +++ b/src/rgw/rgw_rest_sts.h @@ -65,6 +65,12 @@ class WebTokenEngine : public rgw::auth::Engine { void recurse_and_insert(const string& key, const jwt::claim& c, T& t) const; WebTokenEngine::token_t get_token_claims(const jwt::decoded_jwt& decoded) const; + int create_connection(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const; + std::string connect_to_host_get_cert_chain(const DoutPrefixProvider* dpp, const std::string& hostname, int port = 443) const; + std::string get_top_level_domain_from_host(const DoutPrefixProvider* dpp, const std::string& hostname) const; + std::string extract_last_certificate(const DoutPrefixProvider* dpp, const std::string& pem_chain) const; + void shutdown_ssl(const DoutPrefixProvider* dpp, SSL* ssl, SSL_CTX* ctx) const; + public: WebTokenEngine(CephContext* const cct, rgw::sal::Driver* driver,