import collections
from datetime import datetime, timedelta
import fnmatch
+from functools import wraps
import importlib
import inspect
import json
from .. import logger
from ..settings import Settings
-from ..tools import Session
+from ..tools import Session, TaskManager
def ApiController(path):
return wrapper
+class Task(object):
+ def __init__(self, name, metadata, wait_for=5.0, exception_handler=None):
+ self.name = name
+ if isinstance(metadata, list):
+ self.metadata = dict([(e[1:-1], e) for e in metadata])
+ else:
+ self.metadata = metadata
+ self.wait_for = wait_for
+ self.exception_handler = exception_handler
+
+ def _gen_arg_map(self, func, args, kwargs):
+ # pylint: disable=deprecated-method
+ arg_map = {}
+ if sys.version_info > (3, 0): # pylint: disable=no-else-return
+ sig = inspect.signature(func)
+ arg_list = [a for a in sig.parameters]
+ else:
+ sig = inspect.getargspec(func)
+ arg_list = [a for a in sig.args]
+
+ for idx, arg in enumerate(arg_list):
+ if idx < len(args):
+ arg_map[arg] = args[idx]
+ else:
+ if arg in kwargs:
+ arg_map[arg] = kwargs[arg]
+ if arg in arg_map:
+ arg_map[idx] = arg_map[arg]
+
+ return arg_map
+
+ def __call__(self, func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ arg_map = self._gen_arg_map(func, args, kwargs)
+ md = {}
+ for k, v in self.metadata.items():
+ if isinstance(v, str) and v and v[0] == '{' and v[-1] == '}':
+ param = v[1:-1]
+ try:
+ pos = int(param)
+ md[k] = arg_map[pos]
+ except ValueError:
+ md[k] = arg_map[v[1:-1]]
+ else:
+ md[k] = v
+ task = TaskManager.run(self.name, md, func, args, kwargs)
+ try:
+ status, value = task.wait(self.wait_for)
+ except Exception as ex:
+ if self.exception_handler:
+ return self.exception_handler(ex)
+ raise ex
+ if status == TaskManager.VALUE_EXECUTING:
+ cherrypy.response.status = 202
+ return {'name': self.name, 'metadata': md}
+ return value
+ wrapper.__wrapped__ = func
+ return wrapper
+
+
class BaseControllerMeta(type):
def __new__(mcs, name, bases, dct):
new_cls = type.__new__(mcs, name, bases, dct)
(v.kind == inspect.Parameter.POSITIONAL_ONLY or
v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD)]
else:
+ func = getattr(func, '__wrapped__', func)
args = inspect.getargspec(func)
nd = len(args.args) if not args.defaults else -len(args.defaults)
cargs = args.args[1:nd]
# -*- coding: utf-8 -*-
-# pylint: disable=W0212
+# pylint: disable=W0212,too-many-arguments
from __future__ import absolute_import
import json
+import threading
+import time
import cherrypy
from cherrypy.test import helper
+from .. import logger
from ..controllers.auth import Auth
from ..controllers import json_error_page, generate_controller_routes
from ..tools import SessionExpireAtBrowserCloseTool
def _put(self, url, data=None):
self._request(url, 'PUT', data)
+ def _task_request(self, method, url, data, timeout):
+ self._request(url, method, data)
+ if self.status != '202 Accepted':
+ logger.info("task finished immediately")
+ return
+
+ res = self.jsonBody()
+ self.assertIsInstance(res, dict)
+ self.assertIn('name', res)
+ self.assertIn('metadata', res)
+
+ task_name = res['name']
+ task_metadata = res['metadata']
+
+ class Waiter(threading.Thread):
+ def __init__(self, task_name, task_metadata, tc):
+ super(Waiter, self).__init__()
+ self.task_name = task_name
+ self.task_metadata = task_metadata
+ self.ev = threading.Event()
+ self.abort = False
+ self.res_task = None
+ self.tc = tc
+
+ def run(self):
+ running = True
+ while running and not self.abort:
+ logger.info("task (%s, %s) is still executing", self.task_name,
+ self.task_metadata)
+ time.sleep(1)
+ self.tc._get('/task?name={}'.format(self.task_name))
+ res = self.tc.jsonBody()
+ for task in res['finished_tasks']:
+ if task['metadata'] == self.task_metadata:
+ # task finished
+ running = False
+ self.res_task = task
+ self.ev.set()
+
+ thread = Waiter(task_name, task_metadata, self)
+ thread.start()
+ status = thread.ev.wait(timeout)
+ if not status:
+ # timeout expired
+ thread.abort = True
+ thread.join()
+ raise Exception("Waiting for task ({}, {}) to finish timed out"
+ .format(task_name, task_metadata))
+ logger.info("task (%s, %s) finished", task_name, task_metadata)
+ if thread.res_task['success']:
+ self.body = json.dumps(thread.res_task['ret_value'])
+ return
+ raise Exception(thread.res_task['exception'])
+
+ def _task_post(self, url, data=None, timeout=60):
+ self._task_request('POST', url, data, timeout)
+
+ def _task_delete(self, url, timeout=60):
+ self._task_request('DELETE', url, None, timeout)
+
+ def _task_put(self, url, data=None, timeout=60):
+ self._task_request('PUT', url, data, timeout)
+
def jsonBody(self):
body_str = self.body.decode('utf-8') if isinstance(self.body, bytes) else self.body
return json.loads(body_str)
--- /dev/null
+# -*- coding: utf-8 -*-
+# pylint: disable=blacklisted-name
+
+import time
+
+from .helper import ControllerTestCase
+from ..controllers import ApiController, RESTController, Task
+from ..controllers.task import Task as TaskController
+from ..tools import NotificationQueue, TaskManager
+
+
+@ApiController('test/task')
+class TaskTest(RESTController):
+ sleep_time = 0.0
+
+ @Task('task/create', {'param': '{param}'}, wait_for=1.0)
+ @RESTController.args_from_json
+ def create(self, param):
+ time.sleep(TaskTest.sleep_time)
+ return {'my_param': param}
+
+ @Task('task/set', {'param': '{2}'}, wait_for=1.0)
+ @RESTController.args_from_json
+ def set(self, key, param=None):
+ time.sleep(TaskTest.sleep_time)
+ return {'key': key, 'my_param': param}
+
+ @Task('task/delete', ['{key}'], wait_for=1.0)
+ @RESTController.args_from_json
+ def delete(self, key):
+ # pylint: disable=unused-argument
+ time.sleep(TaskTest.sleep_time)
+
+
+class TaskControllerTest(ControllerTestCase):
+ @classmethod
+ def setup_server(cls):
+ # pylint: disable=protected-access
+ NotificationQueue.start_queue()
+ TaskManager.init()
+ TaskController._cp_config['tools.authenticate.on'] = False
+ cls.setup_controllers([TaskTest, TaskController])
+
+ @classmethod
+ def tearDownClass(cls):
+ NotificationQueue.stop()
+
+ def setUp(self):
+ TaskTest.sleep_time = 0.0
+
+ def test_create_task(self):
+ self._task_post('/test/task', {'param': 'hello'})
+ self.assertJsonBody({'my_param': 'hello'})
+
+ def test_long_set_task(self):
+ TaskTest.sleep_time = 2.0
+ self._task_put('/test/task/2', {'param': 'hello'})
+ self.assertJsonBody({'key': '2', 'my_param': 'hello'})
+
+ def test_delete_task(self):
+ self._task_delete('/test/task/hello')