]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
mgr/orchestrator: Add mypy static type checking
authorSebastian Wagner <sebastian.wagner@suse.com>
Wed, 4 Dec 2019 17:07:42 +0000 (18:07 +0100)
committerSebastian Wagner <sebastian.wagner@suse.com>
Wed, 4 Dec 2019 19:45:49 +0000 (20:45 +0100)
static type checking is a good way to quickly increase the code coverage.
without creating lots of unit tests.

Signed-off-by: Sebastian Wagner <sebastian.wagner@suse.com>
src/pybind/mgr/mgr_module.py
src/pybind/mgr/mgr_util.py
src/pybind/mgr/orchestrator.py
src/pybind/mgr/ssh/tests/fixtures.py
src/pybind/mgr/tox.ini

index 060288f44714993e35b9baf1ef54a582220bc4de..4159d5044e050871bf23ab9d686b4256dcf35fad 100644 (file)
@@ -1,7 +1,7 @@
 import ceph_module  # noqa
 
 try:
-    from typing import Set, Tuple, Iterator, Any
+    from typing import Set, Tuple, Iterator, Any, Dict, Optional, Callable, List
 except ImportError:
     # just for type checking
     pass
@@ -251,7 +251,7 @@ class CRUSHMap(ceph_module.BasePyCRUSH):
         return osd_list
 
     def device_class_counts(self):
-        result = defaultdict(int)
+        result = defaultdict(int)  # type: Dict[str, int]
         # TODO don't abuse dump like this
         d = self.dump()
         for device in d['devices']:
@@ -262,7 +262,7 @@ class CRUSHMap(ceph_module.BasePyCRUSH):
 
 
 class CLICommand(object):
-    COMMANDS = {}
+    COMMANDS = {}  # type: Dict[str, CLICommand]
 
     def __init__(self, prefix, args="", desc="", perm="rw"):
         self.prefix = prefix
@@ -270,7 +270,7 @@ class CLICommand(object):
         self.args_dict = {}
         self.desc = desc
         self.perm = perm
-        self.func = None
+        self.func = None  # type: Optional[Callable]
         self._parse_args()
 
     def _parse_args(self):
@@ -300,6 +300,7 @@ class CLICommand(object):
             kwargs[a.replace("-", "_")] = cmd_dict[a]
         if inbuf:
             kwargs['inbuf'] = inbuf
+        assert self.func
         return self.func(mgr, **kwargs)
 
     @classmethod
@@ -528,8 +529,8 @@ class MgrStandbyModule(ceph_module.BaseMgrStandbyModule, MgrModuleLoggingMixin):
     from their active peer), and to configuration settings (read only).
     """
 
-    MODULE_OPTIONS = []
-    MODULE_OPTION_DEFAULTS = {}
+    MODULE_OPTIONS = []  # type: List[Dict[str, Any]]
+    MODULE_OPTION_DEFAULTS = {}  # type: Dict[str, Any]
 
     def __init__(self, module_name, capsule):
         super(MgrStandbyModule, self).__init__(capsule)
@@ -605,9 +606,9 @@ class MgrStandbyModule(ceph_module.BaseMgrStandbyModule, MgrModuleLoggingMixin):
 
 
 class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
-    COMMANDS = []
-    MODULE_OPTIONS = []
-    MODULE_OPTION_DEFAULTS = {}
+    COMMANDS = []  # type: List[Any]
+    MODULE_OPTIONS = []  # type: List[dict]
+    MODULE_OPTION_DEFAULTS = {}  # type: Dict[str, Any]
 
     # Priority definitions for perf counters
     PRIO_CRITICAL = 10
@@ -810,8 +811,9 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         return ''
 
     def _perfpath_to_path_labels(self, daemon, path):
-        label_names = ("ceph_daemon",)
-        labels = (daemon,)
+        # type: (str, str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]
+        label_names = ("ceph_daemon",)  # type: Tuple[str, ...]
+        labels = (daemon,)  # type: Tuple[str, ...]
 
         if daemon.startswith('rbd-mirror.'):
             match = re.match(
@@ -1284,7 +1286,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         value.
         """
 
-        result = defaultdict(dict)
+        result = defaultdict(dict)  # type: Dict[str, dict]
 
         for server in self.list_servers():
             for service in server['services']:
index ea219c117d2d64b4af9fd5464c3e914f49022905..7cd598c85cba8b6d875efb2c1dbe32efdeb8d975 100644 (file)
@@ -108,10 +108,10 @@ def get_default_addr():
            return False
 
     try:
-        return get_default_addr.result
+        return get_default_addr.result  # type: ignore
     except AttributeError:
         result = '::' if is_ipv6_enabled() else '0.0.0.0'
-        get_default_addr.result = result
+        get_default_addr.result = result  # type: ignore
         return result
 
 
index 6ccefeebfff08487d20ac5ad6f5e78d69f9173e2..460ad26b955f684021c5ed5ce8f7687848506c1b 100644 (file)
@@ -76,7 +76,7 @@ def parse_host_specs(host, require_network=True):
         return host_spec
 
     from ipaddress import ip_network, ip_address
-    networks = list()
+    networks = list()  # type: List[str]
     network = host_spec.network
     # in case we have [v2:1.2.3.4:3000,v1:1.2.3.4:6478]
     if ',' in network:
@@ -173,7 +173,7 @@ class _Promise(object):
 
     def __init__(self,
                  _first_promise=None,  # type: Optional["_Promise"]
-                 value=NO_RESULT,  # type: Optional
+                 value=NO_RESULT,  # type: Optional[Any]
                  on_complete=None,    # type: Optional[Callable]
                  name=None,  # type: Optional[str]
                  ):
@@ -189,7 +189,7 @@ class _Promise(object):
 
         # _Promise is not a continuation monad, as `_result` is of type
         # T instead of (T -> r) -> r. Therefore we need to store the first promise here.
-        self._first_promise = _first_promise or self  # type: 'Completion'
+        self._first_promise = _first_promise or self  # type: '_Promise'
 
     def __repr__(self):
         name = self._name or getattr(self._on_complete, '__name__', '??') if self._on_complete else 'None'
@@ -208,8 +208,6 @@ class _Promise(object):
         else:
             name = self._on_complete.__class__.__name__
         val = repr(self._value) if self._value not in (self.NO_RESULT, self.ASYNC_RESULT) else '...'
-        if hasattr(val, 'debug_str'):
-            val = val.debug_str()
         prefix = {
             self.INITIALIZED: '      ',
             self.RUNNING:     '   >>>',
@@ -278,6 +276,7 @@ class _Promise(object):
             assert self not in next_result
             next_result._append_promise(self._next_promise)
             self._set_next_promise(next_result)
+            assert self._next_promise
             if self._next_promise._value is self.NO_RESULT:
                 self._next_promise._value = self._value
             self.propagate_to_next()
@@ -453,7 +452,7 @@ class Completion(_Promise):
     def __init__(self,
                  _first_promise=None,  # type: Optional["Completion"]
                  value=_Promise.NO_RESULT,  # type: Any
-                 on_complete=None,  # type: Optional[Callable],
+                 on_complete=None,  # type: Optional[Callable]
                  name=None,  # type: Optional[str]
                  ):
         super(Completion, self).__init__(_first_promise, value, on_complete, name)
@@ -462,7 +461,7 @@ class Completion(_Promise):
     def _progress_reference(self):
         # type: () -> Optional[ProgressReference]
         if hasattr(self._on_complete, 'progress_id'):
-            return self._on_complete
+            return self._on_complete  # type: ignore
         return None
 
     @property
@@ -483,7 +482,7 @@ class Completion(_Promise):
     def with_progress(cls,  # type: Any
                       message,  # type: str
                       mgr,
-                      _first_promise=None,  # type: Optional["Completions"]
+                      _first_promise=None,  # type: Optional["Completion"]
                       value=_Promise.NO_RESULT,  # type: Any
                       on_complete=None,  # type: Optional[Callable]
                       calc_percent=None  # type: Optional[Callable[[], Any]]
@@ -788,21 +787,21 @@ class Orchestrator(object):
         return self.get_inventory()
 
     def add_host_label(self, host, label):
-        # type: (str) -> WriteCompletion
+        # type: (str, str) -> Completion
         """
         Add a host label
         """
-        return NotImplementedError()
+        raise NotImplementedError()
 
     def remove_host_label(self, host, label):
-        # type: (str) -> WriteCompletion
+        # type: (str, str) -> Completion
         """
         Remove a host label
         """
-        return NotImplementedError()
+        raise NotImplementedError()
 
     def get_inventory(self, node_filter=None, refresh=False):
-        # type: (InventoryFilter, bool) -> Completion
+        # type: (Optional[InventoryFilter], bool) -> Completion
         """
         Returns something that was created by `ceph-volume inventory`.
 
@@ -826,7 +825,7 @@ class Orchestrator(object):
         raise NotImplementedError()
 
     def service_action(self, action, service_type, service_name=None, service_id=None):
-        # type: (str, str, str, str) -> Completion
+        # type: (str, str, Optional[str], Optional[str]) -> Completion
         """
         Perform an action (start/stop/reload) on a service.
 
@@ -878,7 +877,7 @@ class Orchestrator(object):
         raise NotImplementedError()
 
     def blink_device_light(self, ident_fault, on, locations):
-        # type: (str, bool, List[DeviceLightLoc]) -> WriteCompletion
+        # type: (str, bool, List[DeviceLightLoc]) -> Completion
         """
         Instructs the orchestrator to enable or disable either the ident or the fault LED.
 
@@ -1237,13 +1236,13 @@ class RGWSpec(StatelessServiceSpec):
     @property
     def rgw_multisite_endpoint_addr(self):
         """Returns the first host. Not supported for Rook."""
-        return self.hosts[0]
+        return self.placement.hosts[0]
 
     @property
     def rgw_multisite_endpoints_list(self):
         return ",".join(["{}://{}:{}".format(self.rgw_multisite_proto,
                              host,
-                             self.rgw_frontend_port) for host in self.hosts])
+                             self.rgw_frontend_port) for host in self.placement.hosts])
 
     def genkey(self, nchars):
         """ Returns a random string of nchars
@@ -1282,11 +1281,13 @@ class InventoryFilter(object):
 
     """
     def __init__(self, labels=None, nodes=None):
-        # type: (List[str], List[str]) -> None
-        self.labels = labels  # Optional: get info about nodes matching labels
-        self.nodes = nodes  # Optional: get info about certain named nodes only
+        # type: (Optional[List[str]], Optional[List[str]]) -> None
 
+        #: Optional: get info about nodes matching labels
+        self.labels = labels
 
+        #: Optional: get info about certain named nodes only
+        self.nodes = nodes
 
 
 class InventoryNode(object):
@@ -1295,7 +1296,7 @@ class InventoryNode(object):
     InventoryNode.
     """
     def __init__(self, name, devices=None, labels=None):
-        # type: (str, inventory.Devices, List[str]) -> None
+        # type: (str, Optional[inventory.Devices], Optional[List[str]]) -> None
         if devices is None:
             devices = inventory.Devices([])
         if labels is None:
@@ -1459,13 +1460,13 @@ class OutdatableData(object):
         # type: (Optional[dict], Optional[datetime.datetime]) -> None
         self._data = data
         if data is not None and last_refresh is None:
-            self.last_refresh = datetime.datetime.utcnow()
+            self.last_refresh = datetime.datetime.utcnow()  # type: Optional[datetime.datetime]
         else:
             self.last_refresh = last_refresh
 
     def json(self):
         if self.last_refresh is not None:
-            timestr = self.last_refresh.strftime(self.DATEFMT)
+            timestr = self.last_refresh.strftime(self.DATEFMT)  # type: Optional[str]
         else:
             timestr = None
 
@@ -1515,16 +1516,16 @@ class OutdatableDictMixin(object):
 
     def __getitem__(self, item):
         # type: (str) -> OutdatableData
-        return OutdatableData.from_json(super(OutdatableDictMixin, self).__getitem__(item))
+        return OutdatableData.from_json(super(OutdatableDictMixin, self).__getitem__(item))  # type: ignore
 
     def __setitem__(self, key, value):
         # type: (str, OutdatableData) -> None
         val = None if value is None else value.json()
-        super(OutdatableDictMixin, self).__setitem__(key, val)
+        super(OutdatableDictMixin, self).__setitem__(key, val)  # type: ignore
 
     def items(self):
-        # type: () -> Iterator[Tuple[str, OutdatableData]]
-        for item in super(OutdatableDictMixin, self).items():
+        ## type: () -> Iterator[Tuple[str, OutdatableData]]
+        for item in super(OutdatableDictMixin, self).items():  # type: ignore
             k, v = item
             yield k, OutdatableData.from_json(v)
 
@@ -1543,7 +1544,7 @@ class OutdatableDictMixin(object):
     def remove_outdated(self):
         outdated = [item[0] for item in self.items() if item[1].outdated()]
         for o in outdated:
-            del self[o]
+            del self[o]  # type: ignore
 
     def invalidate(self, key):
         self[key] = OutdatableData(self[key].data,
index 273aae244b4d134bb026f349dd02e5bda1944e2c..124cef13ab0125f364262124b4ca079184ab5603 100644 (file)
@@ -35,6 +35,7 @@ def ssh_module():
             mock.patch("ssh.module.SSHOrchestrator.get_store_prefix", get_store_prefix):
         SSHOrchestrator._register_commands('')
         m = SSHOrchestrator.__new__ (SSHOrchestrator)
+        m._root_logger = mock.MagicMock()
         m._store = {
             'ssh_config': '',
             'ssh_identity_key': '',
index c97800f33b2d783288d4ed99a3b5f643cfdc0ded..f6f305d0f9bdcf827ff7b473cbc42ff2358bcc07 100644 (file)
@@ -1,8 +1,15 @@
 [tox]
-envlist = py3
+envlist = py3, mypy
 skipsdist = true
 
 [testenv]
 setenv = UNITTEST = true
-deps = -rrequirements.txt
-commands = pytest -v --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
\ No newline at end of file
+deps = -r requirements.txt
+commands = pytest -v --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
+
+[testenv:mypy]
+basepython = python3
+deps =
+    -r requirements.txt
+    mypy
+commands = mypy --config-file=../../mypy.ini orchestrator.py
\ No newline at end of file