class SSHManager:
+ # Retry count for connection/channel errors - easy to change
+ SSH_RETRY_COUNT = 3
+
+ # SSH Channel Open Error codes (from RFC 4254)
+ # Only retry on transient/recoverable errors
+ CHANNEL_OPEN_RECOVERABLE_CODES = {
+ 2, # OPEN_CONNECT_FAILED - connection to target failed (may be transient)
+ 4, # OPEN_RESOURCE_SHORTAGE - server lacks resources (may clear up)
+ }
+ # Retryable exception types for SSH command execution
+ RETRYABLE_ERRORS = (
+ asyncssh.ChannelOpenError,
+ asyncssh.ConnectionLost,
+ asyncssh.DisconnectError,
+ asyncio.TimeoutError,
+ OSError, # network-level OS errors
+ )
+
def __init__(self, mgr: "CephadmOrchestrator"):
self.mgr: "CephadmOrchestrator" = mgr
self.cons: Dict[str, "SSHClientConnection"] = {}
+ def _is_conn_valid(self, conn: "SSHClientConnection") -> bool:
+ """Safely check if an AsyncSSH connection is still valid and usable."""
+ try:
+ if conn is None:
+ return False
+ if hasattr(conn, "is_connected") and not conn.is_connected():
+ return False
+ if hasattr(conn, "is_closing") and conn.is_closing():
+ return False
+ if hasattr(conn, "is_closed") and callable(conn.is_closed) and conn.is_closed():
+ return False
+ return True
+ except Exception:
+ return False
+
async def _remote_connection(self,
host: str,
addr: Optional[str] = None,
) -> "SSHClientConnection":
- if not self.cons.get(host) or host not in self.mgr.inventory:
- if not addr and host in self.mgr.inventory:
- addr = self.mgr.inventory.get_addr(host)
-
- if not addr:
- raise OrchestratorError("host address is empty")
-
- assert self.mgr.ssh_user
- n = self.mgr.ssh_user + '@' + addr
- logger.debug("Opening connection to {} with ssh options '{}'".format(
- n, self.mgr._ssh_options))
-
- asyncssh.set_log_level('DEBUG')
- asyncssh.set_debug_level(3)
-
- with self.redirect_log(host, addr):
- try:
- ssh_options = asyncssh.SSHClientConnectionOptions(
- keepalive_interval=self.mgr.ssh_keepalive_interval,
- keepalive_count_max=self.mgr.ssh_keepalive_count_max
- )
- conn = await asyncssh.connect(addr, username=self.mgr.ssh_user, client_keys=[self.mgr.tkey.name],
- known_hosts=None, config=[self.mgr.ssh_config_fname],
- preferred_auth=['publickey'], options=ssh_options)
- except OSError:
- raise
- except asyncssh.Error:
- raise
- except Exception:
- raise
- self.cons[host] = conn
+ existing_conn = self.cons.get(host)
+ # Check if we have a valid existing connection
+ if existing_conn and host in self.mgr.inventory and self._is_conn_valid(existing_conn):
+ self.mgr.offline_hosts_remove(host)
+ return existing_conn
+
+ if existing_conn and not self._is_conn_valid(existing_conn):
+ logger.debug(f'Existing connection to {host} is invalid, creating new connection')
+ await self._reset_con(host)
+
+ if not addr and host in self.mgr.inventory:
+ addr = self.mgr.inventory.get_addr(host)
+ if not addr:
+ raise OrchestratorError("host address is empty")
+
+ assert self.mgr.ssh_user
+ n = self.mgr.ssh_user + '@' + addr
+ logger.debug("Opening connection to {} with ssh options '{}'".format(
+ n, self.mgr._ssh_options))
+
+ asyncssh.set_log_level('DEBUG')
+ asyncssh.set_debug_level(3)
+
+ with self.redirect_log(host, addr):
+ try:
+ ssh_options = asyncssh.SSHClientConnectionOptions(
+ keepalive_interval=self.mgr.ssh_keepalive_interval,
+ keepalive_count_max=self.mgr.ssh_keepalive_count_max
+ )
+ conn = await asyncssh.connect(addr, username=self.mgr.ssh_user, client_keys=[self.mgr.tkey.name],
+ known_hosts=None, config=[self.mgr.ssh_config_fname],
+ preferred_auth=['publickey'], options=ssh_options)
+ except OSError:
+ raise
+ except asyncssh.Error:
+ raise
+ except Exception:
+ raise
+ self.cons[host] = conn
self.mgr.offline_hosts_remove(host)
self.mgr.offline_hosts.add(host)
log_content = log_string.getvalue()
msg = f'Failed to connect to {host} ({addr}). {str(e)}' + '\n' + f'Log: {log_content}'
- logger.debug(msg)
+ logger.exception(msg)
raise HostConnectionError(msg, host, addr)
except Exception as e:
self.mgr.offline_hosts.add(host)
address = host
if log_command:
logger.debug(f'Running command: {rcmd}')
- try:
- r = await conn.run(str(rcmd), input=stdin)
- # handle these Exceptions otherwise you might get a weird error like
- # TypeError: __init__() missing 1 required positional argument: 'reason' (due to the asyncssh error interacting with raise_if_exception)
- except asyncssh.ChannelOpenError as e:
- # SSH connection closed or broken, will create new connection next call
- logger.debug(f'Connection to {host} failed. {str(e)}')
- await self._reset_con(host)
- self.mgr.offline_hosts.add(host)
- raise HostConnectionError(f'Unable to reach remote host {host}. {str(e)}', host, address)
- except asyncssh.ProcessError as e:
- msg = f"Cannot execute the command '{rcmd}' on the {host}. {str(e.stderr)}."
- logger.debug(msg)
- await self._reset_con(host)
- self.mgr.offline_hosts.add(host)
- raise HostConnectionError(msg, host, address)
- except Exception as e:
- msg = f"Generic error while executing command '{rcmd}' on the host {host}. {str(e)}."
- logger.debug(msg)
- await self._reset_con(host)
- self.mgr.offline_hosts.add(host)
- raise HostConnectionError(msg, host, address)
+
+ # Retry logic for transient connection/channel errors
+ for attempt in range(self.SSH_RETRY_COUNT):
+ try:
+ r = await conn.run(str(rcmd), input=stdin)
+ break # Success, exit retry loop
+ # Handle retryable exceptions (connection/channel errors)
+ # Note: handle these Exceptions otherwise you might get a weird error like
+ # TypeError: __init__() missing 1 required positional argument: 'reason'
+ # (due to the asyncssh error interacting with raise_if_exception)
+ except self.RETRYABLE_ERRORS as e:
+ error_type = type(e).__name__
+ logger.exception('Command exection failed with %s', error_type)
+ # For ChannelOpenError, check if the error code is recoverable
+ if isinstance(e, asyncssh.ChannelOpenError):
+ error_code = getattr(e, 'code', None)
+ logger.debug(
+ f'{error_type} (code={error_code}) on attempt '
+ f'{attempt + 1}/{self.SSH_RETRY_COUNT} '
+ f'for host {host}: {str(e)}')
+ # Check if this error code is recoverable/retryable
+ if error_code not in self.CHANNEL_OPEN_RECOVERABLE_CODES:
+ # Non-recoverable error code, don't retry
+ logger.debug(
+ f'ChannelOpenError code {error_code} is not recoverable, '
+ f'not retrying for host {host}')
+ await self._reset_con(host)
+ self.mgr.offline_hosts.add(host)
+ raise HostConnectionError(
+ f'Unable to reach remote host {host}. {str(e)}',
+ host, address)
+ else:
+ logger.debug(
+ f'{error_type} on attempt {attempt + 1}/{self.SSH_RETRY_COUNT} '
+ f'for host {host}: {str(e)}')
+
+ # Reset connection and get a new one for retry
+ await self._reset_con(host)
+ if attempt < self.SSH_RETRY_COUNT - 1:
+ # Not the last attempt, try to get a new connection
+ try:
+ conn = await self._remote_connection(host, addr)
+ except Exception as conn_e:
+ logger.debug(
+ f'Failed to re-establish connection to {host} '
+ f'on retry: {str(conn_e)}')
+ # Continue to next attempt, connection will be retried
+ continue
+ else:
+ # Last attempt failed, raise the error
+ self.mgr.offline_hosts.add(host)
+ raise HostConnectionError(
+ f'Unable to reach remote host {host} after '
+ f'{self.SSH_RETRY_COUNT} attempts. {str(e)}',
+ host, address)
+ except asyncssh.ProcessError as e:
+ msg = f"ProcessError cannot execute the command '{rcmd}' on the {host}. {str(e.stderr)}."
+ logger.exception(msg)
+ await self._reset_con(host)
+ self.mgr.offline_hosts.add(host)
+ raise HostConnectionError(msg, host, address)
+ except Exception as e:
+ error_type = type(e).__name__
+ msg = (f"Generic error {error_type} while executing command '{rcmd}' "
+ f"on the host {host}. {str(e)}.")
+ logger.exception(msg)
+ await self._reset_con(host)
+ self.mgr.offline_hosts.add(host)
+ raise HostConnectionError(msg, host, address)
def _rstrip(v: Union[bytes, str, None]) -> str:
if not v: