]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/dashboard: Task decorator for controller endpoints
authorRicardo Dias <rdias@suse.com>
Wed, 11 Apr 2018 11:42:41 +0000 (12:42 +0100)
committerRicardo Dias <rdias@suse.com>
Fri, 13 Apr 2018 14:58:48 +0000 (15:58 +0100)
Signed-off-by: Ricardo Dias <rdias@suse.com>
src/pybind/mgr/dashboard/controllers/__init__.py
src/pybind/mgr/dashboard/tests/helper.py
src/pybind/mgr/dashboard/tests/test_rest_tasks.py [new file with mode: 0644]

index d695578c6cba4e333e36f18f96636a7916158279..7c36866c246298b19b3a7eabc1d72d663526384b 100644 (file)
@@ -5,6 +5,7 @@ from __future__ import absolute_import
 import collections
 from datetime import datetime, timedelta
 import fnmatch
+from functools import wraps
 import importlib
 import inspect
 import json
@@ -20,7 +21,7 @@ from six import add_metaclass
 
 from .. import logger
 from ..settings import Settings
-from ..tools import Session
+from ..tools import Session, TaskManager
 
 
 def ApiController(path):
@@ -259,6 +260,67 @@ def browsable_api_view(meth):
     return wrapper
 
 
+class Task(object):
+    def __init__(self, name, metadata, wait_for=5.0, exception_handler=None):
+        self.name = name
+        if isinstance(metadata, list):
+            self.metadata = dict([(e[1:-1], e) for e in metadata])
+        else:
+            self.metadata = metadata
+        self.wait_for = wait_for
+        self.exception_handler = exception_handler
+
+    def _gen_arg_map(self, func, args, kwargs):
+        # pylint: disable=deprecated-method
+        arg_map = {}
+        if sys.version_info > (3, 0):  # pylint: disable=no-else-return
+            sig = inspect.signature(func)
+            arg_list = [a for a in sig.parameters]
+        else:
+            sig = inspect.getargspec(func)
+            arg_list = [a for a in sig.args]
+
+        for idx, arg in enumerate(arg_list):
+            if idx < len(args):
+                arg_map[arg] = args[idx]
+            else:
+                if arg in kwargs:
+                    arg_map[arg] = kwargs[arg]
+            if arg in arg_map:
+                arg_map[idx] = arg_map[arg]
+
+        return arg_map
+
+    def __call__(self, func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            arg_map = self._gen_arg_map(func, args, kwargs)
+            md = {}
+            for k, v in self.metadata.items():
+                if isinstance(v, str) and v and v[0] == '{' and v[-1] == '}':
+                    param = v[1:-1]
+                    try:
+                        pos = int(param)
+                        md[k] = arg_map[pos]
+                    except ValueError:
+                        md[k] = arg_map[v[1:-1]]
+                else:
+                    md[k] = v
+            task = TaskManager.run(self.name, md, func, args, kwargs)
+            try:
+                status, value = task.wait(self.wait_for)
+            except Exception as ex:
+                if self.exception_handler:
+                    return self.exception_handler(ex)
+                raise ex
+            if status == TaskManager.VALUE_EXECUTING:
+                cherrypy.response.status = 202
+                return {'name': self.name, 'metadata': md}
+            return value
+        wrapper.__wrapped__ = func
+        return wrapper
+
+
 class BaseControllerMeta(type):
     def __new__(mcs, name, bases, dct):
         new_cls = type.__new__(mcs, name, bases, dct)
@@ -294,6 +356,7 @@ class BaseController(object):
                      (v.kind == inspect.Parameter.POSITIONAL_ONLY or
                       v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD)]
         else:
+            func = getattr(func, '__wrapped__', func)
             args = inspect.getargspec(func)
             nd = len(args.args) if not args.defaults else -len(args.defaults)
             cargs = args.args[1:nd]
index a503c45006733dcff1be34175aa80a9c602ea1e1..473276f59310b2016b4f4caf6a389c4f878c97df 100644 (file)
@@ -1,12 +1,15 @@
 # -*- coding: utf-8 -*-
-# pylint: disable=W0212
+# pylint: disable=W0212,too-many-arguments
 from __future__ import absolute_import
 
 import json
+import threading
+import time
 
 import cherrypy
 from cherrypy.test import helper
 
+from .. import logger
 from ..controllers.auth import Auth
 from ..controllers import json_error_page, generate_controller_routes
 from ..tools import SessionExpireAtBrowserCloseTool
@@ -53,6 +56,69 @@ class ControllerTestCase(helper.CPWebCase):
     def _put(self, url, data=None):
         self._request(url, 'PUT', data)
 
+    def _task_request(self, method, url, data, timeout):
+        self._request(url, method, data)
+        if self.status != '202 Accepted':
+            logger.info("task finished immediately")
+            return
+
+        res = self.jsonBody()
+        self.assertIsInstance(res, dict)
+        self.assertIn('name', res)
+        self.assertIn('metadata', res)
+
+        task_name = res['name']
+        task_metadata = res['metadata']
+
+        class Waiter(threading.Thread):
+            def __init__(self, task_name, task_metadata, tc):
+                super(Waiter, self).__init__()
+                self.task_name = task_name
+                self.task_metadata = task_metadata
+                self.ev = threading.Event()
+                self.abort = False
+                self.res_task = None
+                self.tc = tc
+
+            def run(self):
+                running = True
+                while running and not self.abort:
+                    logger.info("task (%s, %s) is still executing", self.task_name,
+                                self.task_metadata)
+                    time.sleep(1)
+                    self.tc._get('/task?name={}'.format(self.task_name))
+                    res = self.tc.jsonBody()
+                    for task in res['finished_tasks']:
+                        if task['metadata'] == self.task_metadata:
+                            # task finished
+                            running = False
+                            self.res_task = task
+                            self.ev.set()
+
+        thread = Waiter(task_name, task_metadata, self)
+        thread.start()
+        status = thread.ev.wait(timeout)
+        if not status:
+            # timeout expired
+            thread.abort = True
+            thread.join()
+            raise Exception("Waiting for task ({}, {}) to finish timed out"
+                            .format(task_name, task_metadata))
+        logger.info("task (%s, %s) finished", task_name, task_metadata)
+        if thread.res_task['success']:
+            self.body = json.dumps(thread.res_task['ret_value'])
+            return
+        raise Exception(thread.res_task['exception'])
+
+    def _task_post(self, url, data=None, timeout=60):
+        self._task_request('POST', url, data, timeout)
+
+    def _task_delete(self, url, timeout=60):
+        self._task_request('DELETE', url, None, timeout)
+
+    def _task_put(self, url, data=None, timeout=60):
+        self._task_request('PUT', url, data, timeout)
+
     def jsonBody(self):
         body_str = self.body.decode('utf-8') if isinstance(self.body, bytes) else self.body
         return json.loads(body_str)
diff --git a/src/pybind/mgr/dashboard/tests/test_rest_tasks.py b/src/pybind/mgr/dashboard/tests/test_rest_tasks.py
new file mode 100644 (file)
index 0000000..1811902
--- /dev/null
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# pylint: disable=blacklisted-name
+
+import time
+
+from .helper import ControllerTestCase
+from ..controllers import ApiController, RESTController, Task
+from ..controllers.task import Task as TaskController
+from ..tools import NotificationQueue, TaskManager
+
+
+@ApiController('test/task')
+class TaskTest(RESTController):
+    sleep_time = 0.0
+
+    @Task('task/create', {'param': '{param}'}, wait_for=1.0)
+    @RESTController.args_from_json
+    def create(self, param):
+        time.sleep(TaskTest.sleep_time)
+        return {'my_param': param}
+
+    @Task('task/set', {'param': '{2}'}, wait_for=1.0)
+    @RESTController.args_from_json
+    def set(self, key, param=None):
+        time.sleep(TaskTest.sleep_time)
+        return {'key': key, 'my_param': param}
+
+    @Task('task/delete', ['{key}'], wait_for=1.0)
+    @RESTController.args_from_json
+    def delete(self, key):
+        # pylint: disable=unused-argument
+        time.sleep(TaskTest.sleep_time)
+
+
+class TaskControllerTest(ControllerTestCase):
+    @classmethod
+    def setup_server(cls):
+        # pylint: disable=protected-access
+        NotificationQueue.start_queue()
+        TaskManager.init()
+        TaskController._cp_config['tools.authenticate.on'] = False
+        cls.setup_controllers([TaskTest, TaskController])
+
+    @classmethod
+    def tearDownClass(cls):
+        NotificationQueue.stop()
+
+    def setUp(self):
+        TaskTest.sleep_time = 0.0
+
+    def test_create_task(self):
+        self._task_post('/test/task', {'param': 'hello'})
+        self.assertJsonBody({'my_param': 'hello'})
+
+    def test_long_set_task(self):
+        TaskTest.sleep_time = 2.0
+        self._task_put('/test/task/2', {'param': 'hello'})
+        self.assertJsonBody({'key': '2', 'my_param': 'hello'})
+
+    def test_delete_task(self):
+        self._task_delete('/test/task/hello')