]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/rbd_support: add type annotation
authorKefu Chai <kchai@redhat.com>
Wed, 10 Feb 2021 13:38:33 +0000 (21:38 +0800)
committerKefu Chai <kchai@redhat.com>
Thu, 18 Feb 2021 14:46:51 +0000 (22:46 +0800)
Signed-off-by: Kefu Chai <kchai@redhat.com>
src/mypy.ini
src/pybind/mgr/rbd_support/common.py
src/pybind/mgr/rbd_support/mirror_snapshot_schedule.py
src/pybind/mgr/rbd_support/module.py
src/pybind/mgr/rbd_support/perf.py
src/pybind/mgr/rbd_support/schedule.py
src/pybind/mgr/rbd_support/task.py
src/pybind/mgr/rbd_support/trash_purge_schedule.py
src/pybind/mgr/tox.ini

index 2d859f6006c6fdaf4bc43195e94d8d4483969056..2c3f396fd1c7a2a4ccea51afeeebdab5c5dd9f13 100755 (executable)
@@ -49,6 +49,9 @@ disallow_untyped_defs = True
 [mypy-orchestrator.*]
 disallow_untyped_defs = True
 
+[mypy-rbd_support.*]
+disallow_untyped_defs = True
+
 [mypy-rook.*]
 disallow_untyped_defs = True
 
index f6bac8f39226df469aa347385237040192a205d6..9c9c9248fb571118528e7eb7e41008ee45041122 100644 (file)
@@ -1,23 +1,37 @@
 import re
 
+from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
+
+
 GLOBAL_POOL_KEY = (None, None)
 
 class NotAuthorizedError(Exception):
     pass
 
 
-def is_authorized(module, pool, namespace):
+if TYPE_CHECKING:
+    from rbd_support.module import Module
+
+
+def is_authorized(module: 'Module',
+                  pool: Optional[str],
+                  namespace: Optional[str]) -> bool:
     return module.is_authorized({"pool": pool or '',
                                  "namespace": namespace or ''})
 
 
-def authorize_request(module, pool, namespace):
+def authorize_request(module: 'Module',
+                      pool: Optional[str],
+                      namespace: Optional[str]) -> None:
     if not is_authorized(module, pool, namespace):
         raise NotAuthorizedError("not authorized on pool={}, namespace={}".format(
             pool, namespace))
 
 
-def extract_pool_key(pool_spec):
+PoolKeyT = Union[Tuple[str, str], Tuple[None, None]]
+
+
+def extract_pool_key(pool_spec: Optional[str]) -> PoolKeyT:
     if not pool_spec:
         return GLOBAL_POOL_KEY
 
@@ -27,7 +41,7 @@ def extract_pool_key(pool_spec):
     return (match.group(1), match.group(2) or '')
 
 
-def get_rbd_pools(module):
+def get_rbd_pools(module: 'Module') -> Dict[int, str]:
     osd_map = module.get('osd_map')
     return {pool['pool']: pool['pool_name'] for pool in osd_map['pools']
             if 'rbd' in pool.get('application_metadata', {})}
index fa8bb019410a8f949b414770615c4feca6152fa3..1b8266b8821dc03e27fe3cc8a753989d07bd7806 100644 (file)
@@ -7,19 +7,20 @@ import traceback
 
 from datetime import datetime
 from threading import Condition, Lock, Thread
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
 
 from .common import get_rbd_pools
 from .schedule import LevelSpec, Interval, StartTime, Schedule, Schedules
 
 MIRRORING_OID = "rbd_mirroring"
 
-def namespace_validator(ioctx):
+def namespace_validator(ioctx: rados.Ioctx) -> None:
     mode = rbd.RBD().mirror_mode_get(ioctx)
     if mode != rbd.RBD_MIRROR_MODE_IMAGE:
         raise ValueError("namespace {} is not in mirror image mode".format(
             ioctx.get_namespace()))
 
-def image_validator(image):
+def image_validator(image: rbd.Image) -> None:
     mode = image.mirror_image_get_mode()
     if mode != rbd.RBD_MIRROR_IMAGE_MODE_SNAPSHOT:
         raise rbd.InvalidArgument("Invalid mirror image mode")
@@ -28,18 +29,18 @@ class Watchers:
 
     lock = Lock()
 
-    def __init__(self, handler):
+    def __init__(self, handler: Any) -> None:
         self.rados = handler.module.rados
         self.log = handler.log
-        self.watchers = {}
-        self.updated = {}
-        self.error = {}
-        self.epoch = {}
+        self.watchers: Dict[Tuple[str, str], rados.Watch] = {}
+        self.updated: Dict[int, bool] = {}
+        self.error: Dict[int, str] = {}
+        self.epoch: Dict[int, int] = {}
 
-    def __del__(self):
+    def __del__(self) -> None:
         self.unregister_all()
 
-    def _clean_watcher(self, pool_id, namespace, watch_id):
+    def _clean_watcher(self, pool_id: str, namespace: str, watch_id: int) -> None:
         assert self.lock.locked()
 
         del self.watchers[pool_id, namespace]
@@ -47,7 +48,7 @@ class Watchers:
         self.error.pop(watch_id, None)
         self.epoch.pop(watch_id, None)
 
-    def check(self, pool_id, namespace, epoch):
+    def check(self, pool_id: str, namespace: str, epoch: int) -> bool:
         error = None
         with self.lock:
             watch = self.watchers.get((pool_id, namespace))
@@ -66,16 +67,16 @@ class Watchers:
         else:
             return True
 
-    def register(self, pool_id, namespace):
+    def register(self, pool_id: str, namespace: str) -> bool:
 
-        def callback(notify_id, notifier_id, watch_id, data):
+        def callback(notify_id: str, notifier_id: str, watch_id: int, data: str) -> None:
             self.log.debug("watcher {}: got notify {} from {}".format(
                 watch_id, notify_id, notifier_id))
 
             with self.lock:
                 self.updated[watch_id] = True
 
-        def error_callback(watch_id, error):
+        def error_callback(watch_id: int, error: str) -> None:
             self.log.debug("watcher {}: got errror {}".format(
                 watch_id, error))
 
@@ -100,7 +101,7 @@ class Watchers:
             self.updated[watch.get_id()] = True
         return True
 
-    def unregister(self, pool_id, namespace):
+    def unregister(self, pool_id: str, namespace: str) -> None:
 
         with self.lock:
             watch = self.watchers[pool_id, namespace]
@@ -121,14 +122,14 @@ class Watchers:
         with self.lock:
             self._clean_watcher(pool_id, namespace, watch_id)
 
-    def unregister_all(self):
+    def unregister_all(self) -> None:
         with self.lock:
             watchers = list(self.watchers)
 
         for pool_id, namespace in watchers:
             self.unregister(pool_id, namespace)
 
-    def unregister_stale(self, current_epoch):
+    def unregister_stale(self, current_epoch: int) -> None:
         with self.lock:
             watchers = list(self.watchers)
 
@@ -144,28 +145,30 @@ class Watchers:
             self.unregister(pool_id, namespace)
 
 
+ImageSpecT = Tuple[str, str, str]
+
 class CreateSnapshotRequests:
 
     lock = Lock()
     condition = Condition(lock)
 
-    def __init__(self, handler):
+    def __init__(self, handler: Any) -> None:
         self.handler = handler
         self.rados = handler.module.rados
         self.log = handler.log
-        self.pending = set()
-        self.queue = []
-        self.ioctxs = {}
+        self.pending: Set[ImageSpecT] = set()
+        self.queue: List[ImageSpecT] = []
+        self.ioctxs: Dict[Tuple[str, str], Tuple[rados.Ioctx, Set[ImageSpecT]]] = {}
 
-    def __del__(self):
+    def __del__(self) -> None:
         self.wait_for_pending()
 
-    def wait_for_pending(self):
+    def wait_for_pending(self) -> None:
         with self.lock:
             while self.pending:
                 self.condition.wait()
 
-    def add(self, pool_id, namespace, image_id):
+    def add(self, pool_id: str, namespace: str, image_id: str) -> None:
         image_spec = (pool_id, namespace, image_id)
 
         self.log.debug("CreateSnapshotRequests.add: {}/{}/{}".format(
@@ -189,7 +192,7 @@ class CreateSnapshotRequests:
 
         self.open_image(image_spec)
 
-    def open_image(self, image_spec):
+    def open_image(self, image_spec: ImageSpecT) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug("CreateSnapshotRequests.open_image: {}/{}/{}".format(
@@ -198,7 +201,7 @@ class CreateSnapshotRequests:
         try:
             ioctx = self.get_ioctx(image_spec)
 
-            def cb(comp, image):
+            def cb(comp: rados.Completion, image: rbd.Image) -> None:
                 self.handle_open_image(image_spec, comp, image)
 
             rbd.RBD().aio_open_image(cb, ioctx, image_id=image_id)
@@ -208,7 +211,10 @@ class CreateSnapshotRequests:
                     pool_id, namespace, image_id, e))
             self.finish(image_spec)
 
-    def handle_open_image(self, image_spec, comp, image):
+    def handle_open_image(self,
+                          image_spec: ImageSpecT,
+                          comp: rados.Completion,
+                          image: rbd.Image) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
@@ -224,13 +230,13 @@ class CreateSnapshotRequests:
 
         self.get_mirror_mode(image_spec, image)
 
-    def get_mirror_mode(self, image_spec, image):
+    def get_mirror_mode(self, image_spec: ImageSpecT, image: rbd.Image) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug("CreateSnapshotRequests.get_mirror_mode: {}/{}/{}".format(
             pool_id, namespace, image_id))
 
-        def cb(comp, mode):
+        def cb(comp: rados.Completion, mode: str) -> None:
             self.handle_get_mirror_mode(image_spec, image, comp, mode)
 
         try:
@@ -241,7 +247,11 @@ class CreateSnapshotRequests:
                     pool_id, namespace, image_id, e))
             self.close_image(image_spec, image)
 
-    def handle_get_mirror_mode(self, image_spec, image, comp, mode):
+    def handle_get_mirror_mode(self,
+                               image_spec: ImageSpecT,
+                               image: rbd.Image,
+                               comp: rados.Completion,
+                               mode: str) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
@@ -265,13 +275,13 @@ class CreateSnapshotRequests:
 
         self.get_mirror_info(image_spec, image)
 
-    def get_mirror_info(self, image_spec, image):
+    def get_mirror_info(self, image_spec: ImageSpecT, image: rbd.Image) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug("CreateSnapshotRequests.get_mirror_info: {}/{}/{}".format(
             pool_id, namespace, image_id))
 
-        def cb(comp, info):
+        def cb(comp: rados.Completion, info: str) -> None:
             self.handle_get_mirror_info(image_spec, image, comp, info)
 
         try:
@@ -282,7 +292,11 @@ class CreateSnapshotRequests:
                     pool_id, namespace, image_id, e))
             self.close_image(image_spec, image)
 
-    def handle_get_mirror_info(self, image_spec, image, comp, info):
+    def handle_get_mirror_info(self,
+                               image_spec: ImageSpecT,
+                               image: rbd.Image,
+                               comp: rados.Completion,
+                               info: str) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
@@ -298,14 +312,14 @@ class CreateSnapshotRequests:
 
         self.create_snapshot(image_spec, image)
 
-    def create_snapshot(self, image_spec, image):
+    def create_snapshot(self, image_spec: ImageSpecT, image: rbd.Image) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
             "CreateSnapshotRequests.create_snapshot for {}/{}/{}".format(
                 pool_id, namespace, image_id))
 
-        def cb(comp, snap_id):
+        def cb(comp: rados.Completion, snap_id: str) -> None:
             self.handle_create_snapshot(image_spec, image, comp, snap_id)
 
         try:
@@ -317,7 +331,11 @@ class CreateSnapshotRequests:
             self.close_image(image_spec, image)
 
 
-    def handle_create_snapshot(self, image_spec, image, comp, snap_id):
+    def handle_create_snapshot(self,
+                               image_spec: ImageSpecT,
+                               image: rbd.Image,
+                               comp: rados.Completion,
+                               snap_id: str) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
@@ -331,14 +349,14 @@ class CreateSnapshotRequests:
 
         self.close_image(image_spec, image)
 
-    def close_image(self, image_spec, image):
+    def close_image(self, image_spec: ImageSpecT, image: rbd.Image) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
             "CreateSnapshotRequests.close_image {}/{}/{}".format(
                 pool_id, namespace, image_id))
 
-        def cb(comp):
+        def cb(comp: rados.Completion) -> None:
             self.handle_close_image(image_spec, comp)
 
         try:
@@ -349,7 +367,9 @@ class CreateSnapshotRequests:
                     pool_id, namespace, image_id, e))
             self.finish(image_spec)
 
-    def handle_close_image(self, image_spec, comp):
+    def handle_close_image(self,
+                           image_spec: ImageSpecT,
+                           comp: rados.Completion) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug(
@@ -363,7 +383,7 @@ class CreateSnapshotRequests:
 
         self.finish(image_spec)
 
-    def finish(self, image_spec):
+    def finish(self, image_spec: ImageSpecT) -> None:
         pool_id, namespace, image_id = image_spec
 
         self.log.debug("CreateSnapshotRequests.finish: {}/{}/{}".format(
@@ -379,7 +399,7 @@ class CreateSnapshotRequests:
 
         self.open_image(image_spec)
 
-    def get_ioctx(self, image_spec):
+    def get_ioctx(self, image_spec: ImageSpecT) -> rados.Ioctx:
         pool_id, namespace, image_id = image_spec
         nspec = (pool_id, namespace)
 
@@ -390,11 +410,12 @@ class CreateSnapshotRequests:
                 ioctx.set_namespace(namespace)
                 images = set()
                 self.ioctxs[nspec] = (ioctx, images)
+            assert images is not None
             images.add(image_spec)
 
         return ioctx
 
-    def put_ioctx(self, image_spec):
+    def put_ioctx(self, image_spec: ImageSpecT) -> None:
         pool_id, namespace, image_id = image_spec
         nspec = (pool_id, namespace)
 
@@ -414,7 +435,7 @@ class MirrorSnapshotScheduleHandler:
     condition = Condition(lock)
     thread = None
 
-    def __init__(self, module):
+    def __init__(self, module: Any) -> None:
         self.module = module
         self.log = module.log
         self.last_refresh_images = datetime(1970, 1, 1)
@@ -425,11 +446,11 @@ class MirrorSnapshotScheduleHandler:
         self.thread = Thread(target=self.run)
         self.thread.start()
 
-    def _cleanup(self):
+    def _cleanup(self) -> None:
         self.watchers.unregister_all()
         self.create_snapshot_requests.wait_for_pending()
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.log.info("MirrorSnapshotScheduleHandler: starting")
             while True:
@@ -448,14 +469,16 @@ class MirrorSnapshotScheduleHandler:
             self.log.fatal("Fatal runtime error: {}\n{}".format(
                 ex, traceback.format_exc()))
 
-    def init_schedule_queue(self):
-        self.queue = {}
-        self.images = {}
+    def init_schedule_queue(self) -> None:
+        # schedule_time => image_spec
+        self.queue: Dict[str, List[ImageSpecT]] = {}
+        # pool_id => {namespace => image_id}
+        self.images: Dict[str, Dict[str, Dict[str, str]]] = {}
         self.watchers = Watchers(self)
         self.refresh_images()
         self.log.debug("scheduler queue is initialized")
 
-    def load_schedules(self):
+    def load_schedules(self) -> None:
         self.log.info("MirrorSnapshotScheduleHandler: load_schedules")
 
         schedules = Schedules(self)
@@ -463,7 +486,7 @@ class MirrorSnapshotScheduleHandler:
         with self.lock:
             self.schedules = schedules
 
-    def refresh_images(self):
+    def refresh_images(self) -> None:
         if (datetime.now() - self.last_refresh_images).seconds < 60:
             return
 
@@ -480,7 +503,7 @@ class MirrorSnapshotScheduleHandler:
                 return
 
         epoch = int(datetime.now().strftime('%s'))
-        images = {}
+        images: Dict[str, Dict[str, Dict[str, str]]] = {}
 
         for pool_id, pool_name in get_rbd_pools(self.module).items():
             if not self.schedules.intersects(
@@ -496,7 +519,10 @@ class MirrorSnapshotScheduleHandler:
         self.watchers.unregister_stale(epoch)
         self.last_refresh_images = datetime.now()
 
-    def load_pool_images(self, ioctx, epoch, images):
+    def load_pool_images(self,
+                         ioctx: rados.Ioctx,
+                         epoch: int,
+                         images: Dict[str, Dict[str, Dict[str, str]]]) -> None:
         pool_id = str(ioctx.get_pool_id())
         pool_name = ioctx.get_pool_name()
         images[pool_id] = {}
@@ -507,7 +533,7 @@ class MirrorSnapshotScheduleHandler:
             namespaces = [''] + rbd.RBD().namespace_list(ioctx)
             for namespace in namespaces:
                 if not self.schedules.intersects(
-                        LevelSpec.from_pool_spec(pool_id, pool_name, namespace)):
+                        LevelSpec.from_pool_spec(int(pool_id), pool_name, namespace)):
                     continue
                 self.log.debug("load_pool_images: pool={}, namespace={}".format(
                     pool_name, namespace))
@@ -546,7 +572,7 @@ class MirrorSnapshotScheduleHandler:
                 "load_pool_images: exception when scanning pool {}: {}".format(
                     pool_name, e))
 
-    def rebuild_queue(self):
+    def rebuild_queue(self) -> None:
         with self.lock:
             now = datetime.now()
 
@@ -567,7 +593,8 @@ class MirrorSnapshotScheduleHandler:
 
             self.condition.notify()
 
-    def refresh_queue(self, current_images):
+    def refresh_queue(self,
+                      current_images: Dict[str, Dict[str, Dict[str, str]]]) -> None:
         now = datetime.now()
 
         for pool_id in self.images:
@@ -588,7 +615,7 @@ class MirrorSnapshotScheduleHandler:
 
         self.condition.notify()
 
-    def enqueue(self, now, pool_id, namespace, image_id):
+    def enqueue(self, now: datetime, pool_id: str, namespace: str, image_id: str) -> None:
 
         schedule = self.schedules.find(pool_id, namespace, image_id)
         if not schedule:
@@ -603,9 +630,9 @@ class MirrorSnapshotScheduleHandler:
         if image_spec not in self.queue[schedule_time]:
             self.queue[schedule_time].append((pool_id, namespace, image_id))
 
-    def dequeue(self):
+    def dequeue(self) -> Tuple[Optional[Tuple[str, str, str]], float]:
         if not self.queue:
-            return None, 1000
+            return None, 1000.0
 
         now = datetime.now()
         schedule_time = sorted(self.queue)[0]
@@ -619,9 +646,9 @@ class MirrorSnapshotScheduleHandler:
         image = images.pop(0)
         if not images:
             del self.queue[schedule_time]
-        return image, 0
+        return image, 0.0
 
-    def remove_from_queue(self, pool_id, namespace, image_id):
+    def remove_from_queue(self, pool_id: str, namespace: str, image_id: str) -> None:
         empty_slots = []
         for schedule_time, images in self.queue.items():
             if (pool_id, namespace, image_id) in images:
@@ -631,7 +658,10 @@ class MirrorSnapshotScheduleHandler:
         for schedule_time in empty_slots:
             del self.queue[schedule_time]
 
-    def add_schedule(self, level_spec, interval, start_time):
+    def add_schedule(self,
+                     level_spec: LevelSpec,
+                     interval: str,
+                     start_time: Optional[str]) -> Tuple[int, str, str]:
         self.log.debug(
             "add_schedule: level_spec={}, interval={}, start_time={}".format(
                 level_spec.name, interval, start_time))
@@ -643,7 +673,10 @@ class MirrorSnapshotScheduleHandler:
         self.rebuild_queue()
         return 0, "", ""
 
-    def remove_schedule(self, level_spec, interval, start_time):
+    def remove_schedule(self,
+                        level_spec: LevelSpec,
+                        interval: Optional[str],
+                        start_time: Optional[str]) -> Tuple[int, str, str]:
         self.log.debug(
             "remove_schedule: level_spec={}, interval={}, start_time={}".format(
                 level_spec.name, interval, start_time))
@@ -655,7 +688,7 @@ class MirrorSnapshotScheduleHandler:
         self.rebuild_queue()
         return 0, "", ""
 
-    def list(self, level_spec):
+    def list(self, level_spec: LevelSpec) -> Tuple[int, str, str]:
         self.log.debug("list: level_spec={}".format(level_spec.name))
 
         with self.lock:
@@ -663,7 +696,7 @@ class MirrorSnapshotScheduleHandler:
 
         return 0, json.dumps(result, indent=4, sort_keys=True), ""
 
-    def status(self, level_spec):
+    def status(self, level_spec: LevelSpec) -> Tuple[int, str, str]:
         self.log.debug("status: level_spec={}".format(level_spec.name))
 
         scheduled_images = []
index c8526cd9eec78542a988b206a88c3ac698615ad5..2b29b1f1a74a7172dd444759e7d04a908e4c8359 100644 (file)
@@ -73,12 +73,7 @@ class Module(MgrModule):
         Option(name=TrashPurgeScheduleHandler.MODULE_OPTION_NAME),
     ]
 
-    mirror_snapshot_schedule = None
-    perf = None
-    task = None
-    trash_purge_schedule = None
-
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
         super(Module, self).__init__(*args, **kwargs)
         self.rados.wait_for_latest_osdmap()
         self.mirror_snapshot_schedule = MirrorSnapshotScheduleHandler(self)
index 525ed35d8c94befdb5429c3b3abf3e26eb3b45a3..572d75f5b8b9b77b88d35e9aa19518c2a4b58fcd 100644 (file)
@@ -7,9 +7,10 @@ import traceback
 
 from datetime import datetime, timedelta
 from threading import Condition, Lock, Thread
+from typing import cast, Any, Callable, Dict, List, Optional, Set, Tuple, Union
 
 from .common import (GLOBAL_POOL_KEY, authorize_request, extract_pool_key,
-                     get_rbd_pools)
+                     get_rbd_pools, PoolKeyT)
 
 QUERY_POOL_ID = "pool_id"
 QUERY_POOL_ID_MAP = "pool_id_map"
@@ -38,24 +39,52 @@ STATS_RATE_INTERVAL = timedelta(minutes=1)
 REPORT_MAX_RESULTS = 64
 
 
+# {(pool_id, namespace)...}
+ResolveImageNamesT = Set[Tuple[int, str]]
+
+# (time, [value,...])
+PerfCounterT = Tuple[int, List[int]]
+# current, previous
+RawImageCounterT = Tuple[PerfCounterT, Optional[PerfCounterT]]
+# image_id => perf_counter
+RawImagesCounterT = Dict[str, RawImageCounterT]
+# namespace_counters => raw_images
+RawNamespacesCountersT = Dict[str, RawImagesCounterT]
+# pool_id => namespaces_counters
+RawPoolCountersT = Dict[int, RawNamespacesCountersT]
+
+SumImageCounterT = List[int]
+# image_id => sum_image
+SumImagesCounterT = Dict[str, SumImageCounterT]
+# namespace => sum_images
+SumNamespacesCountersT = Dict[str, SumImagesCounterT]
+# pool_id, sum_namespaces
+SumPoolCountersT = Dict[int, SumNamespacesCountersT]
+
+ExtractDataFuncT = Callable[[int, Optional[RawImageCounterT], SumImageCounterT], float]
+
+
 class PerfHandler:
-    user_queries = {}
-    image_cache = {}
+    user_queries: Dict[PoolKeyT, Dict[str, Any]] = {}
+    image_cache: Dict[str, str] = {}
 
     lock = Lock()
     query_condition = Condition(lock)
     refresh_condition = Condition(lock)
     thread = None
 
-    image_name_cache = {}
+    image_name_cache: Dict[Tuple[int, str], Dict[str, str]] = {}
     image_name_refresh_time = datetime.fromtimestamp(0)
 
     @classmethod
-    def prepare_regex(cls, value):
+    def prepare_regex(cls, value: Any) -> str:
         return '^({})$'.format(value)
 
     @classmethod
-    def prepare_osd_perf_query(cls, pool_id, namespace, counter_type):
+    def prepare_osd_perf_query(cls,
+                               pool_id: Optional[int],
+                               namespace: Optional[str],
+                               counter_type: str) -> Dict[str, Any]:
         pool_id_regex = OSD_PERF_QUERY_REGEX_MATCH_ALL
         namespace_regex = OSD_PERF_QUERY_REGEX_MATCH_ALL
         if pool_id:
@@ -76,23 +105,23 @@ class PerfHandler:
         }
 
     @classmethod
-    def pool_spec_search_keys(cls, pool_key):
+    def pool_spec_search_keys(cls, pool_key: str) -> List[str]:
         return [pool_key[0:len(pool_key) - x]
                 for x in range(0, len(pool_key) + 1)]
 
     @classmethod
-    def submatch_pool_key(cls, pool_key, search_key):
+    def submatch_pool_key(cls, pool_key: PoolKeyT, search_key: str) -> bool:
         return ((pool_key[1] == search_key[1] or not search_key[1])
                 and (pool_key[0] == search_key[0] or not search_key[0]))
 
-    def __init__(self, module):
+    def __init__(self, module: Any) -> None:
         self.module = module
         self.log = module.log
 
         self.thread = Thread(target=self.run)
         self.thread.start()
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.log.info("PerfHandler: starting")
             while True:
@@ -110,12 +139,15 @@ class PerfHandler:
             self.log.fatal("Fatal runtime error: {}\n{}".format(
                 ex, traceback.format_exc()))
 
-    def merge_raw_osd_perf_counters(self, pool_key, query, now_ts,
-                                    resolve_image_names):
+    def merge_raw_osd_perf_counters(self,
+                                    pool_key: PoolKeyT,
+                                    query: Dict[str, Any],
+                                    now_ts: int,
+                                    resolve_image_names: ResolveImageNamesT) -> RawPoolCountersT:
         pool_id_map = query[QUERY_POOL_ID_MAP]
 
         # collect and combine the raw counters from all sort orders
-        raw_pool_counters = query.setdefault(QUERY_RAW_POOL_COUNTERS, {})
+        raw_pool_counters: Dict[int, Dict[str, Dict[str, Any]]] = query.setdefault(QUERY_RAW_POOL_COUNTERS, {})
         for query_id in query[QUERY_IDS]:
             res = self.module.get_osd_perf_counters(query_id)
             for counter in res['counters']:
@@ -141,19 +173,23 @@ class PerfHandler:
                 # if we haven't already processed it for this round
                 raw_namespaces = raw_pool_counters.setdefault(pool_id, {})
                 raw_images = raw_namespaces.setdefault(namespace, {})
-                raw_image = raw_images.setdefault(image_id, [None, None])
-
+                raw_image = raw_images.get(image_id)
                 # save the last two perf counters for each image
-                if raw_image[0] and raw_image[0][0] < now_ts:
-                    raw_image[1] = raw_image[0]
-                    raw_image[0] = None
-                if not raw_image[0]:
-                    raw_image[0] = [now_ts, [int(x[0]) for x in counter['c']]]
+                new_current = (now_ts, [int(x[0]) for x in counter['c']])
+                if raw_image:
+                    old_current, _ = raw_image
+                    if old_current[0] < now_ts:
+                        raw_images[image_id] = (new_current, old_current)
+                else:
+                    raw_images[image_id] = (new_current, None)
 
         self.log.debug("merge_raw_osd_perf_counters: {}".format(raw_pool_counters))
         return raw_pool_counters
 
-    def sum_osd_perf_counters(self, query, raw_pool_counters, now_ts):
+    def sum_osd_perf_counters(self,
+                              query: Dict[str, dict],
+                              raw_pool_counters: RawPoolCountersT,
+                              now_ts: int) -> SumPoolCountersT:
         # update the cumulative counters for each image
         sum_pool_counters = query.setdefault(QUERY_SUM_POOL_COUNTERS, {})
         for pool_id, raw_namespaces in raw_pool_counters.items():
@@ -164,12 +200,13 @@ class PerfHandler:
                     # zero-out non-updated raw counters
                     if not raw_image[0]:
                         continue
-                    elif raw_image[0][0] < now_ts:
-                        raw_image[1] = raw_image[0]
-                        raw_image[0] = [now_ts, [0 for x in raw_image[1][1]]]
+                    old_current, _ = raw_image
+                    if old_current[0] < now_ts:
+                        new_current = (now_ts, [0] * len(old_current[1]))
+                        raw_images[image_id] = (new_current, old_current)
                         continue
 
-                    counters = raw_image[0][1]
+                    counters = old_current[1]
 
                     # copy raw counters if this is a newly discovered image or
                     # increment existing counters
@@ -183,7 +220,7 @@ class PerfHandler:
         self.log.debug("sum_osd_perf_counters: {}".format(sum_pool_counters))
         return sum_pool_counters
 
-    def refresh_image_names(self, resolve_image_names):
+    def refresh_image_names(self, resolve_image_names: ResolveImageNamesT) -> None:
         for pool_id, namespace in resolve_image_names:
             image_key = (pool_id, namespace)
             images = self.image_name_cache.setdefault(image_key, {})
@@ -193,7 +230,7 @@ class PerfHandler:
                     images[image_meta['id']] = image_meta['name']
             self.log.debug("resolve_image_names: {}={}".format(image_key, images))
 
-    def scrub_missing_images(self):
+    def scrub_missing_images(self) -> None:
         for pool_key, query in self.user_queries.items():
             raw_pool_counters = query.get(QUERY_RAW_POOL_COUNTERS, {})
             sum_pool_counters = query.get(QUERY_SUM_POOL_COUNTERS, {})
@@ -213,7 +250,7 @@ class PerfHandler:
                             if image_id in raw_images:
                                 del raw_images[image_id]
 
-    def process_raw_osd_perf_counters(self):
+    def process_raw_osd_perf_counters(self) -> None:
         now = datetime.now()
         now_ts = int(now.strftime("%s"))
 
@@ -223,7 +260,7 @@ class PerfHandler:
             self.log.debug("process_raw_osd_perf_counters: expiring image name cache")
             self.image_name_cache = {}
 
-        resolve_image_names = set()
+        resolve_image_names: Set[Tuple[int, str]] = set()
         for pool_key, query in self.user_queries.items():
             if not query[QUERY_IDS]:
                 continue
@@ -239,14 +276,14 @@ class PerfHandler:
         elif not self.image_name_cache:
             self.scrub_missing_images()
 
-    def resolve_pool_id(self, pool_name):
+    def resolve_pool_id(self, pool_name: str) -> int:
         pool_id = self.module.rados.pool_lookup(pool_name)
         if not pool_id:
             raise rados.ObjectNotFound("Pool '{}' not found".format(pool_name),
                                        errno.ENOENT)
         return pool_id
 
-    def scrub_expired_queries(self):
+    def scrub_expired_queries(self) -> None:
         # perf counters need to be periodically refreshed to continue
         # to be registered
         expire_time = datetime.now() - QUERY_EXPIRE_INTERVAL
@@ -256,7 +293,9 @@ class PerfHandler:
                 self.unregister_osd_perf_queries(pool_key, user_query[QUERY_IDS])
                 del self.user_queries[pool_key]
 
-    def register_osd_perf_queries(self, pool_id, namespace):
+    def register_osd_perf_queries(self,
+                                  pool_id: Optional[int],
+                                  namespace: Optional[str]) -> List[int]:
         query_ids = []
         try:
             for counter in OSD_PERF_QUERY_COUNTERS:
@@ -275,23 +314,24 @@ class PerfHandler:
 
         return query_ids
 
-    def unregister_osd_perf_queries(self, pool_key, query_ids):
+    def unregister_osd_perf_queries(self, pool_key: PoolKeyT, query_ids: List[int]) -> None:
         self.log.info("unregister_osd_perf_queries: pool_key={}, query_ids={}".format(
             pool_key, query_ids))
         for query_id in query_ids:
             self.module.remove_osd_perf_query(query_id)
         query_ids[:] = []
 
-    def register_query(self, pool_key):
+    def register_query(self, pool_key: PoolKeyT) -> Dict[str, Any]:
         if pool_key not in self.user_queries:
+            pool_name, namespace = pool_key
             pool_id = None
-            if pool_key[0]:
-                pool_id = self.resolve_pool_id(pool_key[0])
+            if pool_name:
+                pool_id = self.resolve_pool_id(cast(str, pool_name))
 
             user_query = {
                 QUERY_POOL_ID: pool_id,
-                QUERY_POOL_ID_MAP: {pool_id: pool_key[0]},
-                QUERY_IDS: self.register_osd_perf_queries(pool_id, pool_key[1]),
+                QUERY_POOL_ID_MAP: {pool_id: pool_name},
+                QUERY_IDS: self.register_osd_perf_queries(pool_id, namespace),
                 QUERY_LAST_REQUEST: datetime.now()
             }
 
@@ -319,18 +359,22 @@ class PerfHandler:
 
         return user_query
 
-    def extract_stat(self, index, raw_image, sum_image):
+    def extract_stat(self,
+                     index: int,
+                     raw_image: Optional[RawImageCounterT],
+                     sum_image: Any) -> float:
         # require two raw counters between a fixed time window
         if not raw_image or not raw_image[0] or not raw_image[1]:
             return 0
 
-        current_time = raw_image[0][0]
-        previous_time = raw_image[1][0]
+        current_counter, previous_counter = cast(Tuple[PerfCounterT, PerfCounterT], raw_image)
+        current_time = current_counter[0]
+        previous_time = previous_counter[0]
         if current_time <= previous_time or \
                 current_time - previous_time > STATS_RATE_INTERVAL.total_seconds():
             return 0
 
-        current_value = raw_image[0][1][index]
+        current_value = current_counter[1][index]
         instant_rate = float(current_value) / (current_time - previous_time)
 
         # convert latencies from sum to average per op
@@ -346,15 +390,28 @@ class PerfHandler:
 
         return instant_rate
 
-    def extract_counter(self, index, raw_image, sum_image):
+    def extract_counter(self,
+                        index: int,
+                        raw_image: Optional[RawImageCounterT],
+                        sum_image: List[int]) -> int:
         if sum_image:
             return sum_image[index]
         return 0
 
-    def generate_report(self, query, sort_by, extract_data):
-        pool_id_map = query[QUERY_POOL_ID_MAP]
-        sum_pool_counters = query.setdefault(QUERY_SUM_POOL_COUNTERS, {})
-        raw_pool_counters = query.setdefault(QUERY_RAW_POOL_COUNTERS, {})
+    def generate_report(self,
+                        query: Dict[str, Union[Dict[str, str],
+                                               Dict[int, Dict[str, dict]]]],
+                        sort_by: str,
+                        extract_data: ExtractDataFuncT) -> Tuple[Dict[int, str],
+                                                                 List[Dict[str, List[float]]]]:
+        pool_id_map = cast(Dict[int, str], query[QUERY_POOL_ID_MAP])
+        sum_pool_counters = cast(SumPoolCountersT,
+                                 query.setdefault(QUERY_SUM_POOL_COUNTERS,
+                                                  cast(SumPoolCountersT, {})))
+        # pool_id => {namespace => {image_id => [counter..] }
+        raw_pool_counters = cast(RawPoolCountersT,
+                                 query.setdefault(QUERY_RAW_POOL_COUNTERS,
+                                                  cast(RawPoolCountersT, {})))
 
         sort_by_index = OSD_PERF_QUERY_COUNTERS.index(sort_by)
 
@@ -363,20 +420,20 @@ class PerfHandler:
         for pool_id, sum_namespaces in sum_pool_counters.items():
             if pool_id not in pool_id_map:
                 continue
-            raw_namespaces = raw_pool_counters.get(pool_id, {})
+            raw_namespaces: RawNamespacesCountersT = raw_pool_counters.get(pool_id, {})
             for namespace, sum_images in sum_namespaces.items():
                 raw_images = raw_namespaces.get(namespace, {})
                 for image_id, sum_image in sum_images.items():
-                    raw_image = raw_images.get(image_id, [])
+                    raw_image = raw_images.get(image_id)
 
                     # always sort by recent IO activity
-                    results.append([(pool_id, namespace, image_id),
+                    results.append(((pool_id, namespace, image_id),
                                     self.extract_stat(sort_by_index, raw_image,
-                                                      sum_image)])
+                                                      sum_image)))
         results = sorted(results, key=lambda x: x[1], reverse=True)[:REPORT_MAX_RESULTS]
 
         # build the report in sorted order
-        pool_descriptors = {}
+        pool_descriptors: Dict[str, int] = {}
         counters = []
         for key, _ in results:
             pool_id = key[0]
@@ -389,7 +446,7 @@ class PerfHandler:
 
             raw_namespaces = raw_pool_counters.get(pool_id, {})
             raw_images = raw_namespaces.get(namespace, {})
-            raw_image = raw_images.get(image_id, [])
+            raw_image = raw_images.get(image_id)
 
             sum_namespaces = sum_pool_counters[pool_id]
             sum_images = sum_namespaces[namespace]
@@ -414,14 +471,17 @@ class PerfHandler:
                 in pool_descriptors.items()}, \
             counters
 
-    def get_perf_data(self, report, pool_spec, sort_by, extract_data):
+    def get_perf_data(self,
+                      report: str,
+                      pool_spec: Optional[str],
+                      sort_by: str,
+                      extract_data: ExtractDataFuncT) -> Tuple[int, str, str]:
         self.log.debug("get_perf_{}s: pool_spec={}, sort_by={}".format(
             report, pool_spec, sort_by))
         self.scrub_expired_queries()
 
         pool_key = extract_pool_key(pool_spec)
         authorize_request(self.module, pool_key[0], pool_key[1])
-
         user_query = self.register_query(pool_key)
 
         now = datetime.now()
@@ -437,10 +497,14 @@ class PerfHandler:
 
         return 0, json.dumps(report), ""
 
-    def get_perf_stats(self, pool_spec, sort_by):
+    def get_perf_stats(self,
+                       pool_spec: Optional[str],
+                       sort_by: str) -> Tuple[int, str, str]:
         return self.get_perf_data(
             "stat", pool_spec, sort_by, self.extract_stat)
 
-    def get_perf_counters(self, pool_spec, sort_by):
+    def get_perf_counters(self,
+                          pool_spec: Optional[str],
+                          sort_by: str) -> Tuple[int, str, str]:
         return self.get_perf_data(
             "counter", pool_spec, sort_by, self.extract_counter)
index 50e0f6cdc4481e26e3a5258cc85944ef98978f8e..ffcb78a0741226d057c8d17e00d1194b6c1ba7cf 100644 (file)
@@ -5,8 +5,11 @@ import rbd
 import re
 
 from dateutil.parser import parse
+from typing import cast, Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
 
 from .common import get_rbd_pools
+if TYPE_CHECKING:
+    from .module import Module
 
 SCHEDULE_INTERVAL = "interval"
 SCHEDULE_START_TIME = "start_time"
@@ -14,17 +17,22 @@ SCHEDULE_START_TIME = "start_time"
 
 class LevelSpec:
 
-    def __init__(self, name, id, pool_id, namespace, image_id=None):
+    def __init__(self,
+                 name: str,
+                 id: str,
+                 pool_id: Optional[str],
+                 namespace: Optional[str],
+                 image_id: Optional[str] = None) -> None:
         self.name = name
         self.id = id
         self.pool_id = pool_id
         self.namespace = namespace
         self.image_id = image_id
 
-    def __eq__(self, level_spec):
+    def __eq__(self, level_spec: Any) -> bool:
         return self.id == level_spec.id
 
-    def is_child_of(self, level_spec):
+    def is_child_of(self, level_spec: 'LevelSpec') -> bool:
         if level_spec.is_global():
             return not self.is_global()
         if level_spec.pool_id != self.pool_id:
@@ -37,13 +45,16 @@ class LevelSpec:
             return self.image_id is not None
         return False
 
-    def is_global(self):
+    def is_global(self) -> bool:
         return self.pool_id is None
 
-    def get_pool_id(self):
+    def get_pool_id(self) -> Optional[str]:
         return self.pool_id
 
-    def matches(self, pool_id, namespace, image_id=None):
+    def matches(self,
+                pool_id: str,
+                namespace: str,
+                image_id: Optional[str] = None) -> bool:
         if self.pool_id and self.pool_id != pool_id:
             return False
         if self.namespace and self.namespace != namespace:
@@ -52,7 +63,7 @@ class LevelSpec:
             return False
         return True
 
-    def intersects(self, level_spec):
+    def intersects(self, level_spec: 'LevelSpec') -> bool:
         if self.pool_id is None or level_spec.pool_id is None:
             return True
         if self.pool_id != level_spec.pool_id:
@@ -68,11 +79,14 @@ class LevelSpec:
         return True
 
     @classmethod
-    def make_global(cls):
+    def make_global(cls) -> 'LevelSpec':
         return LevelSpec("", "", None, None, None)
 
     @classmethod
-    def from_pool_spec(cls, pool_id, pool_name, namespace=None):
+    def from_pool_spec(cls,
+                       pool_id: int,
+                       pool_name: str,
+                       namespace: Optional[str] = None) -> 'LevelSpec':
         if namespace is None:
             id = "{}".format(pool_id)
             name = "{}/".format(pool_name)
@@ -82,8 +96,12 @@ class LevelSpec:
         return LevelSpec(name, id, str(pool_id), namespace, None)
 
     @classmethod
-    def from_name(cls, module, name, namespace_validator=None,
-                  image_validator=None, allow_image_level=True):
+    def from_name(cls,
+                  module: 'Module',
+                  name: str,
+                  namespace_validator: Optional[Callable] = None,
+                  image_validator: Optional[Callable] = None,
+                  allow_image_level: bool = True) -> 'LevelSpec':
         # parse names like:
         # '', 'rbd/', 'rbd/ns/', 'rbd//image', 'rbd/image', 'rbd/ns/image'
         match = re.match(r'^(?:([^/]+)/(?:(?:([^/]*)/|)(?:([^/@]+))?)?)?$',
@@ -107,8 +125,7 @@ class LevelSpec:
                     raise ValueError("pool {} does not exist".format(pool_name))
                 if pool_id not in get_rbd_pools(module):
                     raise ValueError("{} is not an RBD pool".format(pool_name))
-                pool_id = str(pool_id)
-                id += pool_id
+                id += str(pool_id)
                 if match.group(2) is not None or match.group(3):
                     id += "/"
                     with module.rados.open_ioctx(pool_name) as ioctx:
@@ -150,8 +167,11 @@ class LevelSpec:
         return LevelSpec(name, id, pool_id, namespace, image_id)
 
     @classmethod
-    def from_id(cls, handler, id, namespace_validator=None,
-                image_validator=None):
+    def from_id(cls,
+                handler: Any,
+                id: str,
+                namespace_validator: Optional[Callable] = None,
+                image_validator: Optional[Callable] = None) -> 'LevelSpec':
         # parse ids like:
         # '', '123', '123/', '123/ns', '123//image_id', '123/ns/image_id'
         match = re.match(r'^(?:(\d+)(?:/([^/]*)(?:/([^/@]+))?)?)?$', id)
@@ -209,16 +229,16 @@ class LevelSpec:
 
 class Interval:
 
-    def __init__(self, minutes):
+    def __init__(self, minutes: int) -> None:
         self.minutes = minutes
 
-    def __eq__(self, interval):
+    def __eq__(self, interval: Any) -> bool:
         return self.minutes == interval.minutes
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return hash(self.minutes)
 
-    def to_string(self):
+    def to_string(self) -> str:
         if self.minutes % (60 * 24) == 0:
             interval = int(self.minutes / (60 * 24))
             units = 'd'
@@ -232,7 +252,7 @@ class Interval:
         return "{}{}".format(interval, units)
 
     @classmethod
-    def from_string(cls, interval):
+    def from_string(cls, interval: str) -> 'Interval':
         match = re.match(r'^(\d+)(d|h|m)?$', interval)
         if not match:
             raise ValueError("Invalid interval ({})".format(interval))
@@ -248,23 +268,27 @@ class Interval:
 
 class StartTime:
 
-    def __init__(self, hour, minute, tzinfo):
+    def __init__(self,
+                 hour: int,
+                 minute: int,
+                 tzinfo: Optional[datetime.tzinfo]) -> None:
         self.time = datetime.time(hour, minute, tzinfo=tzinfo)
         self.minutes = self.time.hour * 60 + self.time.minute
         if self.time.tzinfo:
-            self.minutes += int(self.time.utcoffset().seconds / 60)
+            utcoffset = cast(datetime.timedelta, self.time.utcoffset())
+            self.minutes += int(utcoffset.seconds / 60)
 
-    def __eq__(self, start_time):
+    def __eq__(self, start_time: Any) -> bool:
         return self.minutes == start_time.minutes
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return hash(self.minutes)
 
-    def to_string(self):
+    def to_string(self) -> str:
         return self.time.isoformat()
 
     @classmethod
-    def from_string(cls, start_time):
+    def from_string(cls, start_time: Optional[str]) -> Optional['StartTime']:
         if not start_time:
             return None
 
@@ -278,26 +302,31 @@ class StartTime:
 
 class Schedule:
 
-    def __init__(self, name):
+    def __init__(self, name: str) -> None:
         self.name = name
-        self.items = set()
+        self.items: Set[Tuple[Interval, Optional[StartTime]]] = set()
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.items)
 
-    def add(self, interval, start_time=None):
+    def add(self,
+            interval: Interval,
+            start_time: Optional[StartTime] = None) -> None:
         self.items.add((interval, start_time))
 
-    def remove(self, interval, start_time=None):
+    def remove(self,
+               interval: Interval,
+               start_time: Optional[StartTime] = None) -> None:
         self.items.discard((interval, start_time))
 
-    def next_run(self, now):
+    def next_run(self, now: datetime.datetime) -> str:
         schedule_time = None
-        for item in self.items:
-            period = datetime.timedelta(minutes=item[0].minutes)
+        for interval, opt_start in self.items:
+            period = datetime.timedelta(minutes=interval.minutes)
             start_time = datetime.datetime(1970, 1, 1)
-            if item[1]:
-                start_time += datetime.timedelta(minutes=item[1].minutes)
+            if opt_start:
+                start = cast(StartTime, opt_start)
+                start_time += datetime.timedelta(minutes=start.minutes)
             time = start_time + \
                 (int((now - start_time) / period) + 1) * period
             if schedule_time is None or time < schedule_time:
@@ -306,16 +335,23 @@ class Schedule:
             raise ValueError('no items is added')
         return datetime.datetime.strftime(schedule_time, "%Y-%m-%d %H:%M:00")
 
-    def to_list(self):
-        return [{SCHEDULE_INTERVAL: i[0].to_string(),
-                 SCHEDULE_START_TIME: i[1] and i[1].to_string() or None}
-                for i in self.items]
+    def to_list(self) -> List[Dict[str, Optional[str]]]:
+        def item_to_dict(interval: Interval,
+                         start_time: Optional[StartTime]) -> Dict[str, Optional[str]]:
+            if start_time:
+                schedule_start_time: Optional[str] = start_time.to_string()
+            else:
+                schedule_start_time = None
+            return {SCHEDULE_INTERVAL: interval.to_string(),
+                    SCHEDULE_START_TIME: schedule_start_time}
+        return [item_to_dict(interval, start_time)
+                for interval, start_time in self.items]
 
-    def to_json(self):
+    def to_json(self) -> str:
         return json.dumps(self.to_list(), indent=4, sort_keys=True)
 
     @classmethod
-    def from_json(cls, name, val):
+    def from_json(cls, name: str, val: str) -> 'Schedule':
         try:
             items = json.loads(val)
             schedule = Schedule(name)
@@ -333,17 +369,20 @@ class Schedule:
         except TypeError as e:
             raise ValueError("Invalid schedule format ({})".format(str(e)))
 
+
 class Schedules:
 
-    def __init__(self, handler):
+    def __init__(self, handler: Any) -> None:
         self.handler = handler
-        self.level_specs = {}
-        self.schedules = {}
+        self.level_specs: Dict[str, LevelSpec] = {}
+        self.schedules: Dict[str, Schedule] = {}
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.schedules)
 
-    def load(self, namespace_validator=None, image_validator=None):
+    def load(self,
+             namespace_validator: Optional[Callable] = None,
+             image_validator: Optional[Callable] = None) -> None:
 
         schedule_cfg = self.handler.module.get_module_option(
             self.handler.MODULE_OPTION_NAME, '')
@@ -380,10 +419,13 @@ class Schedules:
                     "Failed to load schedules for pool {}: {}".format(
                         pool_name, e))
 
-    def load_from_pool(self, ioctx, namespace_validator, image_validator):
+    def load_from_pool(self,
+                       ioctx: rados.Ioctx,
+                       namespace_validator: Optional[Callable],
+                       image_validator: Optional[Callable]) -> None:
         pool_id = ioctx.get_pool_id()
         pool_name = ioctx.get_pool_name()
-        stale_keys = ()
+        stale_keys = []
         start_after = ''
         try:
             while True:
@@ -409,7 +451,7 @@ class Schedules:
                                 self.handler.log.debug(
                                     "Stale schedule key %s in pool %s",
                                     k, pool_name)
-                                stale_keys += (k,)
+                                stale_keys.append(k)
                                 continue
 
                             self.level_specs[level_spec.id] = level_spec
@@ -432,7 +474,7 @@ class Schedules:
                 ioctx.remove_omap_keys(write_op, stale_keys)
                 ioctx.operate_write_op(write_op, self.handler.SCHEDULE_OID)
 
-    def save(self, level_spec, schedule):
+    def save(self, level_spec: LevelSpec, schedule: Optional[Schedule]) -> None:
         if level_spec.is_global():
             schedule_cfg = schedule and schedule.to_json() or None
             self.handler.module.set_module_option(
@@ -440,6 +482,7 @@ class Schedules:
             return
 
         pool_id = level_spec.get_pool_id()
+        assert pool_id
         with self.handler.module.rados.open_ioctx2(int(pool_id)) as ioctx:
             with rados.WriteOpCtx() as write_op:
                 if schedule:
@@ -449,8 +492,10 @@ class Schedules:
                     ioctx.remove_omap_keys(write_op, (level_spec.id, ))
                 ioctx.operate_write_op(write_op, self.handler.SCHEDULE_OID)
 
-
-    def add(self, level_spec, interval, start_time):
+    def add(self,
+            level_spec: LevelSpec,
+            interval: str,
+            start_time: Optional[str]) -> None:
         schedule = self.schedules.get(level_spec.id, Schedule(level_spec.name))
         schedule.add(Interval.from_string(interval),
                      StartTime.from_string(start_time))
@@ -458,7 +503,10 @@ class Schedules:
         self.level_specs[level_spec.id] = level_spec
         self.save(level_spec, schedule)
 
-    def remove(self, level_spec, interval, start_time):
+    def remove(self,
+               level_spec: LevelSpec,
+               interval: Optional[str],
+               start_time: Optional[str]) -> None:
         schedule = self.schedules.pop(level_spec.id, None)
         if schedule:
             if interval is None:
@@ -472,7 +520,10 @@ class Schedules:
                 del self.level_specs[level_spec.id]
         self.save(level_spec, schedule)
 
-    def find(self, pool_id, namespace, image_id=None):
+    def find(self,
+             pool_id: str,
+             namespace: str,
+             image_id: Optional[str] = None) -> Optional['Schedule']:
         levels = [pool_id, namespace]
         if image_id:
             levels.append(image_id)
@@ -486,15 +537,15 @@ class Schedules:
             nr_levels -= 1
         return None
 
-    def intersects(self, level_spec):
+    def intersects(self, level_spec: LevelSpec) -> bool:
         for ls in self.level_specs.values():
             if ls.intersects(level_spec):
                 return True
         return False
 
-    def to_list(self, level_spec):
+    def to_list(self, level_spec: LevelSpec) -> Dict[str, dict]:
         if level_spec.id in self.schedules:
-            parent = level_spec
+            parent: Optional[LevelSpec] = level_spec
         else:
             # try to find existing parent
             parent = None
@@ -519,4 +570,3 @@ class Schedules:
                     'schedule' : schedule.to_list(),
                 }
         return result
-
index fcec7dcdc46355180bb02763ac75c64321da7d80..d283962a365e3f07f1c84fa9da352360d0caf97b 100644 (file)
@@ -10,6 +10,7 @@ from contextlib import contextmanager
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from threading import Condition, Lock, Thread
+from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
 
 from .common import (authorize_request, extract_pool_key, get_rbd_pools,
                      is_authorized, GLOBAL_POOL_KEY)
@@ -53,52 +54,59 @@ TASK_MAX_RETRY_INTERVAL = timedelta(seconds=300)
 MAX_COMPLETED_TASKS = 50
 
 
+T = TypeVar('T')
+FuncT = TypeVar('FuncT', bound=Callable[..., Any])
+
+
 class Throttle:
-    def __init__(self, throttle_period):
+    def __init__(self: Any, throttle_period: timedelta) -> None:
         self.throttle_period = throttle_period
         self.time_of_last_call = datetime.min
 
-    def __call__(self, fn):
+    def __call__(self: 'Throttle', fn: FuncT) -> FuncT:
         @wraps(fn)
-        def wrapper(*args, **kwargs):
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
             now = datetime.now()
             if self.time_of_last_call + self.throttle_period <= now:
                 self.time_of_last_call = now
                 return fn(*args, **kwargs)
-        return wrapper
+        return cast(FuncT, wrapper)
+
+
+TaskRefsT = Dict[str, str]
 
 
 class Task:
-    def __init__(self, sequence, task_id, message, refs):
+    def __init__(self, sequence: int, task_id: str, message: str, refs: TaskRefsT):
         self.sequence = sequence
         self.task_id = task_id
         self.message = message
         self.refs = refs
-        self.retry_message = None
+        self.retry_message: Optional[str] = None
         self.retry_attempts = 0
-        self.retry_time = None
+        self.retry_time: Optional[datetime] = None
         self.in_progress = False
         self.progress = 0.0
         self.canceled = False
         self.failed = False
         self.progress_posted = False
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.to_json()
 
     @property
-    def sequence_key(self):
-        return "{0:016X}".format(self.sequence)
+    def sequence_key(self) -> bytes:
+        return "{0:016X}".format(self.sequence).encode()
 
-    def cancel(self):
+    def cancel(self) -> None:
         self.canceled = True
         self.fail("Operation canceled")
 
-    def fail(self, message):
+    def fail(self, message: str) -> None:
         self.failed = True
         self.failure_message = message
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         d = {TASK_SEQUENCE: self.sequence,
              TASK_ID: self.task_id,
              TASK_MESSAGE: self.message,
@@ -117,11 +125,11 @@ class Task:
             d[TASK_CANCELED] = True
         return d
 
-    def to_json(self):
+    def to_json(self) -> str:
         return str(json.dumps(self.to_dict()))
 
     @classmethod
-    def from_json(cls, val):
+    def from_json(cls, val: str) -> 'Task':
         try:
             d = json.loads(val)
             action = d.get(TASK_REFS, {}).get(TASK_REF_ACTION)
@@ -135,20 +143,26 @@ class Task:
             raise ValueError("Invalid task format (missing key {})".format(str(e)))
 
 
+# pool_name, namespace, image_name
+ImageSpecT = Tuple[str, str, str]
+# pool_name, namespace
+PoolSpecT = Tuple[str, str]
+MigrationStatusT = Dict[str, str]
+
 class TaskHandler:
     lock = Lock()
     condition = Condition(lock)
     thread = None
 
     in_progress_task = None
-    tasks_by_sequence = dict()
-    tasks_by_id = dict()
+    tasks_by_sequence: Dict[int, Task] = dict()
+    tasks_by_id: Dict[str, Task] = dict()
 
-    completed_tasks = []
+    completed_tasks: List[Task] = []
 
     sequence = 0
 
-    def __init__(self, module):
+    def __init__(self, module: Any) -> None:
         self.module = module
         self.log = module.log
 
@@ -159,16 +173,16 @@ class TaskHandler:
         self.thread.start()
 
     @property
-    def default_pool_name(self):
+    def default_pool_name(self) -> str:
         return self.module.get_ceph_option("rbd_default_pool")
 
-    def extract_pool_spec(self, pool_spec):
+    def extract_pool_spec(self, pool_spec: str) -> PoolSpecT:
         pool_spec = extract_pool_key(pool_spec)
         if pool_spec == GLOBAL_POOL_KEY:
             pool_spec = (self.default_pool_name, '')
-        return pool_spec
+        return cast(PoolSpecT, pool_spec)
 
-    def extract_image_spec(self, image_spec):
+    def extract_image_spec(self, image_spec: str) -> ImageSpecT:
         match = re.match(r'^(?:([^/]+)/(?:([^/]+)/)?)?([^/@]+)$',
                          image_spec or '')
         if not match:
@@ -176,7 +190,7 @@ class TaskHandler:
         return (match.group(1) or self.default_pool_name, match.group(2) or '',
                 match.group(3))
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.log.info("TaskHandler: starting")
             while True:
@@ -195,7 +209,7 @@ class TaskHandler:
                 ex, traceback.format_exc()))
 
     @contextmanager
-    def open_ioctx(self, spec):
+    def open_ioctx(self, spec: PoolSpecT) -> Iterator[rados.Ioctx]:
         try:
             with self.module.rados.open_ioctx(spec[0]) as ioctx:
                 ioctx.set_namespace(spec[1])
@@ -205,7 +219,7 @@ class TaskHandler:
             raise
 
     @classmethod
-    def format_image_spec(cls, image_spec):
+    def format_image_spec(cls, image_spec: ImageSpecT) -> str:
         image = image_spec[2]
         if image_spec[1]:
             image = "{}/{}".format(image_spec[1], image)
@@ -213,7 +227,7 @@ class TaskHandler:
             image = "{}/{}".format(image_spec[0], image)
         return image
 
-    def init_task_queue(self):
+    def init_task_queue(self) -> None:
         for pool_id, pool_name in get_rbd_pools(self.module).items():
             try:
                 with self.module.rados.open_ioctx2(int(pool_id)) as ioctx:
@@ -239,7 +253,7 @@ class TaskHandler:
         self.log.debug("sequence={}, tasks_by_sequence={}, tasks_by_id={}".format(
             self.sequence, str(self.tasks_by_sequence), str(self.tasks_by_id)))
 
-    def load_task_queue(self, ioctx, pool_name):
+    def load_task_queue(self, ioctx: rados.Ioctx, pool_name: str) -> None:
         pool_spec = pool_name
         if ioctx.nspace:
             pool_spec += "/{}".format(ioctx.nspace)
@@ -274,11 +288,11 @@ class TaskHandler:
             # rbd_task DNE
             pass
 
-    def append_task(self, task):
+    def append_task(self, task: Task) -> None:
         self.tasks_by_sequence[task.sequence] = task
         self.tasks_by_id[task.task_id] = task
 
-    def task_refs_match(self, task_refs, refs):
+    def task_refs_match(self, task_refs: TaskRefsT, refs: TaskRefsT) -> bool:
         if TASK_REF_IMAGE_ID not in refs and TASK_REF_IMAGE_ID in task_refs:
             task_refs = task_refs.copy()
             del task_refs[TASK_REF_IMAGE_ID]
@@ -286,7 +300,7 @@ class TaskHandler:
         self.log.debug("task_refs_match: ref1={}, ref2={}".format(task_refs, refs))
         return task_refs == refs
 
-    def find_task(self, refs):
+    def find_task(self, refs: TaskRefsT) -> Optional[Task]:
         self.log.debug("find_task: refs={}".format(refs))
 
         # search for dups and return the original
@@ -299,8 +313,13 @@ class TaskHandler:
         for task in reversed(self.completed_tasks):
             if self.task_refs_match(task.refs, refs):
                 return task
+        else:
+            return None
 
-    def add_task(self, ioctx, message, refs):
+    def add_task(self,
+                 ioctx: rados.Ioctx,
+                 message: str,
+                 refs: TaskRefsT) -> str:
         self.log.debug("add_task: message={}, refs={}".format(message, refs))
 
         # ensure unique uuid across all pools
@@ -328,7 +347,10 @@ class TaskHandler:
         self.condition.notify()
         return task_json
 
-    def remove_task(self, ioctx, task, remove_in_memory=True):
+    def remove_task(self,
+                    ioctx: rados.Ioctx,
+                    task: Task,
+                    remove_in_memory: bool = True) -> None:
         self.log.info("remove_task: task={}".format(str(task)))
         omap_keys = (task.sequence_key, )
         try:
@@ -353,7 +375,7 @@ class TaskHandler:
             except KeyError:
                 pass
 
-    def execute_task(self, sequence):
+    def execute_task(self, sequence: int) -> None:
         task = self.tasks_by_sequence[sequence]
         self.log.info("execute_task: task={}".format(str(task)))
 
@@ -416,7 +438,7 @@ class TaskHandler:
                 TASK_RETRY_INTERVAL * task.retry_attempts,
                 TASK_MAX_RETRY_INTERVAL)
 
-    def progress_callback(self, task, current, total):
+    def progress_callback(self, task: Task, current: int, total: int) -> int:
         progress = float(current) / float(total)
         self.log.debug("progress_callback: task={}, progress={}".format(
             str(task), progress))
@@ -440,7 +462,7 @@ class TaskHandler:
 
         return 0
 
-    def execute_flatten(self, ioctx, task):
+    def execute_flatten(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_flatten: task={}".format(str(task)))
 
         try:
@@ -453,7 +475,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_remove(self, ioctx, task):
+    def execute_remove(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_remove: task={}".format(str(task)))
 
         try:
@@ -463,7 +485,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_trash_remove(self, ioctx, task):
+    def execute_trash_remove(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_trash_remove: task={}".format(str(task)))
 
         try:
@@ -473,7 +495,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_execute(self, ioctx, task):
+    def execute_migration_execute(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_execute: task={}".format(str(task)))
 
         try:
@@ -486,7 +508,7 @@ class TaskHandler:
             task.fail("Image is not migrating")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_commit(self, ioctx, task):
+    def execute_migration_commit(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_commit: task={}".format(str(task)))
 
         try:
@@ -499,7 +521,7 @@ class TaskHandler:
             task.fail("Image is not migrating or migration not executed")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_abort(self, ioctx, task):
+    def execute_migration_abort(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_abort: task={}".format(str(task)))
 
         try:
@@ -512,7 +534,7 @@ class TaskHandler:
             task.fail("Image is not migrating")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def complete_progress(self, task):
+    def complete_progress(self, task: Task) -> None:
         if not task.progress_posted:
             # ensure progress event exists before we complete/fail it
             self.post_progress(task, 0)
@@ -528,7 +550,7 @@ class TaskHandler:
             # progress module is disabled
             pass
 
-    def _update_progress(self, task, progress):
+    def _update_progress(self, task: Task, progress: float) -> None:
         self.log.debug("update_progress: task={}, progress={}".format(str(task), progress))
         try:
             refs = {"origin": "rbd_support"}
@@ -540,19 +562,19 @@ class TaskHandler:
             # progress module is disabled
             pass
 
-    def post_progress(self, task, progress):
+    def post_progress(self, task: Task, progress: float) -> None:
         self._update_progress(task, progress)
         task.progress_posted = True
 
-    def update_progress(self, task, progress):
+    def update_progress(self, task: Task, progress: float) -> None:
         if task.progress_posted:
             self._update_progress(task, progress)
 
     @Throttle(timedelta(seconds=1))
-    def throttled_update_progress(self, task, progress):
+    def throttled_update_progress(self, task: Task, progress: float) -> None:
         self.update_progress(task, progress)
 
-    def queue_flatten(self, image_spec):
+    def queue_flatten(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -563,7 +585,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             try:
                 with rbd.Image(ioctx, image_spec[2]) as image:
                     refs[TASK_REF_IMAGE_ID] = image.id()
@@ -592,7 +614,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ""
 
-    def queue_remove(self, image_spec):
+    def queue_remove(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -603,7 +625,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             try:
                 with rbd.Image(ioctx, image_spec[2]) as image:
                     refs[TASK_REF_IMAGE_ID] = image.id()
@@ -628,7 +650,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def queue_trash_remove(self, image_id_spec):
+    def queue_trash_remove(self, image_id_spec: str) -> Tuple[int, str, str]:
         image_id_spec = self.extract_image_spec(image_id_spec)
 
         authorize_request(self.module, image_id_spec[0], image_id_spec[1])
@@ -643,7 +665,7 @@ class TaskHandler:
             return 0, task.to_json(), ''
 
         # verify that image exists in trash
-        with self.open_ioctx(image_id_spec) as ioctx:
+        with self.open_ioctx(image_id_spec[:2]) as ioctx:
             rbd.RBD().trash_get(ioctx, image_id_spec[2])
 
             return 0, self.add_task(ioctx,
@@ -651,25 +673,29 @@ class TaskHandler:
                                         self.format_image_spec(image_id_spec)),
                                     refs), ''
 
-    def get_migration_status(self, ioctx, image_spec):
+    def get_migration_status(self,
+                             ioctx: rados.Ioctx,
+                             image_spec: ImageSpecT) -> Optional[MigrationStatusT]:
         try:
             return rbd.RBD().migration_status(ioctx, image_spec[2])
         except (rbd.InvalidArgument, rbd.ImageNotFound):
             return None
 
-    def validate_image_migrating(self, image_spec, migration_status):
+    def validate_image_migrating(self,
+                                 image_spec: ImageSpecT,
+                                 migration_status: Optional[MigrationStatusT]) -> None:
         if not migration_status:
             raise rbd.InvalidArgument("Image {} is not migrating".format(
                 self.format_image_spec(image_spec)), errno=errno.EINVAL)
 
-    def resolve_pool_name(self, pool_id):
+    def resolve_pool_name(self, pool_id: str) -> str:
         osd_map = self.module.get('osd_map')
         for pool in osd_map['pools']:
             if pool['pool'] == pool_id:
                 return pool['pool_name']
         return '<unknown>'
 
-    def queue_migration_execute(self, image_spec):
+    def queue_migration_execute(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -680,7 +706,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -690,6 +716,7 @@ class TaskHandler:
                 return 0, task.to_json(), ''
 
             self.validate_image_migrating(image_spec, status)
+            assert status
             if status['state'] not in [rbd.RBD_IMAGE_MIGRATION_STATE_PREPARED,
                                        rbd.RBD_IMAGE_MIGRATION_STATE_EXECUTING]:
                 raise rbd.InvalidArgument("Image {} is not in ready state".format(
@@ -707,7 +734,7 @@ class TaskHandler:
                                                                 status['dest_image_name']))),
                                     refs), ''
 
-    def queue_migration_commit(self, image_spec):
+    def queue_migration_commit(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -718,7 +745,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -728,6 +755,7 @@ class TaskHandler:
                 return 0, task.to_json(), ''
 
             self.validate_image_migrating(image_spec, status)
+            assert status
             if status['state'] != rbd.RBD_IMAGE_MIGRATION_STATE_EXECUTED:
                 raise rbd.InvalidArgument("Image {} has not completed migration".format(
                     self.format_image_spec(image_spec)), errno=errno.EINVAL)
@@ -737,7 +765,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def queue_migration_abort(self, image_spec):
+    def queue_migration_abort(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -748,7 +776,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -763,7 +791,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def task_cancel(self, task_id):
+    def task_cancel(self, task_id: str) -> Tuple[int, str, str]:
         self.log.info("task_cancel: {}".format(task_id))
 
         task = self.tasks_by_id.get(task_id)
@@ -789,7 +817,7 @@ class TaskHandler:
 
         return 0, "", ""
 
-    def task_list(self, task_id):
+    def task_list(self, task_id: Optional[str]) -> Tuple[int, str, str]:
         self.log.info("task_list: {}".format(task_id))
 
         if task_id:
@@ -799,14 +827,14 @@ class TaskHandler:
                                              task.refs[TASK_REF_POOL_NAMESPACE]):
                 return -errno.ENOENT, '', "No such task {}".format(task_id)
 
-            result = task.to_dict()
+            return 0, json.dumps(task.to_dict(), indent=4, sort_keys=True), ""
         else:
-            result = []
+            tasks = []
             for sequence in sorted(self.tasks_by_sequence.keys()):
                 task = self.tasks_by_sequence[sequence]
                 if is_authorized(self.module,
                                  task.refs[TASK_REF_POOL_NAME],
                                  task.refs[TASK_REF_POOL_NAMESPACE]):
-                    result.append(task.to_dict())
+                    tasks.append(task.to_dict())
 
-        return 0, json.dumps(result, indent=4, sort_keys=True), ""
+            return 0, json.dumps(tasks, indent=4, sort_keys=True), ""
index 7eb96e258e11d42cc91115f326f38f82e42a0dbc..2eaad833c9b6f0c13859cd6cd1ae6fe0548d2489 100644 (file)
@@ -7,6 +7,7 @@ import traceback
 
 from datetime import datetime
 from threading import Condition, Lock, Thread
+from typing import Any, Dict, List, Optional, Tuple
 
 from .common import get_rbd_pools
 from .schedule import LevelSpec, Interval, StartTime, Schedule, Schedules
@@ -20,7 +21,7 @@ class TrashPurgeScheduleHandler:
     condition = Condition(lock)
     thread = None
 
-    def __init__(self, module):
+    def __init__(self, module: Any) -> None:
         self.module = module
         self.log = module.log
         self.last_refresh_pools = datetime(1970, 1, 1)
@@ -30,7 +31,7 @@ class TrashPurgeScheduleHandler:
         self.thread = Thread(target=self.run)
         self.thread.start()
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.log.info("TrashPurgeScheduleHandler: starting")
             while True:
@@ -49,7 +50,7 @@ class TrashPurgeScheduleHandler:
             self.log.fatal("Fatal runtime error: {}\n{}".format(
                 ex, traceback.format_exc()))
 
-    def trash_purge(self, pool_id, namespace):
+    def trash_purge(self, pool_id: str, namespace: str) -> None:
         try:
             with self.module.rados.open_ioctx2(int(pool_id)) as ioctx:
                 ioctx.set_namespace(namespace)
@@ -58,14 +59,14 @@ class TrashPurgeScheduleHandler:
             self.log.error("exception when purgin {}/{}: {}".format(
                 pool_id, namespace, e))
 
-
-    def init_schedule_queue(self):
-        self.queue = {}
-        self.pools = {}
+    def init_schedule_queue(self) -> None:
+        self.queue: Dict[str, List[Tuple[str, str]]] = {}
+        # pool_id => {namespace => pool_name}
+        self.pools: Dict[str, Dict[str, str]] = {}
         self.refresh_pools()
         self.log.debug("scheduler queue is initialized")
 
-    def load_schedules(self):
+    def load_schedules(self) -> None:
         self.log.info("TrashPurgeScheduleHandler: load_schedules")
 
         schedules = Schedules(self)
@@ -73,7 +74,7 @@ class TrashPurgeScheduleHandler:
         with self.lock:
             self.schedules = schedules
 
-    def refresh_pools(self):
+    def refresh_pools(self) -> None:
         if (datetime.now() - self.last_refresh_pools).seconds < 60:
             return
 
@@ -81,7 +82,7 @@ class TrashPurgeScheduleHandler:
 
         self.load_schedules()
 
-        pools = {}
+        pools: Dict[str, Dict[str, str]] = {}
 
         for pool_id, pool_name in get_rbd_pools(self.module).items():
             if not self.schedules.intersects(
@@ -96,7 +97,7 @@ class TrashPurgeScheduleHandler:
 
         self.last_refresh_pools = datetime.now()
 
-    def load_pool(self, ioctx, pools):
+    def load_pool(self, ioctx: rados.Ioctx, pools: Dict[str, Dict[str, str]]) -> None:
         pool_id = str(ioctx.get_pool_id())
         pool_name = ioctx.get_pool_name()
         pools[pool_id] = {}
@@ -115,7 +116,7 @@ class TrashPurgeScheduleHandler:
         for namespace in pool_namespaces:
             pools[pool_id][namespace] = pool_name
 
-    def rebuild_queue(self):
+    def rebuild_queue(self) -> None:
         with self.lock:
             now = datetime.now()
 
@@ -135,7 +136,7 @@ class TrashPurgeScheduleHandler:
 
             self.condition.notify()
 
-    def refresh_queue(self, current_pools):
+    def refresh_queue(self, current_pools: Dict[str, Dict[str, str]]) -> None:
         now = datetime.now()
 
         for pool_id, namespaces in self.pools.items():
@@ -152,7 +153,7 @@ class TrashPurgeScheduleHandler:
 
         self.condition.notify()
 
-    def enqueue(self, now, pool_id, namespace):
+    def enqueue(self, now: datetime, pool_id: str, namespace: str) -> None:
 
         schedule = self.schedules.find(pool_id, namespace)
         if not schedule:
@@ -167,9 +168,9 @@ class TrashPurgeScheduleHandler:
         if ns_spec not in self.queue[schedule_time]:
             self.queue[schedule_time].append((pool_id, namespace))
 
-    def dequeue(self):
+    def dequeue(self) -> Tuple[Optional[Tuple[str, str]], float]:
         if not self.queue:
-            return None, 1000
+            return None, 1000.0
 
         now = datetime.now()
         schedule_time = sorted(self.queue)[0]
@@ -183,9 +184,9 @@ class TrashPurgeScheduleHandler:
         namespace = namespaces.pop(0)
         if not namespaces:
             del self.queue[schedule_time]
-        return namespace, 0
+        return namespace, 0.0
 
-    def remove_from_queue(self, pool_id, namespace):
+    def remove_from_queue(self, pool_id: str, namespace: str) -> None:
         empty_slots = []
         for schedule_time, namespaces in self.queue.items():
             if (pool_id, namespace) in namespaces:
@@ -195,7 +196,10 @@ class TrashPurgeScheduleHandler:
         for schedule_time in empty_slots:
             del self.queue[schedule_time]
 
-    def add_schedule(self, level_spec, interval, start_time):
+    def add_schedule(self,
+                     level_spec: LevelSpec,
+                     interval: str,
+                     start_time: Optional[str]) -> Tuple[int, str, str]:
         self.log.debug(
             "add_schedule: level_spec={}, interval={}, start_time={}".format(
                 level_spec.name, interval, start_time))
@@ -207,7 +211,10 @@ class TrashPurgeScheduleHandler:
         self.rebuild_queue()
         return 0, "", ""
 
-    def remove_schedule(self, level_spec, interval, start_time):
+    def remove_schedule(self,
+                        level_spec: LevelSpec,
+                        interval: Optional[str],
+                        start_time: Optional[str]) -> Tuple[int, str, str]:
         self.log.debug(
             "remove_schedule: level_spec={}, interval={}, start_time={}".format(
                 level_spec.name, interval, start_time))
@@ -219,7 +226,7 @@ class TrashPurgeScheduleHandler:
         self.rebuild_queue()
         return 0, "", ""
 
-    def list(self, level_spec):
+    def list(self, level_spec: LevelSpec) -> Tuple[int, str, str]:
         self.log.debug("list: level_spec={}".format(level_spec.name))
 
         with self.lock:
@@ -227,7 +234,7 @@ class TrashPurgeScheduleHandler:
 
         return 0, json.dumps(result, indent=4, sort_keys=True), ""
 
-    def status(self, level_spec):
+    def status(self, level_spec: LevelSpec) -> Tuple[int, str, str]:
         self.log.debug("status: level_spec={}".format(level_spec.name))
 
         scheduled = []
index c554de29f89075626919bbbcaf6ef4b81b4bd794..a17a24e4e19df31801e78cc1a66a2ea7d73c4f07 100644 (file)
@@ -75,6 +75,7 @@ commands =
            -m orchestrator \
            -m progress \
            -m prometheus \
+           -m rbd_support \
            -m rook \
            -m snap_schedule \
            -m stats \