]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
mgr/cephadm: Added retry logic for execute command if command fails with connection...
authorShweta Bhosale <Shweta.Bhosale1@ibm.com>
Wed, 10 Dec 2025 09:43:41 +0000 (15:13 +0530)
committerShweta Bhosale <Shweta.Bhosale1@ibm.com>
Wed, 10 Dec 2025 18:18:54 +0000 (23:48 +0530)
Fixes: https://tracker.ceph.com/issues/74179
Signed-off-by: Shweta Bhosale <Shweta.Bhosale1@ibm.com>
src/pybind/mgr/cephadm/ssh.py
src/pybind/mgr/cephadm/tests/test_ssh.py

index acb5a77c51b9ed60865d631370c1ca303dd7d27a..95a06e0d8727550c8d8341ce608a5abd20f8c1b6 100644 (file)
@@ -142,45 +142,86 @@ class EventLoopThread(Thread):
 
 class SSHManager:
 
+    # Retry count for connection/channel errors - easy to change
+    SSH_RETRY_COUNT = 3
+
+    # SSH Channel Open Error codes (from RFC 4254)
+    # Only retry on transient/recoverable errors
+    CHANNEL_OPEN_RECOVERABLE_CODES = {
+        2,  # OPEN_CONNECT_FAILED - connection to target failed (may be transient)
+        4,  # OPEN_RESOURCE_SHORTAGE - server lacks resources (may clear up)
+    }
+    # Retryable exception types for SSH command execution
+    RETRYABLE_ERRORS = (
+        asyncssh.ChannelOpenError,
+        asyncssh.ConnectionLost,
+        asyncssh.DisconnectError,
+        asyncio.TimeoutError,
+        OSError,  # network-level OS errors
+    )
+
     def __init__(self, mgr: "CephadmOrchestrator"):
         self.mgr: "CephadmOrchestrator" = mgr
         self.cons: Dict[str, "SSHClientConnection"] = {}
 
+    def _is_conn_valid(self, conn: "SSHClientConnection") -> bool:
+        """Safely check if an AsyncSSH connection is still valid and usable."""
+        try:
+            if conn is None:
+                return False
+            if hasattr(conn, "is_connected") and not conn.is_connected():
+                return False
+            if hasattr(conn, "is_closing") and conn.is_closing():
+                return False
+            if hasattr(conn, "is_closed") and callable(conn.is_closed) and conn.is_closed():
+                return False
+            return True
+        except Exception:
+            return False
+
     async def _remote_connection(self,
                                  host: str,
                                  addr: Optional[str] = None,
                                  ) -> "SSHClientConnection":
-        if not self.cons.get(host) or host not in self.mgr.inventory:
-            if not addr and host in self.mgr.inventory:
-                addr = self.mgr.inventory.get_addr(host)
-
-            if not addr:
-                raise OrchestratorError("host address is empty")
-
-            assert self.mgr.ssh_user
-            n = self.mgr.ssh_user + '@' + addr
-            logger.debug("Opening connection to {} with ssh options '{}'".format(
-                n, self.mgr._ssh_options))
-
-            asyncssh.set_log_level('DEBUG')
-            asyncssh.set_debug_level(3)
-
-            with self.redirect_log(host, addr):
-                try:
-                    ssh_options = asyncssh.SSHClientConnectionOptions(
-                        keepalive_interval=self.mgr.ssh_keepalive_interval,
-                        keepalive_count_max=self.mgr.ssh_keepalive_count_max
-                    )
-                    conn = await asyncssh.connect(addr, username=self.mgr.ssh_user, client_keys=[self.mgr.tkey.name],
-                                                  known_hosts=None, config=[self.mgr.ssh_config_fname],
-                                                  preferred_auth=['publickey'], options=ssh_options)
-                except OSError:
-                    raise
-                except asyncssh.Error:
-                    raise
-                except Exception:
-                    raise
-            self.cons[host] = conn
+        existing_conn = self.cons.get(host)
+        # Check if we have a valid existing connection
+        if existing_conn and host in self.mgr.inventory and self._is_conn_valid(existing_conn):
+            self.mgr.offline_hosts_remove(host)
+            return existing_conn
+
+        if existing_conn and not self._is_conn_valid(existing_conn):
+            logger.debug(f'Existing connection to {host} is invalid, creating new connection')
+            await self._reset_con(host)
+
+        if not addr and host in self.mgr.inventory:
+            addr = self.mgr.inventory.get_addr(host)
+        if not addr:
+            raise OrchestratorError("host address is empty")
+
+        assert self.mgr.ssh_user
+        n = self.mgr.ssh_user + '@' + addr
+        logger.debug("Opening connection to {} with ssh options '{}'".format(
+            n, self.mgr._ssh_options))
+
+        asyncssh.set_log_level('DEBUG')
+        asyncssh.set_debug_level(3)
+
+        with self.redirect_log(host, addr):
+            try:
+                ssh_options = asyncssh.SSHClientConnectionOptions(
+                    keepalive_interval=self.mgr.ssh_keepalive_interval,
+                    keepalive_count_max=self.mgr.ssh_keepalive_count_max
+                )
+                conn = await asyncssh.connect(addr, username=self.mgr.ssh_user, client_keys=[self.mgr.tkey.name],
+                                              known_hosts=None, config=[self.mgr.ssh_config_fname],
+                                              preferred_auth=['publickey'], options=ssh_options)
+            except OSError:
+                raise
+            except asyncssh.Error:
+                raise
+            except Exception:
+                raise
+        self.cons[host] = conn
 
         self.mgr.offline_hosts_remove(host)
 
@@ -205,7 +246,7 @@ class SSHManager:
             self.mgr.offline_hosts.add(host)
             log_content = log_string.getvalue()
             msg = f'Failed to connect to {host} ({addr}). {str(e)}' + '\n' + f'Log: {log_content}'
-            logger.debug(msg)
+            logger.exception(msg)
             raise HostConnectionError(msg, host, addr)
         except Exception as e:
             self.mgr.offline_hosts.add(host)
@@ -241,28 +282,75 @@ class SSHManager:
             address = host
         if log_command:
             logger.debug(f'Running command: {rcmd}')
-        try:
-            r = await conn.run(str(rcmd), input=stdin)
-        # handle these Exceptions otherwise you might get a weird error like
-        # TypeError: __init__() missing 1 required positional argument: 'reason' (due to the asyncssh error interacting with raise_if_exception)
-        except asyncssh.ChannelOpenError as e:
-            # SSH connection closed or broken, will create new connection next call
-            logger.debug(f'Connection to {host} failed. {str(e)}')
-            await self._reset_con(host)
-            self.mgr.offline_hosts.add(host)
-            raise HostConnectionError(f'Unable to reach remote host {host}. {str(e)}', host, address)
-        except asyncssh.ProcessError as e:
-            msg = f"Cannot execute the command '{rcmd}' on the {host}. {str(e.stderr)}."
-            logger.debug(msg)
-            await self._reset_con(host)
-            self.mgr.offline_hosts.add(host)
-            raise HostConnectionError(msg, host, address)
-        except Exception as e:
-            msg = f"Generic error while executing command '{rcmd}' on the host {host}. {str(e)}."
-            logger.debug(msg)
-            await self._reset_con(host)
-            self.mgr.offline_hosts.add(host)
-            raise HostConnectionError(msg, host, address)
+
+        # Retry logic for transient connection/channel errors
+        for attempt in range(self.SSH_RETRY_COUNT):
+            try:
+                r = await conn.run(str(rcmd), input=stdin)
+                break  # Success, exit retry loop
+            # Handle retryable exceptions (connection/channel errors)
+            # Note: handle these Exceptions otherwise you might get a weird error like
+            # TypeError: __init__() missing 1 required positional argument: 'reason'
+            # (due to the asyncssh error interacting with raise_if_exception)
+            except self.RETRYABLE_ERRORS as e:
+                error_type = type(e).__name__
+                logger.exception('Command exection failed with %s', error_type)
+                # For ChannelOpenError, check if the error code is recoverable
+                if isinstance(e, asyncssh.ChannelOpenError):
+                    error_code = getattr(e, 'code', None)
+                    logger.debug(
+                        f'{error_type} (code={error_code}) on attempt '
+                        f'{attempt + 1}/{self.SSH_RETRY_COUNT} '
+                        f'for host {host}: {str(e)}')
+                    # Check if this error code is recoverable/retryable
+                    if error_code not in self.CHANNEL_OPEN_RECOVERABLE_CODES:
+                        # Non-recoverable error code, don't retry
+                        logger.debug(
+                            f'ChannelOpenError code {error_code} is not recoverable, '
+                            f'not retrying for host {host}')
+                        await self._reset_con(host)
+                        self.mgr.offline_hosts.add(host)
+                        raise HostConnectionError(
+                            f'Unable to reach remote host {host}. {str(e)}',
+                            host, address)
+                else:
+                    logger.debug(
+                        f'{error_type} on attempt {attempt + 1}/{self.SSH_RETRY_COUNT} '
+                        f'for host {host}: {str(e)}')
+
+                # Reset connection and get a new one for retry
+                await self._reset_con(host)
+                if attempt < self.SSH_RETRY_COUNT - 1:
+                    # Not the last attempt, try to get a new connection
+                    try:
+                        conn = await self._remote_connection(host, addr)
+                    except Exception as conn_e:
+                        logger.debug(
+                            f'Failed to re-establish connection to {host} '
+                            f'on retry: {str(conn_e)}')
+                        # Continue to next attempt, connection will be retried
+                        continue
+                else:
+                    # Last attempt failed, raise the error
+                    self.mgr.offline_hosts.add(host)
+                    raise HostConnectionError(
+                        f'Unable to reach remote host {host} after '
+                        f'{self.SSH_RETRY_COUNT} attempts. {str(e)}',
+                        host, address)
+            except asyncssh.ProcessError as e:
+                msg = f"ProcessError cannot execute the command '{rcmd}' on the {host}. {str(e.stderr)}."
+                logger.exception(msg)
+                await self._reset_con(host)
+                self.mgr.offline_hosts.add(host)
+                raise HostConnectionError(msg, host, address)
+            except Exception as e:
+                error_type = type(e).__name__
+                msg = (f"Generic error {error_type} while executing command '{rcmd}' "
+                       f"on the host {host}. {str(e)}.")
+                logger.exception(msg)
+                await self._reset_con(host)
+                self.mgr.offline_hosts.add(host)
+                raise HostConnectionError(msg, host, address)
 
         def _rstrip(v: Union[bytes, str, None]) -> str:
             if not v:
index 44ef3d429b75ce07b9033aa422f705445bf07d5a..95cd87716f23c297d1ccb4bb555395813d9df19b 100644 (file)
@@ -94,9 +94,9 @@ class TestWithSSH:
                                                                    exit_status="",
                                                                    exit_signal="",
                                                                    stderr=stderr,
-                                                                   stdout="")), f"Cannot execute the command.+{stderr}")
+                                                                   stdout="")), f"cannot execute the command.+{stderr}")
         # Test case 4: generic error
-        run_test('test4', FakeConn(exception=Exception), "Generic error while executing command.+")
+        run_test('test4', FakeConn(exception=Exception), "Generic error Exception while executing command.+")
 
 
 @pytest.mark.skipif(ConnectionLost is not None, reason='asyncssh')