NamedTuple, Optional, Type, get_args, get_origin
from ..exceptions import DashboardException
-from .nvmeof_conf import NvmeofGatewaysConfig
+from .nvmeof_conf import NvmeofGatewaysConfig, is_mtls_enabled
logger = logging.getLogger("nvmeof_client")
if matched_gateway:
self.gateway_addr = matched_gateway.get('service_url')
logger.debug("Gateway address set to: %s", self.gateway_addr)
-
- root_ca_cert = NvmeofGatewaysConfig.get_root_ca_cert(service_name)
- if root_ca_cert:
+ enable_auth = is_mtls_enabled(service_name)
+ if enable_auth:
client_key = NvmeofGatewaysConfig.get_client_key(service_name)
client_cert = NvmeofGatewaysConfig.get_client_cert(service_name)
-
- if root_ca_cert and client_key and client_cert:
+ server_cert = NvmeofGatewaysConfig.get_server_cert(service_name)
logger.info('Securely connecting to: %s', self.gateway_addr)
credentials = grpc.ssl_channel_credentials(
- root_certificates=root_ca_cert,
+ root_certificates=server_cert,
private_key=client_key,
certificate_chain=client_cert,
)
@classmethod
def get_root_ca_cert(cls, service_name: str):
root_ca_cert = cls.from_cert_store('nvmeof_root_ca_cert', service_name)
- # If root_ca_cert is not set, use server_cert as root_ca_cert
- return root_ca_cert.encode() if root_ca_cert else cls.get_server_cert(service_name)
+ return root_ca_cert.encode() if root_ca_cert else None
@classmethod
def get_server_cert(cls, service_name: str):
service_name = gateway_keys[0]
return service_name, gateways[service_name][0]['service_url']
return None
+
+
+def is_mtls_enabled(service_name: str):
+ try:
+ orch = OrchClient.instance()
+ return orch.services.get(service_name)[0].spec.enable_auth
+ except OrchestratorError:
+ return False