pb2 = pb2
def __init__(self, gw_group: Optional[str] = None, server_address: Optional[str] = None):
+
+ def encode_tls_bundle(bundle: Dict[str, str]) -> Dict[str, bytes]:
+ """Encode TLS bundle string values to bytes for gRPC."""
+ encoded: Dict[str, bytes] = {}
+ for key, value in bundle.items():
+ if isinstance(value, str):
+ encoded[key] = value.encode('utf-8')
+ else:
+ encoded[key] = value
+ return encoded
+
logger.info("Initiating nvmeof gateway connection...")
try:
if not gw_group:
logger.debug("Gateway address set to: %s", self.gateway_addr)
enable_auth = is_mtls_enabled(service_name)
if enable_auth:
- client_key = NvmeofGatewaysConfig.get_client_key(service_name)
- client_cert = NvmeofGatewaysConfig.get_client_cert(service_name)
- server_cert = NvmeofGatewaysConfig.get_ssl_cert(service_name)
- logger.info('Securely connecting to: %s', self.gateway_addr)
- credentials = grpc.ssl_channel_credentials(
- root_certificates=server_cert,
- private_key=client_key,
- certificate_chain=client_cert,
- )
- self.channel = grpc.secure_channel(self.gateway_addr, credentials)
+ tls_bundle = NvmeofGatewaysConfig.get_nvmeof_tls_bundle(service_name)
+ if tls_bundle:
+ logger.info('Securely connecting to: %s', self.gateway_addr)
+ encoded_tls_bundle = encode_tls_bundle(tls_bundle)
+ credentials = grpc.ssl_channel_credentials(
+ root_certificates=encoded_tls_bundle['server_cert'],
+ private_key=encoded_tls_bundle['client_key'],
+ certificate_chain=encoded_tls_bundle['client_cert'],
+ )
+ self.channel = grpc.secure_channel(self.gateway_addr, credentials)
+ else:
+ self.channel = None
+ logger.error("Cannot obtain nvmeof TLS bundle for the service %s (gw: %s)",
+ service_name, self.gateway_addr)
else:
logger.info("Insecurely connecting to: %s", self.gateway_addr)
self.channel = grpc.insecure_channel(self.gateway_addr)
- self.stub = pb2_grpc.GatewayStub(self.channel)
+
+ if self.channel is not None:
+ self.stub = pb2_grpc.GatewayStub(self.channel)
Model = Dict[str, Any]
Collection = List[Model]
)
@classmethod
- def get_client_cert(cls, service_name: str):
- client_cert = cls.from_cert_store('nvmeof_client_cert', service_name)
- return client_cert.encode() if client_cert else None
-
- @classmethod
- def get_client_key(cls, service_name: str):
- client_key = cls.from_cert_store('nvmeof_client_key', service_name, key=True)
- return client_key.encode() if client_key else None
-
- @classmethod
- def get_root_ca_cert(cls, service_name: str):
- root_ca_cert = cls.from_cert_store('nvmeof_root_ca_cert', service_name)
- return root_ca_cert.encode() if root_ca_cert else None
-
- @classmethod
- def get_ssl_cert(cls, service_name: str):
- server_cert = cls.from_cert_store('nvmeof_ssl_cert', service_name)
- return server_cert.encode() if server_cert else None
-
- @classmethod
- def from_cert_store(cls, entity: str, service_name: str, key=False):
+ def get_nvmeof_tls_bundle(cls, service_name: str):
try:
orch = OrchClient.instance()
if orch.available():
- if key:
- return orch.cert_store.get_key(entity, service_name,
- ignore_missing_exception=True)
- return orch.cert_store.get_cert(entity, service_name,
- ignore_missing_exception=True)
+ return orch.cert_store.get_nvmeof_tls_bundle(service_name)
return None
except OrchestratorError:
# just return None if any orchestrator error is raised