]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/cli Redo cli api mgr module.
authorWaad AlKhoury <waadalkhoury@localhost.localdomain>
Wed, 24 Nov 2021 09:12:41 +0000 (10:12 +0100)
committerPere Diaz Bou <pdiazbou@redhat.com>
Tue, 1 Feb 2022 14:14:16 +0000 (15:14 +0100)
Signed-off-by: Waad AlKhoury <walkhour@redhat.com>
Signed-off-by: Pere Diaz Bou <pdiazbou@redhat.com>
src/pybind/mgr/CMakeLists.txt
src/pybind/mgr/cli_api/__init__.py
src/pybind/mgr/cli_api/module.py
src/pybind/mgr/cli_api/tests/test_cli_api.py [new file with mode: 0644]
src/pybind/mgr/cli_api/tests/test_cliapi.py [deleted file]
src/pybind/mgr/mgr_module.py

index 4756f69aa724b086e77e1c7b162d4488b44637e8..4b915219a365b2e9c36a3405bdd0009c7b657582 100644 (file)
@@ -18,6 +18,7 @@ install(DIRECTORY
   REGEX "\\.gitignore" EXCLUDE
   REGEX ".*\\.pyi" EXCLUDE
   REGEX "hello/.*" EXCLUDE
+  REGEX "cli_api/.*" EXCLUDE
   REGEX "tests/.*" EXCLUDE
   REGEX "rook/rook-client-python.*" EXCLUDE
   REGEX "osd_perf_query/.*" EXCLUDE
index 6b71ccbef682dccb7e9a62e5bed6bd264458eaaf..a52284054d240057ee07e467979333842fcf5d5f 100644 (file)
@@ -1,6 +1,10 @@
-import os
-from .module import CLI  # noqa # pylint: disable=unused-import
+from .module import CLI
 
+__all__ = [
+    "CLI",
+]
 
+import os
 if 'UNITTEST' in os.environ:
     import tests  # noqa # pylint: disable=unused-import
+    __all__.append(tests.__name__)
index 527b8930c67706735fd3ee99cd3a7649bb5902ff..79b042eb0e9d6e2d8d458af8abfbf5142a72a46c 100755 (executable)
-import json
+import concurrent.futures
+import functools
+import inspect
 import logging
-import threading
 import time
-from functools import partial
-from queue import Queue
+import errno
+from typing import Any, Callable, Dict, List
 
-from mgr_module import CLICommand, HandleCommandResult, MgrModule
+from mgr_module import MgrModule, HandleCommandResult, CLICommand, API
 
 logger = logging.getLogger()
+get_time = time.perf_counter
 
 
-class CLI(MgrModule):
-
-    @CLICommand('mgr api get')
-    def api_get(self, arg: str):
-        '''
-        Called by the plugin to fetch named cluster-wide objects from ceph-mgr.
-        :param str data_name: Valid things to fetch are osd_crush_map_text,
-                osd_map, osd_map_tree, osd_map_crush, config, mon_map, fs_map,
-                osd_metadata, pg_summary, io_rate, pg_dump, df, osd_stats,
-                health, mon_status, devices, device <devid>, pg_stats,
-                pool_stats, pg_ready, osd_ping_times.
-        Note:
-            All these structures have their own JSON representations: experiment
-            or look at the C++ ``dump()`` methods to learn about them.
-        '''
-        t1_start = time.time()
-        str_arg = self.get(arg)
-        t1_end = time.time()
-        time_final = (t1_end - t1_start)
-        return HandleCommandResult(0, json.dumps(str_arg), str(time_final))
-
-    @CLICommand('mgr api benchmark get')
-    def api_get_benchmark(self, arg: str, number_of_total_calls: int,
-                          number_of_parallel_calls: int):
-        benchmark_runner = ThreadedBenchmarkRunner(number_of_total_calls, number_of_parallel_calls)
-        benchmark_runner.start(partial(self.get, arg))
-        benchmark_runner.join()
-        stats = benchmark_runner.get_stats()
-        return HandleCommandResult(0, json.dumps(stats), "")
-
-
-class ThreadedBenchmarkRunner:
-    def __init__(self, number_of_total_calls, number_of_parallel_calls):
-        self._number_of_parallel_calls = number_of_parallel_calls
-        self._number_of_total_calls = number_of_total_calls
-        self._threads = []
-        self._jobs: Queue = Queue()
-        self._time = 0.0
-        self._self_time = []
-        self._lock = threading.Lock()
-
-    def start(self, func):
-        if(self._number_of_total_calls and self._number_of_parallel_calls):
-            for thread_id in range(self._number_of_parallel_calls):
-                new_thread = threading.Thread(target=ThreadedBenchmarkRunner.timer,
-                                              args=(self, self._jobs, func,))
-                self._threads.append(new_thread)
-            for job_id in range(self._number_of_total_calls):
-                self._jobs.put(job_id)
-            for thread in self._threads:
-                thread.start()
+def pretty_json(obj: Any) -> Any:
+    import json
+    return json.dumps(obj, sort_keys=True, indent=2)
+
+
+class CephCommander:
+    """
+    Utility class to inspect Python functions and generate corresponding
+    CephCommand signatures (see src/mon/MonCommand.h for details)
+    """
+
+    def __init__(self, func: Callable):
+        self.func = func
+        self.signature = inspect.signature(func)
+        self.params = self.signature.parameters
+
+    def to_ceph_signature(self) -> Dict[str, str]:
+        """
+        Generate CephCommand signature (dict-like)
+        """
+        return {
+            'prefix': f'mgr cli {self.func.__name__}',
+            'perm': API.perm.get(self.func)
+        }
+
+
+class MgrAPIReflector(type):
+    """
+    Metaclass to register COMMANDS and Command Handlers via CLICommand
+    decorator
+    """
+
+    def __new__(cls, name, bases, dct):  # type: ignore
+        klass = super().__new__(cls, name, bases, dct)
+        cls.threaded_benchmark_runner = None
+        for base in bases:
+            for name, func in inspect.getmembers(base, cls.is_public):
+                # However not necessary (CLICommand uses a registry)
+                # save functions to klass._cli_{n}() methods. This
+                # can help on unit testing
+                wrapper = cls.func_wrapper(func)
+                command = CLICommand(**CephCommander(func).to_ceph_signature())(  # type: ignore
+                    wrapper)
+                setattr(
+                    klass,
+                    f'_cli_{name}',
+                    command)
+        return klass
+
+    @staticmethod
+    def is_public(func: Callable) -> bool:
+        return (
+            inspect.isfunction(func)
+            and not func.__name__.startswith('_')
+            and API.expose.get(func)
+        )
+
+    @staticmethod
+    def func_wrapper(func: Callable) -> Callable:
+        @functools.wraps(func)
+        def wrapper(self, *args, **kwargs) -> HandleCommandResult:  # type: ignore
+            return HandleCommandResult(stdout=pretty_json(
+                func(self, *args, **kwargs)))
+
+        # functools doesn't change the signature when wrapping a function
+        # so we do it manually
+        signature = inspect.signature(func)
+        wrapper.__signature__ = signature  # type: ignore
+        return wrapper
+
+
+class CLI(MgrModule, metaclass=MgrAPIReflector):
+    @CLICommand('mgr cli_benchmark')
+    def benchmark(self, iterations: int, threads: int, func_name: str,
+                  func_args: List[str] = None) -> HandleCommandResult:  # type: ignore
+        func_args = () if func_args is None else func_args
+        if iterations and threads:
+            try:
+                func = getattr(self, func_name)
+            except AttributeError:
+                return HandleCommandResult(errno.EINVAL,
+                                           stderr="Could not find the public "
+                                           "function you are requesting")
         else:
-            raise BenchmarkException("Number of Total and number of parallel calls must be greater than 0")
+            raise BenchmarkException("Number of calls and number "
+                                     "of parallel calls must be greater than 0")
+
+        def timer(*args: Any) -> float:
+            time_start = get_time()
+            func(*func_args)
+            return get_time() - time_start
 
-    def join(self):
-        for thread in self._threads:
-            thread.join()
+        with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
+            results_iter = executor.map(timer, range(iterations))
+        results = list(results_iter)
 
-    def get_stats(self):
         stats = {
-            "avg": (self._time / self._number_of_total_calls),
-            "min": min(self._self_time),
-            "max": max(self._self_time)
+            "avg": sum(results) / len(results),
+            "max": max(results),
+            "min": min(results),
         }
-        return stats
-
-    def timer(self, jobs, func):
-        self._lock.acquire()
-        while not self._jobs.empty():
-            jobs.get()
-            time_start = time.time()
-            func()
-            time_end = time.time()
-            self._self_time.append(time_end - time_start)
-            self._time += (time_end - time_start)
-            self._jobs.task_done()
-        self._lock.release()
+        return HandleCommandResult(stdout=pretty_json(stats))
 
 
 class BenchmarkException(Exception):
diff --git a/src/pybind/mgr/cli_api/tests/test_cli_api.py b/src/pybind/mgr/cli_api/tests/test_cli_api.py
new file mode 100644 (file)
index 0000000..ee42dc9
--- /dev/null
@@ -0,0 +1,40 @@
+import unittest
+
+from ..module import CLI, BenchmarkException, HandleCommandResult
+
+
+class BenchmarkRunnerTest(unittest.TestCase):
+    def setUp(self):
+        self.cli = CLI('CLI', 0, 0)
+
+    def test_number_of_calls_on_start_fails(self):
+        with self.assertRaises(BenchmarkException) as ctx:
+            self.cli.benchmark(0, 10, 'list_servers', [])
+        self.assertEqual(str(ctx.exception),
+                         "Number of calls and number "
+                         "of parallel calls must be greater than 0")
+
+    def test_number_of_parallel_calls_on_start_fails(self):
+        with self.assertRaises(BenchmarkException) as ctx:
+            self.cli.benchmark(100, 0, 'list_servers', [])
+        self.assertEqual(str(ctx.exception),
+                         "Number of calls and number "
+                         "of parallel calls must be greater than 0")
+
+    def test_number_of_parallel_calls_on_start_works(self):
+        CLI.benchmark(10, 10, "get", "osd_map")
+
+    def test_function_name_fails(self):
+        for iterations in [0, 1]:
+            threads = 0 if iterations else 1
+            with self.assertRaises(BenchmarkException) as ctx:
+                self.cli.benchmark(iterations, threads, 'fake_method', [])
+            self.assertEqual(str(ctx.exception),
+                             "Number of calls and number "
+                             "of parallel calls must be greater than 0")
+        result: HandleCommandResult = self.cli.benchmark(1, 1, 'fake_method', [])
+        self.assertEqual(result.stderr, "Could not find the public "
+                         "function you are requesting")
+
+    def test_function_name_works(self):
+        CLI.benchmark(10, 10, "get", "osd_map")
diff --git a/src/pybind/mgr/cli_api/tests/test_cliapi.py b/src/pybind/mgr/cli_api/tests/test_cliapi.py
deleted file mode 100644 (file)
index 2136161..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-import unittest
-
-from ..module import ThreadedBenchmarkRunner, BenchmarkException
-
-
-class ThreadedBenchmarkRunnerTest(unittest.TestCase):
-    def test_number_of_calls_on_start_fails(self):
-        class_threadbenchmarkrunner = ThreadedBenchmarkRunner(0, 10)
-        with self.assertRaises(BenchmarkException):
-            class_threadbenchmarkrunner.start(None)
-
-    def test_number_of_parallel_calls_on_start_fails(self):
-        class_threadbenchmarkrunner = ThreadedBenchmarkRunner(10, 0)
-        with self.assertRaises(BenchmarkException):
-            class_threadbenchmarkrunner.start(None)
-
-    def test_number_of_parallel_calls_on_start_works(self):
-        class_threadbenchmarkrunner = ThreadedBenchmarkRunner(10, 10)
-
-        def dummy_function():
-            pass
-        class_threadbenchmarkrunner.start(dummy_function)
-        assert len(class_threadbenchmarkrunner._self_time) > 0
-        assert sum(class_threadbenchmarkrunner._self_time) > 0
-
-    def test_get_stats_works(self):
-        class_threadbenchmarkrunner = ThreadedBenchmarkRunner(10, 10)
-
-        def dummy_function():
-            for i in range(10):
-                pass
-        class_threadbenchmarkrunner.start(dummy_function)
-        stats = class_threadbenchmarkrunner.get_stats()
-        assert stats['avg'] > 0
-        assert stats['min'] > 0
-        assert stats['max'] > 0
index 17880a1f99557fc0ebe3317f5d240c89764ba20c..519ab0add833e05ff65d3694f54a9e18ea82d952 100644 (file)
@@ -830,6 +830,28 @@ ServerInfoT = Dict[str, Union[str, List[ServiceInfoT]]]
 PerfCounterT = Dict[str, Any]
 
 
+class API:
+    def DecoratorFactory(attr: str, default: Any):  # type: ignore
+        class DecoratorClass:
+            _ATTR_TOKEN = f'__ATTR_{attr.upper()}__'
+
+            def __init__(self, value: Any=default) -> None:
+                self.value = value
+
+            def __call__(self, func: Callable) -> Any:
+                setattr(func, self._ATTR_TOKEN, self.value)
+                return func
+
+            @classmethod
+            def get(cls, func: Callable) -> Any:
+                return getattr(func, cls._ATTR_TOKEN, default)
+
+        return DecoratorClass
+
+    perm = DecoratorFactory('perm', default='r')
+    expose = DecoratorFactory('expose', default=False)(True)
+
+
 class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
     COMMANDS = []  # type: List[Any]
     MODULE_OPTIONS: List[Option] = []
@@ -950,6 +972,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_get_release_name()
 
+    @API.expose
     def lookup_release_name(self, major: int) -> str:
         return self._ceph_lookup_release_name(major)
 
@@ -1031,7 +1054,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
             self._rados.shutdown()
             self._ceph_unregister_client(addrs)
 
-    def get(self, data_name: str):
+    @API.expose
+    def get(self, data_name: str) -> Any:
         """
         Called by the plugin to fetch named cluster-wide objects from ceph-mgr.
 
@@ -1155,7 +1179,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
 
         return ret
 
-    def get_server(self, hostname) -> ServerInfoT:
+    @API.expose
+    def get_server(self, hostname: str) -> ServerInfoT:
         """
         Called by the plugin to fetch metadata about a particular hostname from
         ceph-mgr.
@@ -1167,6 +1192,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return cast(ServerInfoT, self._ceph_get_server(hostname))
 
+    @API.expose
     def get_perf_schema(self,
                         svc_type: str,
                         svc_name: str) -> Dict[str,
@@ -1182,6 +1208,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_get_perf_schema(svc_type, svc_name)
 
+    @API.expose
     def get_counter(self,
                     svc_type: str,
                     svc_name: str,
@@ -1200,6 +1227,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_get_counter(svc_type, svc_name, path)
 
+    @API.expose
     def get_latest_counter(self,
                            svc_type: str,
                            svc_name: str,
@@ -1219,6 +1247,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_get_latest_counter(svc_type, svc_name, path)
 
+    @API.expose
     def list_servers(self) -> List[ServerInfoT]:
         """
         Like ``get_server``, but gives information about all servers (i.e. all
@@ -1250,6 +1279,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
             return default
         return metadata
 
+    @API.expose
     def get_daemon_status(self, svc_type: str, svc_id: str) -> Dict[str, str]:
         """
         Fetch the latest status for a particular service daemon.
@@ -1422,18 +1452,22 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_get_mgr_id()
 
+    @API.expose
     def get_ceph_conf_path(self) -> str:
         return self._ceph_get_ceph_conf_path()
 
+    @API.expose
     def get_mgr_ip(self) -> str:
         ips = self.get("mgr_ips").get('ips', [])
         if not ips:
             return socket.gethostname()
         return ips[0]
 
+    @API.expose
     def get_ceph_option(self, key: str) -> OptionValue:
         return self._ceph_get_option(key)
 
+    @API.expose
     def get_foreign_ceph_option(self, entity: str, key: str) -> OptionValue:
         return self._ceph_get_foreign_option(entity, key)
 
@@ -1483,6 +1517,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         r = self._ceph_get_module_option(module, key)
         return default if r is None else r
 
+    @API.expose
     def get_store_prefix(self, key_prefix: str) -> Dict[str, str]:
         """
         Retrieve a dict of KV store keys to values, where the keys
@@ -1533,6 +1568,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
             self._validate_module_option(key)
         return self._ceph_set_module_option(module, key, str(val))
 
+    @API.perm('w')
+    @API.expose
     def set_localized_module_option(self, key: str, val: Optional[str]) -> None:
         """
         Set localized configuration for this ceph-mgr instance
@@ -1543,6 +1580,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         self._validate_module_option(key)
         return self._set_localized(key, val, self._set_module_option)
 
+    @API.perm('w')
+    @API.expose
     def set_store(self, key: str, val: Optional[str]) -> None:
         """
         Set a value in this module's persistent key value store.
@@ -1550,6 +1589,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         self._ceph_set_store(key, val)
 
+    @API.expose
     def get_store(self, key: str, default: Optional[str] = None) -> Optional[str]:
         """
         Get a value from this module's persistent key value store
@@ -1560,6 +1600,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         else:
             return r
 
+    @API.expose
     def get_localized_store(self, key: str, default: Optional[str] = None) -> Optional[str]:
         r = self._ceph_get_store(_get_localized_key(self.get_mgr_id(), key))
         if r is None:
@@ -1568,6 +1609,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
                 r = default
         return r
 
+    @API.perm('w')
+    @API.expose
     def set_localized_store(self, key: str, val: Optional[str]) -> None:
         return self._set_localized(key, val, self.set_store)
 
@@ -1593,6 +1636,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return cast(OSDMap, self._ceph_get_osdmap())
 
+    @API.expose
     def get_latest(self, daemon_type: str, daemon_name: str, counter: str) -> int:
         data = self.get_latest_counter(
             daemon_type, daemon_name, counter)[counter]
@@ -1601,6 +1645,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         else:
             return 0
 
+    @API.expose
     def get_latest_avg(self, daemon_type: str, daemon_name: str, counter: str) -> Tuple[int, int]:
         data = self.get_latest_counter(
             daemon_type, daemon_name, counter)[counter]
@@ -1611,6 +1656,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         else:
             return 0, 0
 
+    @API.expose
     @profile_method()
     def get_all_perf_counters(self, prio_limit: int = PRIO_USEFUL,
                               services: Sequence[str] = ("mds", "mon", "osd",
@@ -1683,6 +1729,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
 
         return result
 
+    @API.expose
     def set_uri(self, uri: str) -> None:
         """
         If the module exposes a service, then call this to publish the
@@ -1692,9 +1739,12 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_set_uri(uri)
 
+    @API.perm('w')
+    @API.expose
     def set_device_wear_level(self, devid: str, wear_level: float) -> None:
         return self._ceph_set_device_wear_level(devid, wear_level)
 
+    @API.expose
     def have_mon_connection(self) -> bool:
         """
         Check whether this ceph-mgr daemon has an open connection
@@ -1712,9 +1762,13 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
                               add_to_ceph_s: bool) -> None:
         return self._ceph_update_progress_event(evid, desc, progress, add_to_ceph_s)
 
+    @API.perm('w')
+    @API.expose
     def complete_progress_event(self, evid: str) -> None:
         return self._ceph_complete_progress_event(evid)
 
+    @API.perm('w')
+    @API.expose
     def clear_all_progress_events(self) -> None:
         return self._ceph_clear_all_progress_events()
 
@@ -1749,6 +1803,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
 
         return True, ""
 
+    @API.expose
     def remote(self, module_name: str, method_name: str, *args: Any, **kwargs: Any) -> Any:
         """
         Invoke a method on another module.  All arguments, and the return
@@ -1803,6 +1858,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_add_osd_perf_query(query)
 
+    @API.perm('w')
+    @API.expose
     def remove_osd_perf_query(self, query_id: int) -> None:
         """
         Unregister an OSD perf query.
@@ -1811,6 +1868,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_remove_osd_perf_query(query_id)
 
+    @API.expose
     def get_osd_perf_counters(self, query_id: int) -> Optional[Dict[str, List[PerfCounterT]]]:
         """
         Get stats collected for an OSD perf query.
@@ -1848,6 +1906,8 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_add_mds_perf_query(query)
 
+    @API.perm('w')
+    @API.expose
     def remove_mds_perf_query(self, query_id: int) -> None:
         """
         Unregister an MDS perf query.
@@ -1856,6 +1916,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_remove_mds_perf_query(query_id)
 
+    @API.expose
     def get_mds_perf_counters(self, query_id: int) -> Optional[Dict[str, List[PerfCounterT]]]:
         """
         Get stats collected for an MDS perf query.
@@ -1875,6 +1936,7 @@ class MgrModule(ceph_module.BaseMgrModule, MgrModuleLoggingMixin):
         """
         return self._ceph_is_authorized(arguments)
 
+    @API.expose
     def send_rgwadmin_command(self, args: List[str],
                               stdout_as_json: bool = True) -> Tuple[int, Union[str, dict], str]:
         try: