]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/cephadm: implement 2-way ssl in mgr -> MgrListener comm line
authorAdam King <adking@redhat.com>
Thu, 26 Aug 2021 19:20:23 +0000 (15:20 -0400)
committerAdam King <adking@redhat.com>
Fri, 24 Sep 2021 11:23:51 +0000 (07:23 -0400)
Signed-off-by: Adam King <adking@redhat.com>
src/cephadm/cephadm
src/pybind/mgr/cephadm/agent.py
src/pybind/mgr/cephadm/serve.py
src/pybind/mgr/cephadm/services/cephadmservice.py

index 5265f9c790a132dd43bc2d9632c3b5f1e3ac371f..2e4b0224f54371e2c0ce3effbf696abf03474631 100755 (executable)
@@ -3474,9 +3474,14 @@ class MgrListener(Thread):
         listenSocket.bind(('0.0.0.0', int(self.agent.listener_port)))
         listenSocket.settimeout(60)
         listenSocket.listen(1)
+        ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+        ssl_ctx.verify_mode = ssl.CERT_REQUIRED
+        ssl_ctx.load_cert_chain(self.agent.listener_cert_path, self.agent.listener_key_path)
+        ssl_ctx.load_verify_locations(self.agent.ca_path)
+        secureListenSocket = ssl_ctx.wrap_socket(listenSocket, server_side=True)
         while not self.stop:
             try:
-                conn, _ = listenSocket.accept()
+                conn, _ = secureListenSocket.accept()
             except socket.timeout:
                 continue
             try:
@@ -3484,6 +3489,7 @@ class MgrListener(Thread):
             except Exception as e:
                 err_str = f'Failed to extract length of payload from message: {e}'
                 conn.send(err_str.encode())
+                logger.error(err_str)
             while True:
                 payload = conn.recv(length).decode()
                 if not payload:
@@ -3528,6 +3534,8 @@ class CephadmAgent():
         self.config_path = os.path.join(self.daemon_dir, 'agent.json')
         self.keyring_path = os.path.join(self.daemon_dir, 'keyring')
         self.ca_path = os.path.join(self.daemon_dir, 'root_cert.pem')
+        self.listener_cert_path = os.path.join(self.daemon_dir, 'listener.crt')
+        self.listener_key_path = os.path.join(self.daemon_dir, 'listener.key')
         self.listener_port = ''
         self.ack = -1
         self.event = threading.Event()
index e1dc5e96ab5bbaa161795839a9a7b5f1f3211578..f7831c6fb3a6b0255a6a0eff0ef5ab4edbd73abb 100644 (file)
@@ -1,6 +1,7 @@
 import cherrypy
 import json
 import socket
+import ssl
 import tempfile
 import threading
 import time
@@ -149,6 +150,7 @@ class HostData:
                 return
 
             self.mgr.cache.agent_ports[host] = int(data['port'])
+            # update timestamp of most recent agent update
             self.mgr.cache.agent_timestamp[host] = datetime_now()
             up_to_date = False
 
@@ -161,15 +163,15 @@ class HostData:
                 self.mgr.log.debug(
                     f'Received old metadata from agent on host {host}. Requested up-to-date metadata.')
 
+
+            if 'ls' in data:
+                self.mgr._process_ls_output(host, data['ls'])
+            if 'networks' in data:
+                self.mgr.cache.update_host_networks(host, data['networks'])
+            if 'facts' in data:
+                self.mgr.cache.update_host_facts(host, json.loads(data['facts']))
+
             if up_to_date:
-                if 'ls' in data:
-                    self.mgr._process_ls_output(host, data['ls'])
-                if 'networks' in data:
-                    self.mgr.cache.update_host_networks(host, data['networks'])
-                if 'facts' in data:
-                    self.mgr.cache.update_host_facts(host, json.loads(data['facts']))
-
-                # update timestamp of most recent up-to-date agent update
                 self.mgr.cache.metadata_up_to_date[host] = True
                 self.mgr.log.debug(
                     f'Received up-to-date metadata from agent on host {host}.')
@@ -188,17 +190,50 @@ class AgentMessageThread(threading.Thread):
         super(AgentMessageThread, self).__init__(target=self.run)
 
     def run(self) -> None:
+        try:
+            assert self.mgr.cherrypy_thread
+            root_cert= self.mgr.cherrypy_thread.ssl_certs.get_root_cert()
+            root_cert_tmp = tempfile.NamedTemporaryFile()
+            root_cert_tmp.write(root_cert.encode('utf-8'))
+            root_cert_tmp.flush()
+            root_cert_fname = root_cert_tmp.name
+
+            cert, key = self.mgr.cherrypy_thread.ssl_certs.generate_cert()
+
+            cert_tmp = tempfile.NamedTemporaryFile()
+            cert_tmp.write(cert.encode('utf-8'))
+            cert_tmp.flush()
+            cert_fname = cert_tmp.name
+
+            key_tmp = tempfile.NamedTemporaryFile()
+            key_tmp.write(key.encode('utf-8'))
+            key_tmp.flush()
+            key_fname = key_tmp.name
+
+            ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=root_cert_fname)
+            ssl_ctx.verify_mode = ssl.CERT_REQUIRED
+            ssl_ctx.check_hostname = True
+            ssl_ctx.load_cert_chain(cert_fname, key_fname)
+        except Exception as e:
+            self.mgr.log.error(f'Failed to get certs for connecting to agent: {e}')
+            return
+        try:
+            bytes_len: str = str(len(self.data.encode('utf-8')))
+            if len(bytes_len.encode('utf-8')) > 10:
+                raise Exception(f'Message is too big to send to agent. Message size is {bytes_len} bytes!')
+            while len(bytes_len.encode('utf-8')) < 10:
+                bytes_len = '0' + bytes_len
+        except Exception as e:
+            self.mgr.log.error(f'Failed to get length of json payload: {e}')
+            return
         for retry_wait in [3, 5]:
             try:
                 agent_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-                agent_socket.connect((self.mgr.inventory.get_addr(self.host), self.port))
-                bytes_len: str = str(len(self.data.encode('utf-8')))
-                if len(bytes_len.encode('utf-8')) > 10:
-                    raise Exception(f'Message is too big to send to agent. Message size is {bytes_len} bytes!')
-                while len(bytes_len.encode('utf-8')) < 10:
-                    bytes_len = '0' + bytes_len
-                agent_socket.sendall((bytes_len + self.data).encode('utf-8'))
-                agent_response = agent_socket.recv(1024).decode()
+                secure_agent_socket = ssl_ctx.wrap_socket(agent_socket, server_hostname=self.host)
+                secure_agent_socket.connect((self.mgr.inventory.get_addr(self.host), self.port))
+                msg = (bytes_len + self.data)
+                secure_agent_socket.sendall(msg.encode('utf-8'))
+                agent_response = secure_agent_socket.recv(1024).decode()
                 self.mgr.log.debug(f'Received "{agent_response}" from agent on host {self.host}')
                 return
             except ConnectionError as e:
@@ -209,8 +244,9 @@ class AgentMessageThread(threading.Thread):
                 time.sleep(retry_wait)
             except Exception as e:
                 # if it's not a connection error, something has gone wrong. Give up.
-                self.mgr.log.debug(f'Failed to contact agent on host {self.host}: {e}')
+                self.mgr.log.error(f'Failed to contact agent on host {self.host}: {e}')
                 return
+        self.mgr.log.error(f'Could not connect to agent on host {self.host}')
         return
 
 
@@ -218,12 +254,12 @@ class CephadmAgentHelpers:
     def __init__(self, mgr: "CephadmOrchestrator"):
         self.mgr: "CephadmOrchestrator" = mgr
 
-    def _request_agent_acks(self, hosts: Set[str]) -> None:
+    def _request_agent_acks(self, hosts: Set[str], increment: bool = False) -> None:
         for host in hosts:
             self.mgr.cache.metadata_up_to_date[host] = False
             if host not in self.mgr.cache.agent_counter:
                 self.mgr.cache.agent_counter[host] = 1
-            else:
+            elif increment:
                 self.mgr.cache.agent_counter[host] = self.mgr.cache.agent_counter[host] + 1
             message_thread = AgentMessageThread(
                 host, self.mgr.cache.agent_ports[host], {'counter': self.mgr.cache.agent_counter[host]}, self.mgr)
@@ -277,9 +313,10 @@ class SSLCerts:
 
         cert_str = crypto.dump_certificate(crypto.FILETYPE_PEM, self.root_cert).decode('utf-8')
         key_str = crypto.dump_privatekey(crypto.FILETYPE_PEM, self.root_key).decode('utf-8')
+
         return (cert_str, key_str)
 
-    def generate_cert(self) -> Tuple[str, str]:
+    def generate_cert(self, name: str = '') -> Tuple[str, str]:
         key = crypto.PKey()
         key.generate_key(crypto.TYPE_RSA, 2048)
 
@@ -287,7 +324,10 @@ class SSLCerts:
         cert.set_serial_number(int(uuid4()))
 
         subj = cert.get_subject()
-        subj.commonName = str(self.mgr.get_mgr_ip())
+        if not name:
+            subj.commonName = str(self.mgr.get_mgr_ip())
+        else:
+            subj.commonName = name
 
         cert.set_issuer(self.root_subj)
         cert.set_pubkey(key)
index 0840242ba8398ea209f39c05d791354ab52e5786..4c76a9833345f5c444e3cf9fe284254070481627 100644 (file)
@@ -816,6 +816,7 @@ class CephadmServe:
                         self._remove_daemon(d.name(), d.hostname)
                         daemons_to_remove.remove(d)
                         progress_done += 1
+                        hosts_altered.add(d.hostname)
                         break
 
                 # deploy new daemon
@@ -891,7 +892,7 @@ class CephadmServe:
             if self.mgr.use_agent:
                 # can only send ack to agents if we know for sure port they bound to
                 hosts_altered = set([h for h in hosts_altered if h in self.mgr.cache.agent_ports])
-                self.mgr.agent_helpers._request_agent_acks(hosts_altered)
+                self.mgr.agent_helpers._request_agent_acks(hosts_altered, increment=True)
 
         if r is None:
             r = False
index 3559b9651674b16c99f31d1c75761194f03696f0..c1cda43032f8fd5c08cbdb2a39aea55d128689ba 100644 (file)
@@ -1026,11 +1026,16 @@ class CephadmAgent(CephService):
                'host': daemon_spec.host}
 
         assert self.mgr.cherrypy_thread
+        assert self.mgr.cherrypy_thread.ssl_certs.get_root_cert()
+        listener_cert, listener_key = self.mgr.cherrypy_thread.ssl_certs.generate_cert(
+            daemon_spec.host)
         config = {
             'agent.json': json.dumps(cfg),
             'cephadm': self.mgr._cephadm,
             'keyring': daemon_spec.keyring,
             'root_cert.pem': self.mgr.cherrypy_thread.ssl_certs.get_root_cert(),
+            'listener.crt': listener_cert,
+            'listener.key': listener_key,
         }
 
         return config, sorted([str(self.mgr.get_mgr_ip()), str(self.mgr.endpoint_port), self.mgr.cherrypy_thread.ssl_certs.get_root_cert()])