]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
mgr/dashboard: fix query parameters in task annotated endpoints
authorRicardo Dias <rdias@suse.com>
Wed, 25 Jul 2018 13:13:20 +0000 (14:13 +0100)
committerRicardo Dias <rdias@suse.com>
Wed, 25 Jul 2018 13:16:13 +0000 (14:16 +0100)
Fixes: http://tracker.ceph.com/issues/25096
Signed-off-by: Ricardo Dias <rdias@suse.com>
src/pybind/mgr/dashboard/controllers/__init__.py
src/pybind/mgr/dashboard/tests/test_rest_tasks.py

index 9c782f51e96aab1864748365cd34c80cf99b25a5..caa696a4ad64a241c4067dc98526a0ac2cf5d3cc 100644 (file)
@@ -257,6 +257,32 @@ def json_error_page(status, message, traceback, version):
                            version=version))
 
 
+def _get_function_params(func):
+    """
+    Retrieves the list of parameters declared in function.
+    Each parameter is represented as dict with keys:
+      * name (str): the name of the parameter
+      * required (bool): whether the parameter is required or not
+      * default (obj): the parameter's default value
+    """
+    fspec = getargspec(func)
+
+    func_params = []
+    nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults)
+    for param in fspec.args[1:nd]:
+        func_params.append({'name': param, 'required': True})
+
+    if fspec.defaults:
+        for param, val in zip(fspec.args[nd:], fspec.defaults):
+            func_params.append({
+                'name': param,
+                'required': False,
+                'default': val
+            })
+
+    return func_params
+
+
 class Task(object):
     def __init__(self, name, metadata, wait_for=5.0, exception_handler=None):
         self.name = name
@@ -268,24 +294,23 @@ class Task(object):
         self.exception_handler = exception_handler
 
     def _gen_arg_map(self, func, args, kwargs):
-        # pylint: disable=deprecated-method
         arg_map = {}
-        if sys.version_info > (3, 0):  # pylint: disable=no-else-return
-            sig = inspect.signature(func)
-            arg_list = [a for a in sig.parameters]
-        else:
-            sig = getargspec(func)
-            arg_list = [a for a in sig.args]
+        params = _get_function_params(func)
 
-        for idx, arg in enumerate(arg_list):
+        args = args[1:]  # exclude self
+        for idx, param in enumerate(params):
             if idx < len(args):
-                arg_map[arg] = args[idx]
+                arg_map[param['name']] = args[idx]
             else:
-                if arg in kwargs:
-                    arg_map[arg] = kwargs[arg]
-            if arg in arg_map:
+                if param['name'] in kwargs:
+                    arg_map[param['name']] = kwargs[param['name']]
+                else:
+                    assert not param['required']
+                    arg_map[param['name']] = param['default']
+
+            if param['name'] in arg_map:
                 # This is not a type error. We are using the index here.
-                arg_map[idx] = arg_map[arg]
+                arg_map[idx+1] = arg_map[param['name']]
 
         return arg_map
 
@@ -325,32 +350,6 @@ class Task(object):
         return wrapper
 
 
-def _get_function_params(func):
-    """
-    Retrieves the list of parameters declared in function.
-    Each parameter is represented as dict with keys:
-      * name (str): the name of the parameter
-      * required (bool): whether the parameter is required or not
-      * default (obj): the parameter's default value
-    """
-    fspec = getargspec(func)
-
-    func_params = []
-    nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults)
-    for param in fspec.args[1:nd]:
-        func_params.append({'name': param, 'required': True})
-
-    if fspec.defaults:
-        for param, val in zip(fspec.args[nd:], fspec.defaults):
-            func_params.append({
-                'name': param,
-                'required': False,
-                'default': val
-            })
-
-    return func_params
-
-
 class BaseController(object):
     """
     Base class for all controllers providing API endpoints.
index 6e8f01b40b7738da07e1fa728c4f2a2383ae8dd6..ad871c9410e9e216801f986ea8cda633d18382de 100644 (file)
@@ -38,6 +38,11 @@ class TaskTest(RESTController):
     def bar(self, key, param=None):
         return {'my_param': param, 'key': key}
 
+    @Task('task/query', ['{param}'])
+    @RESTController.Collection('POST', query_params=['param'])
+    def query(self, param=None):
+        return {'my_param': param}
+
 
 class TaskControllerTest(ControllerTestCase):
     @classmethod
@@ -75,3 +80,7 @@ class TaskControllerTest(ControllerTestCase):
     def test_bar_task(self):
         self._task_put('/test/task/3/bar', {'param': 'hello'})
         self.assertJsonBody({'my_param': 'hello', 'key': '3'})
+
+    def test_query_param(self):
+        self._task_post('/test/task/query')
+        self.assertJsonBody({'my_param': None})