]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
cephadm: move low level networking funcs to net_utils.py
authorJohn Mulligan <jmulligan@redhat.com>
Thu, 17 Aug 2023 18:18:02 +0000 (14:18 -0400)
committerJohn Mulligan <jmulligan@redhat.com>
Wed, 30 Aug 2023 18:00:47 +0000 (14:00 -0400)
Signed-off-by: John Mulligan <jmulligan@redhat.com>
Pair-programmed-with: Adam King <adking@redhat.com>
Co-authored-by: Adam King <adking@redhat.com>
src/cephadm/cephadm.py
src/cephadm/cephadmlib/net_utils.py [new file with mode: 0644]
src/cephadm/tests/fixtures.py
src/cephadm/tests/test_cephadm.py

index b2871017a9470dcc6fa977b0b769f4e5d47a0333..9fb67e9ab1177253ca1f07219a73f81bce403cde 100755 (executable)
@@ -21,7 +21,6 @@ import sys
 import tempfile
 import time
 import errno
-import struct
 import ssl
 from enum import Enum
 from typing import Dict, List, Tuple, Optional, Union, Any, Callable, IO, Sequence, TypeVar, cast, Set, Iterable, TextIO
@@ -91,7 +90,6 @@ from cephadmlib.context import CephadmContext
 from cephadmlib.exceptions import (
     ClusterAlreadyExists,
     Error,
-    PortOccupiedError,
     UnauthorizedRegistryError,
 )
 from cephadmlib.exe_utils import find_executable, find_program
@@ -119,6 +117,22 @@ from cephadmlib.file_utils import (
     write_new,
     write_tmp,
 )
+from cephadmlib.net_utils import (
+    EndPoint,
+    check_ip_port,
+    check_subnet,
+    get_fqdn,
+    get_hostname,
+    get_ip_addresses,
+    get_ipv4_address,
+    get_ipv6_address,
+    get_short_hostname,
+    ip_in_subnets,
+    is_ipv6,
+    port_in_use,
+    unwrap_ipv6,
+    wrap_ipv6,
+)
 
 FuncT = TypeVar('FuncT', bound=Callable)
 
@@ -151,20 +165,6 @@ cached_stdin = None
 ##################################
 
 
-class EndPoint:
-    """EndPoint representing an ip:port format"""
-
-    def __init__(self, ip: str, port: int) -> None:
-        self.ip = ip
-        self.port = port
-
-    def __str__(self) -> str:
-        return f'{self.ip}:{self.port}'
-
-    def __repr__(self) -> str:
-        return f'{self.ip}:{self.port}'
-
-
 class ContainerInfo:
     def __init__(self, container_id: str,
                  image_name: str,
@@ -1378,71 +1378,6 @@ def get_supported_daemons():
 ##################################
 
 
-def attempt_bind(ctx, s, address, port):
-    # type: (CephadmContext, socket.socket, str, int) -> None
-    try:
-        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        s.bind((address, port))
-    except OSError as e:
-        if e.errno == errno.EADDRINUSE:
-            msg = 'Cannot bind to IP %s port %d: %s' % (address, port, e)
-            logger.warning(msg)
-            raise PortOccupiedError(msg)
-        else:
-            raise e
-    except Exception as e:
-        raise Error(e)
-    finally:
-        s.close()
-
-
-def port_in_use(ctx: CephadmContext, endpoint: EndPoint) -> bool:
-    """Detect whether a port is in use on the local machine - IPv4 and IPv6"""
-    logger.info('Verifying port %s ...' % str(endpoint))
-
-    def _port_in_use(af: socket.AddressFamily, address: str) -> bool:
-        try:
-            s = socket.socket(af, socket.SOCK_STREAM)
-            attempt_bind(ctx, s, address, endpoint.port)
-        except PortOccupiedError:
-            return True
-        except OSError as e:
-            if e.errno in (errno.EAFNOSUPPORT, errno.EADDRNOTAVAIL):
-                # Ignore EAFNOSUPPORT and EADDRNOTAVAIL as two interfaces are
-                # being tested here and one might be intentionally be disabled.
-                # In that case no error should be raised.
-                return False
-            else:
-                raise e
-        return False
-
-    if endpoint.ip != '0.0.0.0' and endpoint.ip != '::':
-        if is_ipv6(endpoint.ip):
-            return _port_in_use(socket.AF_INET6, endpoint.ip)
-        else:
-            return _port_in_use(socket.AF_INET, endpoint.ip)
-
-    return any(_port_in_use(af, address) for af, address in (
-        (socket.AF_INET, '0.0.0.0'),
-        (socket.AF_INET6, '::')
-    ))
-
-
-def check_ip_port(ctx, ep):
-    # type: (CephadmContext, EndPoint) -> None
-    if not ctx.skip_ping_check:
-        logger.info(f'Verifying IP {ep.ip} port {ep.port} ...')
-        if is_ipv6(ep.ip):
-            s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-            ip = unwrap_ipv6(ep.ip)
-        else:
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-            ip = ep.ip
-        attempt_bind(ctx, s, ip, ep.port)
-
-##################################
-
-
 # this is an abbreviated version of
 # https://github.com/benediktschmitt/py-filelock/blob/master/filelock.py
 # that drops all of the compatibility (this is Unix/Linux only).
@@ -1739,30 +1674,6 @@ def try_convert_datetime(s):
     return None
 
 
-def get_hostname():
-    # type: () -> str
-    return socket.gethostname()
-
-
-def get_short_hostname():
-    # type: () -> str
-    return get_hostname().split('.', 1)[0]
-
-
-def get_fqdn():
-    # type: () -> str
-    return socket.getfqdn() or socket.gethostname()
-
-
-def get_ip_addresses(hostname: str) -> Tuple[List[str], List[str]]:
-    items = socket.getaddrinfo(hostname, None,
-                               flags=socket.AI_CANONNAME,
-                               type=socket.SOCK_STREAM)
-    ipv4_addresses = [i[4][0] for i in items if i[0] == socket.AF_INET]
-    ipv6_addresses = [i[4][0] for i in items if i[0] == socket.AF_INET6]
-    return ipv4_addresses, ipv6_addresses
-
-
 def get_arch():
     # type: () -> str
     return platform.uname().machine
@@ -4997,76 +4908,6 @@ def get_image_info_from_inspect(out, image):
 ##################################
 
 
-def check_subnet(subnets: str) -> Tuple[int, List[int], str]:
-    """Determine whether the given string is a valid subnet
-
-    :param subnets: subnet string, a single definition or comma separated list of CIDR subnets
-    :returns: return code, IP version list of the subnets and msg describing any errors validation errors
-    """
-
-    rc = 0
-    versions = set()
-    errors = []
-    subnet_list = subnets.split(',')
-    for subnet in subnet_list:
-        # ensure the format of the string is as expected address/netmask
-        subnet = subnet.strip()
-        if not re.search(r'\/\d+$', subnet):
-            rc = 1
-            errors.append(f'{subnet} is not in CIDR format (address/netmask)')
-            continue
-        try:
-            v = ipaddress.ip_network(subnet).version
-            versions.add(v)
-        except ValueError as e:
-            rc = 1
-            errors.append(f'{subnet} invalid: {str(e)}')
-
-    return rc, list(versions), ', '.join(errors)
-
-
-def unwrap_ipv6(address):
-    # type: (str) -> str
-    if address.startswith('[') and address.endswith(']'):
-        return address[1: -1]
-    return address
-
-
-def wrap_ipv6(address):
-    # type: (str) -> str
-
-    # We cannot assume it's already wrapped or even an IPv6 address if
-    # it's already wrapped it'll not pass (like if it's a hostname) and trigger
-    # the ValueError
-    try:
-        if ipaddress.ip_address(address).version == 6:
-            return f'[{address}]'
-    except ValueError:
-        pass
-
-    return address
-
-
-def is_ipv6(address):
-    # type: (str) -> bool
-    address = unwrap_ipv6(address)
-    try:
-        return ipaddress.ip_address(address).version == 6
-    except ValueError:
-        logger.warning('Address: {} is not a valid IP address'.format(address))
-        return False
-
-
-def ip_in_subnets(ip_addr: str, subnets: str) -> bool:
-    """Determine if the ip_addr belongs to any of the subnets list."""
-    subnet_list = [x.strip() for x in subnets.split(',')]
-    for subnet in subnet_list:
-        ip_address = unwrap_ipv6(ip_addr) if is_ipv6(ip_addr) else ip_addr
-        if ipaddress.ip_address(ip_address) in ipaddress.ip_network(subnet):
-            return True
-    return False
-
-
 def parse_mon_addrv(addrv_arg: str) -> List[EndPoint]:
     """Parse mon-addrv param into a list of mon end points."""
     r = re.compile(r':(\d+)$')
@@ -8629,50 +8470,6 @@ def command_rescan_disks(ctx: CephadmContext) -> str:
 ##################################
 
 
-def get_ipv4_address(ifname):
-    # type: (str) -> str
-    def _extract(sock: socket.socket, offset: int) -> str:
-        return socket.inet_ntop(
-            socket.AF_INET,
-            fcntl.ioctl(
-                sock.fileno(),
-                offset,
-                struct.pack('256s', bytes(ifname[:15], 'utf-8'))
-            )[20:24])
-
-    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    try:
-        addr = _extract(s, 35093)  # '0x8915' = SIOCGIFADDR
-        dq_mask = _extract(s, 35099)  # 0x891b = SIOCGIFNETMASK
-    except OSError:
-        # interface does not have an ipv4 address
-        return ''
-
-    dec_mask = sum([bin(int(i)).count('1')
-                    for i in dq_mask.split('.')])
-    return '{}/{}'.format(addr, dec_mask)
-
-
-def get_ipv6_address(ifname):
-    # type: (str) -> str
-    if not os.path.exists('/proc/net/if_inet6'):
-        return ''
-
-    raw = read_file(['/proc/net/if_inet6'])
-    data = raw.splitlines()
-    # based on docs @ https://www.tldp.org/HOWTO/Linux+IPv6-HOWTO/ch11s04.html
-    # field 0 is ipv6, field 2 is scope
-    for iface_setting in data:
-        field = iface_setting.split()
-        if field[-1] == ifname:
-            ipv6_raw = field[0]
-            ipv6_fmtd = ':'.join([ipv6_raw[_p:_p + 4] for _p in range(0, len(field[0]), 4)])
-            # apply naming rules using ipaddress module
-            ipv6 = ipaddress.ip_address(ipv6_fmtd)
-            return '{}/{}'.format(str(ipv6), int('0x{}'.format(field[2]), 16))
-    return ''
-
-
 def bytes_to_human(num, mode='decimal'):
     # type: (float, str) -> str
     """Convert a bytes value into it's human-readable form.
diff --git a/src/cephadm/cephadmlib/net_utils.py b/src/cephadm/cephadmlib/net_utils.py
new file mode 100644 (file)
index 0000000..2650e8f
--- /dev/null
@@ -0,0 +1,233 @@
+# net_utils.py - Generic networking utility functions
+
+import errno
+import fcntl
+import ipaddress
+import logging
+import os
+import re
+import socket
+import struct
+
+from typing import Tuple, List
+
+from .context import CephadmContext
+from .exceptions import Error, PortOccupiedError
+from .file_utils import read_file
+
+logger = logging.getLogger()
+
+
+class EndPoint:
+    """EndPoint representing an ip:port format"""
+
+    def __init__(self, ip: str, port: int) -> None:
+        self.ip = ip
+        self.port = port
+
+    def __str__(self) -> str:
+        return f'{self.ip}:{self.port}'
+
+    def __repr__(self) -> str:
+        return f'{self.ip}:{self.port}'
+
+
+def attempt_bind(ctx, s, address, port):
+    # type: (CephadmContext, socket.socket, str, int) -> None
+    try:
+        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        s.bind((address, port))
+    except OSError as e:
+        if e.errno == errno.EADDRINUSE:
+            msg = 'Cannot bind to IP %s port %d: %s' % (address, port, e)
+            logger.warning(msg)
+            raise PortOccupiedError(msg)
+        else:
+            raise e
+    except Exception as e:
+        raise Error(e)
+    finally:
+        s.close()
+
+
+def port_in_use(ctx: CephadmContext, endpoint: EndPoint) -> bool:
+    """Detect whether a port is in use on the local machine - IPv4 and IPv6"""
+    logger.info('Verifying port %s ...' % str(endpoint))
+
+    def _port_in_use(af: socket.AddressFamily, address: str) -> bool:
+        try:
+            s = socket.socket(af, socket.SOCK_STREAM)
+            attempt_bind(ctx, s, address, endpoint.port)
+        except PortOccupiedError:
+            return True
+        except OSError as e:
+            if e.errno in (errno.EAFNOSUPPORT, errno.EADDRNOTAVAIL):
+                # Ignore EAFNOSUPPORT and EADDRNOTAVAIL as two interfaces are
+                # being tested here and one might be intentionally be disabled.
+                # In that case no error should be raised.
+                return False
+            else:
+                raise e
+        return False
+
+    if endpoint.ip != '0.0.0.0' and endpoint.ip != '::':
+        if is_ipv6(endpoint.ip):
+            return _port_in_use(socket.AF_INET6, endpoint.ip)
+        else:
+            return _port_in_use(socket.AF_INET, endpoint.ip)
+
+    return any(_port_in_use(af, address) for af, address in (
+        (socket.AF_INET, '0.0.0.0'),
+        (socket.AF_INET6, '::')
+    ))
+
+
+def check_ip_port(ctx, ep):
+    # type: (CephadmContext, EndPoint) -> None
+    if not ctx.skip_ping_check:
+        logger.info(f'Verifying IP {ep.ip} port {ep.port} ...')
+        if is_ipv6(ep.ip):
+            s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+            ip = unwrap_ipv6(ep.ip)
+        else:
+            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            ip = ep.ip
+        attempt_bind(ctx, s, ip, ep.port)
+
+
+def check_subnet(subnets: str) -> Tuple[int, List[int], str]:
+    """Determine whether the given string is a valid subnet
+
+    :param subnets: subnet string, a single definition or comma separated list of CIDR subnets
+    :returns: return code, IP version list of the subnets and msg describing any errors validation errors
+    """
+
+    rc = 0
+    versions = set()
+    errors = []
+    subnet_list = subnets.split(',')
+    for subnet in subnet_list:
+        # ensure the format of the string is as expected address/netmask
+        subnet = subnet.strip()
+        if not re.search(r'\/\d+$', subnet):
+            rc = 1
+            errors.append(f'{subnet} is not in CIDR format (address/netmask)')
+            continue
+        try:
+            v = ipaddress.ip_network(subnet).version
+            versions.add(v)
+        except ValueError as e:
+            rc = 1
+            errors.append(f'{subnet} invalid: {str(e)}')
+
+    return rc, list(versions), ', '.join(errors)
+
+
+def unwrap_ipv6(address):
+    # type: (str) -> str
+    if address.startswith('[') and address.endswith(']'):
+        return address[1: -1]
+    return address
+
+
+def wrap_ipv6(address):
+    # type: (str) -> str
+
+    # We cannot assume it's already wrapped or even an IPv6 address if
+    # it's already wrapped it'll not pass (like if it's a hostname) and trigger
+    # the ValueError
+    try:
+        if ipaddress.ip_address(address).version == 6:
+            return f'[{address}]'
+    except ValueError:
+        pass
+
+    return address
+
+
+def is_ipv6(address):
+    # type: (str) -> bool
+    address = unwrap_ipv6(address)
+    try:
+        return ipaddress.ip_address(address).version == 6
+    except ValueError:
+        logger.warning('Address: {} is not a valid IP address'.format(address))
+        return False
+
+
+def ip_in_subnets(ip_addr: str, subnets: str) -> bool:
+    """Determine if the ip_addr belongs to any of the subnets list."""
+    subnet_list = [x.strip() for x in subnets.split(',')]
+    for subnet in subnet_list:
+        ip_address = unwrap_ipv6(ip_addr) if is_ipv6(ip_addr) else ip_addr
+        if ipaddress.ip_address(ip_address) in ipaddress.ip_network(subnet):
+            return True
+    return False
+
+
+def get_ipv4_address(ifname):
+    # type: (str) -> str
+    def _extract(sock: socket.socket, offset: int) -> str:
+        return socket.inet_ntop(
+            socket.AF_INET,
+            fcntl.ioctl(
+                sock.fileno(),
+                offset,
+                struct.pack('256s', bytes(ifname[:15], 'utf-8'))
+            )[20:24])
+
+    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    try:
+        addr = _extract(s, 35093)  # '0x8915' = SIOCGIFADDR
+        dq_mask = _extract(s, 35099)  # 0x891b = SIOCGIFNETMASK
+    except OSError:
+        # interface does not have an ipv4 address
+        return ''
+
+    dec_mask = sum([bin(int(i)).count('1')
+                    for i in dq_mask.split('.')])
+    return '{}/{}'.format(addr, dec_mask)
+
+
+def get_ipv6_address(ifname):
+    # type: (str) -> str
+    if not os.path.exists('/proc/net/if_inet6'):
+        return ''
+
+    raw = read_file(['/proc/net/if_inet6'])
+    data = raw.splitlines()
+    # based on docs @ https://www.tldp.org/HOWTO/Linux+IPv6-HOWTO/ch11s04.html
+    # field 0 is ipv6, field 2 is scope
+    for iface_setting in data:
+        field = iface_setting.split()
+        if field[-1] == ifname:
+            ipv6_raw = field[0]
+            ipv6_fmtd = ':'.join([ipv6_raw[_p:_p + 4] for _p in range(0, len(field[0]), 4)])
+            # apply naming rules using ipaddress module
+            ipv6 = ipaddress.ip_address(ipv6_fmtd)
+            return '{}/{}'.format(str(ipv6), int('0x{}'.format(field[2]), 16))
+    return ''
+
+
+def get_hostname():
+    # type: () -> str
+    return socket.gethostname()
+
+
+def get_short_hostname():
+    # type: () -> str
+    return get_hostname().split('.', 1)[0]
+
+
+def get_fqdn():
+    # type: () -> str
+    return socket.getfqdn() or socket.gethostname()
+
+
+def get_ip_addresses(hostname: str) -> Tuple[List[str], List[str]]:
+    items = socket.getaddrinfo(hostname, None,
+                               flags=socket.AI_CANONNAME,
+                               type=socket.SOCK_STREAM)
+    ipv4_addresses = [i[4][0] for i in items if i[0] == socket.AF_INET]
+    ipv6_addresses = [i[4][0] for i in items if i[0] == socket.AF_INET6]
+    return ipv4_addresses, ipv6_addresses
index 83cddf331d8ac0c6f86db38a4362fa1b4f0c842e..c82a138cbaed310948198681f7666f085852dd67 100644 (file)
@@ -145,7 +145,7 @@ def with_cephadm_ctx(
         hostname = 'host1'
 
     _cephadm = import_cephadm()
-    with mock.patch('cephadm.attempt_bind'), \
+    with mock.patch('cephadmlib.net_utils.attempt_bind'), \
          mock.patch('cephadmlib.call_wrappers.call', return_value=('', '', 0)), \
          mock.patch('cephadmlib.call_wrappers.call_timeout', return_value=0), \
          mock.patch('cephadm.call', return_value=('', '', 0)), \
index a1ee4495e96643157ffd98e0e8ada179b2b4d8e6..40f442454e42f306bf977f21774ea7033061de79 100644 (file)
@@ -47,6 +47,8 @@ class TestCephAdm(object):
 
     @mock.patch('cephadm.logger')
     def test_attempt_bind(self, _logger):
+        from cephadmlib.net_utils import PortOccupiedError, attempt_bind
+
         ctx = None
         address = None
         port = 0
@@ -57,7 +59,7 @@ class TestCephAdm(object):
             return _os_error
 
         for side_effect, expected_exception in (
-            (os_error(errno.EADDRINUSE), _cephadm.PortOccupiedError),
+            (os_error(errno.EADDRINUSE), PortOccupiedError),
             (os_error(errno.EAFNOSUPPORT), OSError),
             (os_error(errno.EADDRNOTAVAIL), OSError),
             (None, None),
@@ -65,36 +67,39 @@ class TestCephAdm(object):
             _socket = mock.Mock()
             _socket.bind.side_effect = side_effect
             try:
-                _cephadm.attempt_bind(ctx, _socket, address, port)
+                attempt_bind(ctx, _socket, address, port)
             except Exception as e:
                 assert isinstance(e, expected_exception)
             else:
                 if expected_exception is not None:
                     assert False
 
-    @mock.patch('cephadm.attempt_bind')
+    @mock.patch('cephadmlib.net_utils.attempt_bind')
     @mock.patch('cephadm.logger')
     def test_port_in_use(self, _logger, _attempt_bind):
+        from cephadmlib.net_utils import PortOccupiedError, port_in_use
+
         empty_ctx = None
 
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
 
-        _attempt_bind.side_effect = _cephadm.PortOccupiedError('msg')
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == True
+        _attempt_bind.side_effect = PortOccupiedError('msg')
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == True
 
         os_error = OSError()
         os_error.errno = errno.EADDRNOTAVAIL
         _attempt_bind.side_effect = os_error
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
 
         os_error = OSError()
         os_error.errno = errno.EAFNOSUPPORT
         _attempt_bind.side_effect = os_error
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('0.0.0.0', 9100)) == False
 
     @mock.patch('cephadm.socket.socket.bind')
     @mock.patch('cephadm.logger')
     def test_port_in_use_special_cases(self, _logger, _bind):
+        from cephadmlib.net_utils import PortOccupiedError, port_in_use
         # port_in_use has special handling for
         # EAFNOSUPPORT and EADDRNOTAVAIL errno OSErrors.
         # If we get those specific errors when attempting
@@ -107,26 +112,28 @@ class TestCephAdm(object):
             return _os_error
 
         _bind.side_effect = os_error(errno.EADDRNOTAVAIL)
-        in_use = _cephadm.port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
+        in_use = port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
         assert in_use == False
 
         _bind.side_effect = os_error(errno.EAFNOSUPPORT)
-        in_use = _cephadm.port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
+        in_use = port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
         assert in_use == False
 
         # this time, have it raise the actual port taken error
         # so it should report the port is in use
         _bind.side_effect = os_error(errno.EADDRINUSE)
-        in_use = _cephadm.port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
+        in_use = port_in_use(None, _cephadm.EndPoint('1.2.3.4', 10000))
         assert in_use == True
 
-    @mock.patch('cephadm.attempt_bind')
+    @mock.patch('cephadmlib.net_utils.attempt_bind')
     @mock.patch('cephadm.logger')
     def test_port_in_use_with_specific_ips(self, _logger, _attempt_bind):
+        from cephadmlib.net_utils import PortOccupiedError, port_in_use
+
         empty_ctx = None
 
         def _fake_attempt_bind(ctx, s: socket.socket, addr: str, port: int) -> None:
-            occupied_error = _cephadm.PortOccupiedError('msg')
+            occupied_error = PortOccupiedError('msg')
             if addr.startswith('200'):
                 raise occupied_error
             if addr.startswith('100'):
@@ -135,10 +142,10 @@ class TestCephAdm(object):
 
         _attempt_bind.side_effect = _fake_attempt_bind
 
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('200.0.0.0', 9100)) == True
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('100.0.0.0', 9100)) == False
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('100.0.0.0', 4567)) == True
-        assert _cephadm.port_in_use(empty_ctx, _cephadm.EndPoint('155.0.0.0', 4567)) == False
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('200.0.0.0', 9100)) == True
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('100.0.0.0', 9100)) == False
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('100.0.0.0', 4567)) == True
+        assert port_in_use(empty_ctx, _cephadm.EndPoint('155.0.0.0', 4567)) == False
 
     @mock.patch('socket.socket')
     @mock.patch('cephadm.logger')
@@ -160,6 +167,8 @@ class TestCephAdm(object):
     @mock.patch('socket.socket')
     @mock.patch('cephadm.logger')
     def test_check_ip_port_failure(self, _logger, _socket):
+        from cephadmlib.net_utils import PortOccupiedError
+
         ctx = _cephadm.CephadmContext()
         ctx.skip_ping_check = False  # enables executing port check with `check_ip_port`
 
@@ -173,7 +182,7 @@ class TestCephAdm(object):
             ('::', socket.AF_INET6),
         ):
             for side_effect, expected_exception in (
-                (os_error(errno.EADDRINUSE), _cephadm.PortOccupiedError),
+                (os_error(errno.EADDRINUSE), PortOccupiedError),
                 (os_error(errno.EADDRNOTAVAIL), OSError),
                 (os_error(errno.EAFNOSUPPORT), OSError),
                 (None, None),