// -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*-
// vim: ts=8 sw=2 smarttab ft=cpp
+#include <openssl/ssl.h>
+#include <openssl/x509.h>
+#include <openssl/pem.h>
+#include <openssl/bio.h>
+
#include <vector>
#include <string>
#include <array>
#include "rgw_sts.h"
#include "rgw_rest_oidc_provider.h"
+#include "rgw_asio_thread.h"
#define dout_context g_ceph_context
int r = load_provider(dpp, y, role_arn, iss, provider);
if (r < 0) {
ldpp_dout(dpp, 0) << "Couldn't get oidc provider info using input iss" << iss << dendl;
- throw -EACCES;
+ throw std::system_error(EACCES, std::system_category());
}
if (decoded.has_payload_claim(string(princTagsNamespace))) {
auto& cl = decoded.get_payload_claim(string(princTagsNamespace));
}
} else {
ldpp_dout(dpp, 0) << "Malformed principal tags" << cl.as_string() << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
}
if (! provider.client_ids.empty()) {
}
if (! found && ! is_client_id_valid(provider.client_ids, client_id) && ! is_client_id_valid(provider.client_ids, azp)) {
ldpp_dout(dpp, 0) << "Client id in token doesn't match with that registered with oidc provider" << dendl;
- throw -EACCES;
+ throw std::system_error(EACCES, std::system_category());
}
}
//Validate signature
auto& algorithm = decoded.get_algorithm();
try {
validate_signature(dpp, decoded, algorithm, iss, provider.thumbprints, y);
- } catch (...) {
- throw -EACCES;
+ } catch (const std::exception& e) {
+ throw std::system_error(EACCES, std::system_category());
}
} else {
return {boost::none, boost::none};
}
- } catch (int error) {
- if (error == -EACCES) {
- throw -EACCES;
- }
- ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl;
- return {boost::none, boost::none};
- }
- catch (...) {
+ } catch (const std::exception& e) {
ldpp_dout(dpp, 5) << "Invalid JWT token" << dendl;
return {boost::none, boost::none};
}
return cert_url;
}
+std::string
+WebTokenEngine::get_top_level_domain_from_host(const DoutPrefixProvider* dpp, const std::string& hostname) const
+{
+ std::string host = hostname;
+ //get top level domain only, removing https etc
+ auto pos = host.find("http://");
+ if (pos == std::string::npos) {
+ pos = host.find("https://");
+ if (pos != std::string::npos) {
+ host.erase(pos, 8);
+ } else {
+ pos = host.find("www.");
+ if (pos != std::string::npos) {
+ host.erase(pos, 4);
+ }
+ }
+ } else {
+ host.erase(pos, 7);
+ }
+
+ pos = host.find("/");
+ if (pos != std::string::npos) {
+ host.erase(pos, (host.length() - 1));
+ }
+
+ ldpp_dout(dpp, 20) << "Top level domain name of the host is: " << host << dendl;
+ return host;
+}
+
+int
+WebTokenEngine::create_connection(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const
+{
+ struct hostent* host = gethostbyname(hostname.c_str());
+ if (!host) {
+ ldpp_dout(dpp, 0) << "gethostbyname failed for host: " << hostname << dendl;
+ return -1;
+ }
+
+ struct sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ memcpy(&addr.sin_addr, host->h_addr, host->h_length);
+
+ int sock = socket(AF_INET, SOCK_STREAM, 0);
+ if (sock < 0) {
+ ldpp_dout(dpp, 10) << "creation of socket failed: " << sock << dendl;
+ return -1;
+ }
+
+ int ret = connect(sock, (struct sockaddr*)&addr, sizeof(addr));
+ if (ret != 0) {
+ ldpp_dout(dpp, 10) << "connection to socket failed: " << ret << dendl;
+ close(sock);
+ return -1;
+ }
+
+ return sock;
+}
+
+std::string
+WebTokenEngine::extract_last_certificate(const DoutPrefixProvider* dpp, const std::string& pem_chain) const
+{
+ const std::string BEGIN_MARKER = "-----BEGIN CERTIFICATE-----";
+ const std::string END_MARKER = "-----END CERTIFICATE-----";
+
+ // Find the last occurrence of BEGIN marker
+ size_t begin_pos = pem_chain.rfind(BEGIN_MARKER);
+ if (begin_pos == std::string::npos) {
+ ldpp_dout(dpp, 10) << "No BEGIN marker found in certificate chain" << dendl;
+ throw std::runtime_error("No BEGIN marker found in certificate chain");
+ }
+
+ // Find the END marker that comes after the last BEGIN marker
+ size_t end_pos = pem_chain.find(END_MARKER, begin_pos);
+ if (end_pos == std::string::npos) {
+ ldpp_dout(dpp, 10) << "No matching END marker found after last BEGIN marker" << dendl;
+ throw std::runtime_error("No matching END marker found after last BEGIN marker");
+ }
+
+ // Calculate the start and length of the complete certificate (including markers)
+ size_t cert_length = (end_pos + END_MARKER.length()) - begin_pos;
+
+ // Extract the complete certificate
+ std::string last_cert = pem_chain.substr(begin_pos, cert_length);
+
+ return last_cert;
+}
+
+void
+WebTokenEngine::shutdown_ssl(const DoutPrefixProvider* dpp, SSL* ssl, SSL_CTX* ctx) const
+{
+ int status = SSL_shutdown(ssl);
+ //status = 0, we have issued shutdown but not acknowledged by remote connection
+ //status = 1, remote connection has shutdown
+ //status !=1 && != 0, error
+ if (status == 0) {
+ status = SSL_shutdown(ssl);
+ }
+ if (status != 1) {
+ auto error = SSL_get_error(ssl, status);
+ ldpp_dout(dpp, 10) << "SSL shutdown failed with error: "<< error << dendl;
+ }
+ SSL_free(ssl); // This also frees cert chains
+ SSL_CTX_free(ctx);
+}
+
+std::string
+WebTokenEngine::connect_to_host_get_cert_chain(const DoutPrefixProvider* dpp, const std::string& hostname, int port) const
+{
+ maybe_warn_about_blocking(dpp);
+
+ // Create SSL context
+ SSL_CTX* ctx = SSL_CTX_new(TLS_client_method());
+ if (!ctx) {
+ throw std::runtime_error("Failed to create SSL context");
+ }
+
+ // Create SSL connection
+ SSL* ssl = SSL_new(ctx);
+ if (!ssl) {
+ SSL_CTX_free(ctx);
+ throw std::runtime_error("Failed to create SSL object");
+ }
+
+ // Create socket and connect
+ int sock = create_connection(dpp, hostname, port);
+ if (sock < 0) {
+ SSL_CTX_free(ctx);
+ throw std::runtime_error("Failed to connect to host:" + hostname);
+ }
+
+ SSL_set_fd(ssl, sock);
+
+ // Set SNI hostname
+ SSL_set_tlsext_host_name(ssl, hostname.c_str());
+
+ // Perform handshake
+ if (SSL_connect(ssl) != 1) {
+ SSL_free(ssl);
+ SSL_CTX_free(ctx);
+ close(sock);
+ throw std::runtime_error("SSL handshake failed");
+ }
+
+ std::string chain_pem;
+
+ // Get the peer certificate (server's certificate)
+ X509* cert = SSL_get_peer_certificate(ssl);
+ if (!cert) {
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("No certificate was presented");
+ }
+
+ // Get the chain
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
+ if (!chain) {
+ X509_free(cert);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("Failed to get certificate chain");
+ }
+
+ // Create BIO for PEM output
+ BIO* bio = BIO_new(BIO_s_mem());
+ if (!bio) {
+ X509_free(cert);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("Failed to to create BIO for PEM");
+ }
+
+ // Write the server's certificate first
+ if (!PEM_write_bio_X509(bio, cert)) {
+ BIO_free(bio);
+ X509_free(cert);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("Failed to write server certificate to BIO");
+ }
+ X509_free(cert);
+
+ // Write the rest of the chain
+ int chain_length = sk_X509_num(chain);
+ for (int i = 0; i < chain_length; i++) {
+ X509* chain_cert = sk_X509_value(chain, i);
+ if (!chain_cert) {
+ BIO_free(bio);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("NULL certificate encountered in chain at position " + std::to_string(i));
+ }
+ if (!PEM_write_bio_X509(bio, chain_cert)) {
+ BIO_free(bio);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+ throw std::runtime_error("Failed to write chain certificate to BIO at position " + std::to_string(i));
+ }
+ }
+
+ // Get the PEM data
+ char* pem_data;
+ long pem_size = BIO_get_mem_data(bio, &pem_data);
+ chain_pem = std::string(pem_data, pem_size);
+
+ // Cleanup
+ BIO_free(bio);
+ shutdown_ssl(dpp, ssl, ctx);
+ close(sock);
+
+ return chain_pem;
+}
+
void
WebTokenEngine::validate_signature_using_n_e(const DoutPrefixProvider* dpp, const jwt::decoded_jwt& decoded, const std::string &algorithm, const std::string& n, const std::string& e) const
{
verifier.verify(decoded);
}
} catch (const std::exception& e) {
+ ldpp_dout(dpp, 10) << std::string("Signature validation using n, e failed: ") + e.what() << dendl;
throw std::system_error(EACCES, std::system_category(), std::string("Signature validation using n, e failed: ") + e.what());
}
ldpp_dout(dpp, 10) << "Verified signature using n and e"<< dendl;
if (algorithm != "HS256" && algorithm != "HS384" && algorithm != "HS512") {
string cert_url = get_cert_url(iss, dpp, y);
if (cert_url.empty()) {
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
// Get certificate
int res = cert_req.process(dpp, y);
if (res < 0) {
ldpp_dout(dpp, 10) << "HTTP request res: " << res << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
//Debug only
ldpp_dout(dpp, 20) << "HTTP status: " << cert_req.get_http_status() << dendl;
}
if (! found_valid_cert) {
ldpp_dout(dpp, 0) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
try {
//verify method takes care of expired tokens also
return;
} else {
ldpp_dout(dpp, 0) << "Unsupported algorithm: " << algorithm << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
- } catch (std::runtime_error& e) {
- ldpp_dout(dpp, 0) << "Signature validation failed: " << e.what() << dendl;
- throw;
}
- catch (...) {
- ldpp_dout(dpp, 0) << "Signature validation failed" << dendl;
- throw;
+ catch (const std::exception& e) {
+ ldpp_dout(dpp, 0) << "Signature validation using x5c failed" << e.what() << dendl;
+ throw std::system_error(EACCES, std::system_category());
}
} else {
if (algorithm == "RS256" || algorithm == "RS384" || algorithm == "RS512") {
string n, e; //modulus and exponent
- if (JSONDecoder::decode_json("n", n, &parser) && JSONDecoder::decode_json("e", e, &parser)) {
+ if (JSONDecoder::decode_json("n", n, &k_parser) && JSONDecoder::decode_json("e", e, &k_parser)) {
+ if (skip == true) {
+ continue;
+ }
+ //Fetch and verify cert according to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_create_oidc_verify-thumbprint.html
+ //and the same must be installed as part of create oidc provider in rgw
+ //this can be made common to all types of keys(x5c, n&e), making thumbprint validation similar to
+ //AWS
+ std::string hostname = get_top_level_domain_from_host(dpp, cert_url);
+ //connect to host and get back cert chain from it
+ std::string cert_chain = connect_to_host_get_cert_chain(dpp, hostname, 443);
+ std::string cert;
+ try {
+ cert = extract_last_certificate(dpp, cert_chain);
+ ldpp_dout(dpp, 20) << "last cert: " << cert << dendl;
+ } catch(const std::exception& e) {
+ ldpp_dout(dpp, 20) << "Extracting last cert of jwks uri failed with: " << e.what() << dendl;
+ throw std::system_error(EINVAL, std::system_category());
+ }
+ if (!is_cert_valid(thumbprints, cert)) {
+ ldpp_dout(dpp, 20) << "Cert doesn't match that with the thumbprints registered with oidc provider: " << cert.c_str() << dendl;
+ throw std::system_error(EINVAL, std::system_category());
+ }
validate_signature_using_n_e(dpp, decoded, algorithm, n, e);
return;
}
ldpp_dout(dpp, 0) << "x5c not present or n, e not present" << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
} else {
throw std::system_error(EINVAL, std::system_category(), "Invalid algorithm: " + algorithm);
}
}
ldpp_dout(dpp, 0) << "Signature can not be validated with the input given in keys: "<< dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
} //end k_parser.parse
}//end for iterate through keys
} else { //end val->is_array
ldpp_dout(dpp, 0) << "keys not present in JSON" << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
} else {
ldpp_dout(dpp, 0) << "Malformed json returned while fetching cert" << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
} //if-else get-data
} else {
ldpp_dout(dpp, 0) << "JWT signed by HMAC algos are currently not supported" << dendl;
- throw -EINVAL;
+ throw std::system_error(EINVAL, std::system_category());
}
}
}
return result_t::deny(-EACCES);
}
- catch (...) {
+ catch (const std::exception& e) {
return result_t::deny(-EACCES);
}
}
aud = s->info.args.get("aud");
if (roleArn.empty() || roleSessionName.empty() || sub.empty() || aud.empty()) {
- ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or token is empty" << dendl;
+ ldpp_dout(this, 0) << "ERROR: one of role arn or role session name or sub or aud is empty" << dendl;
return -EINVAL;
}