from unittest import mock
-import copy, json, os, threading
+import copy, json, os, socket, threading
import pytest
from tests.fixtures import with_cephadm_ctx, cephadm_fs, import_cephadm
+from typing import Optional
+
_cephadm = import_cephadm()
agent.listener_port = None
with pytest.raises(Exception, match='Failed to pick port for agent to listen on: All 1000 ports starting at 7770 taken.'):
agent.run()
+
+
+@mock.patch("cephadm.CephadmAgent.pull_conf_settings")
+@mock.patch("cephadm.CephadmAgent.wakeup")
+def test_mgr_listener_handle_json_payload(_agent_wakeup, _pull_conf_settings, cephadm_fs):
+ with with_cephadm_ctx([]) as ctx:
+ ctx.fsid = FSID
+ agent = _cephadm.CephadmAgent(ctx, FSID, AGENT_ID)
+ cephadm_fs.create_dir(AGENT_DIR)
+
+ data_no_config = {
+ 'counter': 7
+ }
+ agent.mgr_listener.handle_json_payload(data_no_config)
+ _agent_wakeup.assert_not_called()
+ _pull_conf_settings.assert_not_called()
+ assert not any(os.path.exists(os.path.join(AGENT_DIR, s)) for s in agent.required_files)
+
+ data_with_config = {
+ 'counter': 7,
+ 'config': {
+ 'unrequired-file': 'unrequired-text'
+ }
+ }
+ data_with_config['config'].update({s: f'{s} text' for s in agent.required_files if s != agent.required_files[2]})
+ agent.mgr_listener.handle_json_payload(data_with_config)
+ _agent_wakeup.assert_called()
+ _pull_conf_settings.assert_called()
+ assert all(os.path.exists(os.path.join(AGENT_DIR, s)) for s in agent.required_files if s != agent.required_files[2])
+ assert not os.path.exists(os.path.join(AGENT_DIR, agent.required_files[2]))
+ assert not os.path.exists(os.path.join(AGENT_DIR, 'unrequired-file'))
+
+
+@mock.patch("socket.socket")
+@mock.patch("ssl.SSLContext.wrap_socket")
+@mock.patch("cephadm.MgrListener.handle_json_payload")
+@mock.patch("ssl.SSLContext.load_verify_locations")
+@mock.patch("ssl.SSLContext.load_cert_chain")
+def test_mgr_listener_run(_load_cert_chain, _load_verify_locations, _handle_json_payload,
+ _wrap_context, _socket, cephadm_fs):
+
+ with with_cephadm_ctx([]) as ctx:
+ ctx.fsid = FSID
+ agent = _cephadm.CephadmAgent(ctx, FSID, AGENT_ID)
+ cephadm_fs.create_dir(AGENT_DIR)
+
+ payload = json.dumps({'counter': 3,
+ 'config': {s: f'{s} text' for s in agent.required_files if s != agent.required_files[1]}})
+
+ class FakeSocket:
+
+ def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, fileno=None):
+ self.family = family
+ self.type = type
+
+ def bind(*args, **kwargs):
+ return
+
+ def settimeout(*args, **kwargs):
+ return
+
+ def listen(*args, **kwargs):
+ return
+
+ class FakeSecureSocket:
+
+ def __init__(self, pload):
+ self.payload = pload
+ self._conn = FakeConn(self.payload)
+ self.accepted = False
+
+ def accept(self):
+ # to make mgr listener run loop stop running,
+ # set it to stop after accepting a "connection"
+ # on our fake socket so only one iteration of the loop
+ # actually happens
+ agent.mgr_listener.stop = True
+ accepted = True
+ return self._conn, None
+
+ def load_cert_chain(*args, **kwargs):
+ return
+
+ def load_verify_locations(*args, **kwargs):
+ return
+
+ class FakeConn:
+
+ def __init__(self, payload: str = ''):
+ payload_len_str = str(len(payload.encode('utf-8')))
+ while len(payload_len_str.encode('utf-8')) < 10:
+ payload_len_str = '0' + payload_len_str
+ self.payload = (payload_len_str + payload).encode('utf-8')
+ self.buffer_len = len(self.payload)
+
+ def recv(self, len: Optional[int] = None):
+ if not len or len >= self.buffer_len:
+ ret = self.payload
+ self.payload = b''
+ self.buffer_len = 0
+ return ret
+ else:
+ ret = self.payload[:len]
+ self.payload = self.payload[len:]
+ self.buffer_len = self.buffer_len - len
+ return ret
+
+ FSS_good_data = FakeSecureSocket(payload)
+ FSS_bad_json = FakeSecureSocket('bad json')
+ _socket = FakeSocket
+ agent.listener_port = 7777
+
+ # first run, should successfully receive properly structured json payload
+ _wrap_context.side_effect = [FSS_good_data]
+ agent.mgr_listener.stop = False
+ FakeConn.send = mock.Mock(return_value=None)
+ agent.mgr_listener.run()
+
+ # verify payload was correctly extracted
+ assert _handle_json_payload.called_with(json.loads(payload))
+ FakeConn.send.assert_called_once_with(b'ACK')
+
+ # second run, with bad json data received
+ _wrap_context.side_effect = [FSS_bad_json]
+ agent.mgr_listener.stop = False
+ FakeConn.send = mock.Mock(return_value=None)
+ agent.mgr_listener.run()
+ FakeConn.send.assert_called_once_with(b'Failed to extract json payload from message: Expecting value: line 1 column 1 (char 0)')
+
+ # third run, no proper length as beginning og payload
+ FSS_no_length = FakeSecureSocket(payload)
+ FSS_no_length.payload = FSS_no_length.payload[10:]
+ FSS_no_length._conn.payload = FSS_no_length._conn.payload[10:]
+ FSS_no_length._conn.buffer_len -= 10
+ _wrap_context.side_effect = [FSS_no_length]
+ agent.mgr_listener.stop = False
+ FakeConn.send = mock.Mock(return_value=None)
+ agent.mgr_listener.run()
+ FakeConn.send.assert_called_once_with(b'Failed to extract length of payload from message: invalid literal for int() with base 10: \'{"counter"\'')
+
+ # some exception handling for full coverage
+ FSS_exc_testing = FakeSecureSocket(payload)
+ FSS_exc_testing.accept = mock.MagicMock()
+
+ def _accept(*args, **kwargs):
+ if not FSS_exc_testing.accepted:
+ FSS_exc_testing.accepted = True
+ raise socket.timeout()
+ else:
+ agent.mgr_listener.stop = True
+ raise Exception()
+
+ FSS_exc_testing.accept.side_effect = _accept
+ _wrap_context.side_effect = [FSS_exc_testing]
+ agent.mgr_listener.stop = False
+ FakeConn.send = mock.Mock(return_value=None)
+ agent.mgr_listener.run()
+ FakeConn.send.assert_not_called()
+ FSS_exc_testing.accept.call_count == 3