From dcbae67e9870cc59218ddf3133c5550aa430a820 Mon Sep 17 00:00:00 2001 From: Ricardo Dias Date: Wed, 25 Jul 2018 14:13:20 +0100 Subject: [PATCH] mgr/dashboard: fix query parameters in task annotated endpoints Fixes: http://tracker.ceph.com/issues/25096 Signed-off-by: Ricardo Dias --- .../mgr/dashboard/controllers/__init__.py | 77 +++++++++---------- .../mgr/dashboard/tests/test_rest_tasks.py | 9 +++ 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/pybind/mgr/dashboard/controllers/__init__.py b/src/pybind/mgr/dashboard/controllers/__init__.py index 9c782f51e96aa..caa696a4ad64a 100644 --- a/src/pybind/mgr/dashboard/controllers/__init__.py +++ b/src/pybind/mgr/dashboard/controllers/__init__.py @@ -257,6 +257,32 @@ def json_error_page(status, message, traceback, version): version=version)) +def _get_function_params(func): + """ + Retrieves the list of parameters declared in function. + Each parameter is represented as dict with keys: + * name (str): the name of the parameter + * required (bool): whether the parameter is required or not + * default (obj): the parameter's default value + """ + fspec = getargspec(func) + + func_params = [] + nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults) + for param in fspec.args[1:nd]: + func_params.append({'name': param, 'required': True}) + + if fspec.defaults: + for param, val in zip(fspec.args[nd:], fspec.defaults): + func_params.append({ + 'name': param, + 'required': False, + 'default': val + }) + + return func_params + + class Task(object): def __init__(self, name, metadata, wait_for=5.0, exception_handler=None): self.name = name @@ -268,24 +294,23 @@ class Task(object): self.exception_handler = exception_handler def _gen_arg_map(self, func, args, kwargs): - # pylint: disable=deprecated-method arg_map = {} - if sys.version_info > (3, 0): # pylint: disable=no-else-return - sig = inspect.signature(func) - arg_list = [a for a in sig.parameters] - else: - sig = getargspec(func) - arg_list = [a for a in sig.args] + params = _get_function_params(func) - for idx, arg in enumerate(arg_list): + args = args[1:] # exclude self + for idx, param in enumerate(params): if idx < len(args): - arg_map[arg] = args[idx] + arg_map[param['name']] = args[idx] else: - if arg in kwargs: - arg_map[arg] = kwargs[arg] - if arg in arg_map: + if param['name'] in kwargs: + arg_map[param['name']] = kwargs[param['name']] + else: + assert not param['required'] + arg_map[param['name']] = param['default'] + + if param['name'] in arg_map: # This is not a type error. We are using the index here. - arg_map[idx] = arg_map[arg] + arg_map[idx+1] = arg_map[param['name']] return arg_map @@ -325,32 +350,6 @@ class Task(object): return wrapper -def _get_function_params(func): - """ - Retrieves the list of parameters declared in function. - Each parameter is represented as dict with keys: - * name (str): the name of the parameter - * required (bool): whether the parameter is required or not - * default (obj): the parameter's default value - """ - fspec = getargspec(func) - - func_params = [] - nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults) - for param in fspec.args[1:nd]: - func_params.append({'name': param, 'required': True}) - - if fspec.defaults: - for param, val in zip(fspec.args[nd:], fspec.defaults): - func_params.append({ - 'name': param, - 'required': False, - 'default': val - }) - - return func_params - - class BaseController(object): """ Base class for all controllers providing API endpoints. diff --git a/src/pybind/mgr/dashboard/tests/test_rest_tasks.py b/src/pybind/mgr/dashboard/tests/test_rest_tasks.py index 6e8f01b40b773..ad871c9410e9e 100644 --- a/src/pybind/mgr/dashboard/tests/test_rest_tasks.py +++ b/src/pybind/mgr/dashboard/tests/test_rest_tasks.py @@ -38,6 +38,11 @@ class TaskTest(RESTController): def bar(self, key, param=None): return {'my_param': param, 'key': key} + @Task('task/query', ['{param}']) + @RESTController.Collection('POST', query_params=['param']) + def query(self, param=None): + return {'my_param': param} + class TaskControllerTest(ControllerTestCase): @classmethod @@ -75,3 +80,7 @@ class TaskControllerTest(ControllerTestCase): def test_bar_task(self): self._task_put('/test/task/3/bar', {'param': 'hello'}) self.assertJsonBody({'my_param': 'hello', 'key': '3'}) + + def test_query_param(self): + self._task_post('/test/task/query') + self.assertJsonBody({'my_param': None}) -- 2.39.5