from ..services.exception import serialize_dashboard_exception
-def ApiController(path):
- def decorate(cls):
+class Controller(object):
+ def __init__(self, path, base_url=""):
+ self.path = path
+ self.base_url = base_url
+
+ def __call__(self, cls):
cls._cp_controller_ = True
- cls._cp_path_ = path
+ if self.base_url:
+ cls._cp_path_ = "{}/{}".format(self.base_url, self.path)
+ else:
+ cls._cp_path_ = self.path
config = {
'tools.sessions.on': True,
'tools.sessions.name': Session.NAME,
'tools.session_expire_at_browser_close.on': True,
- 'tools.dashboard_exception_handler.on': True,
+ 'tools.dashboard_exception_handler.on': True
}
if not hasattr(cls, '_cp_config'):
cls._cp_config = {}
config['tools.authenticate.on'] = False
cls._cp_config.update(config)
return cls
- return decorate
+
+
+class ApiController(Controller):
+ def __init__(self, path, version=1):
+ if version == 1:
+ base_url = "api"
+ else:
+ base_url = "api/v" + str(version)
+ super(ApiController, self).__init__(path, base_url)
+ self.version = version
+
+ def __call__(self, cls):
+ cls = super(ApiController, self).__call__(cls)
+ cls._api_version = self.version
+ return cls
def AuthRequired(enabled=True):
return controllers
+ENDPOINT_MAP = collections.defaultdict(list)
+
+
def generate_controller_routes(ctrl_class, mapper, base_url):
inst = ctrl_class()
- for methods, url_suffix, action, params in ctrl_class.endpoints():
- if not url_suffix:
- name = ctrl_class.__name__
- url = "{}/{}".format(base_url, ctrl_class._cp_path_)
+ endp_base_urls = set()
+
+ for endpoint in ctrl_class.endpoints():
+ conditions = dict(method=endpoint.methods) if endpoint.methods else None
+ endp_url = endpoint.url
+ if '/' in endp_url:
+ endp_base_urls.add(endp_url[:endp_url.find('/')])
else:
- name = "{}:{}".format(ctrl_class.__name__, url_suffix)
- url = "{}/{}/{}".format(base_url, ctrl_class._cp_path_, url_suffix)
+ endp_base_urls.add(endp_url)
+ url = "{}/{}".format(base_url, endp_url)
- if params:
- for param in params:
- url = "{}/:{}".format(url, param)
+ logger.debug("Mapped [%s] to %s:%s restricted to %s",
+ url, ctrl_class.__name__, endpoint.action,
+ endpoint.methods)
- conditions = dict(method=methods) if methods else None
+ ENDPOINT_MAP[endpoint.url].append(endpoint)
- logger.debug("Mapping [%s] to %s:%s restricted to %s",
- url, ctrl_class.__name__, action, methods)
- mapper.connect(name, url, controller=inst, action=action,
+ name = ctrl_class.__name__ + ":" + endpoint.action
+ mapper.connect(name, url, controller=inst, action=endpoint.action,
conditions=conditions)
# adding route with trailing slash
name += "/"
url += "/"
- mapper.connect(name, url, controller=inst, action=action,
+ mapper.connect(name, url, controller=inst, action=endpoint.action,
conditions=conditions)
+ return endp_base_urls
+
def generate_routes(url_prefix):
mapper = cherrypy.dispatch.RoutesDispatcher()
ctrls = load_controllers()
+
+ parent_urls = set()
for ctrl in ctrls:
- generate_controller_routes(ctrl, mapper, "{}/api".format(url_prefix))
+ parent_urls.update(generate_controller_routes(ctrl, mapper,
+ "{}".format(url_prefix)))
- return mapper
+ logger.debug("list of parent paths: %s", parent_urls)
+ return mapper, parent_urls
def json_error_page(status, message, traceback, version):
return wrapper
+def _get_function_params(func):
+ """
+ Retrieves the list of parameters declared in function.
+ Each parameter is represented as dict with keys:
+ * name (str): the name of the parameter
+ * required (bool): whether the parameter is required or not
+ * default (obj): the parameter's default value
+ """
+ fspec = getargspec(func)
+
+ func_params = []
+ nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults)
+ for param in fspec.args[1:nd]:
+ func_params.append({'name': param, 'required': True})
+
+ if fspec.defaults:
+ for param, val in zip(fspec.args[nd:], fspec.defaults):
+ func_params.append({
+ 'name': param,
+ 'required': False,
+ 'default': val
+ })
+
+ return func_params
+
+
class BaseController(object):
"""
Base class for all controllers providing API endpoints.
"""
+ class Endpoint(object):
+ """
+ An instance of this class represents an endpoint.
+ """
+ def __init__(self, ctrl, func, methods=None):
+ self.ctrl = ctrl
+ self.func = self._unwrap(func)
+ if methods is None:
+ methods = []
+ self.methods = methods
+
+ @classmethod
+ def _unwrap(cls, func):
+ while hasattr(func, "__wrapped__"):
+ func = func.__wrapped__
+ return func
+
+ @property
+ def url(self):
+ ctrl_path_params = self.ctrl.get_path_param_names()
+ if self.func.__name__ != '__call__':
+ url = "{}/{}".format(self.ctrl.get_path(), self.func.__name__)
+ else:
+ url = self.ctrl.get_path()
+ path_params = [
+ p['name'] for p in _get_function_params(self.func)
+ if p['required'] and p['name'] not in ctrl_path_params]
+ path_params = ["{{{}}}".format(p) for p in path_params]
+ if path_params:
+ url += "/{}".format("/".join(path_params))
+ return url
+
+ @property
+ def action(self):
+ return self.func.__name__
+
+ @property
+ def path_params(self):
+ return [p for p in _get_function_params(self.func) if p['required']]
+
+ @property
+ def query_params(self):
+ return [p for p in _get_function_params(self.func)
+ if not p['required']]
+
+ @property
+ def body_params(self):
+ return []
+
+ @property
+ def group(self):
+ return self.ctrl.__name__
+
+ @property
+ def is_api(self):
+ return hasattr(self.ctrl, '_api_version')
+
+ @property
+ def is_secure(self):
+ return self.ctrl._cp_config['tools.authenticate.on']
+
+ def __repr__(self):
+ return "Endpoint({}, {}, {})".format(self.url, self.methods,
+ self.action)
+
def __init__(self):
- logger.info('Initializing controller: %s -> /api/%s',
+ logger.info('Initializing controller: %s -> /%s',
self.__class__.__name__, self._cp_path_)
@classmethod
- def _parse_function_args(cls, func):
- args = getargspec(func)
- nd = len(args.args) if not args.defaults else -len(args.defaults)
- cargs = args.args[1:nd]
-
- # filter out controller path params
- for idx, step in enumerate(cls._cp_path_.split('/')):
+ def get_path_param_names(cls):
+ path_params = []
+ for step in cls._cp_path_.split('/'):
param = None
if step[0] == ':':
param = step[1:]
- if step[0] == '{' and step[-1] == '}' and ':' in step[1:-1]:
- param, _, _regex = step[1:-1].partition(':')
-
+ elif step[0] == '{' and step[-1] == '}':
+ param, _, _ = step[1:-1].partition(':')
if param:
- if param not in cargs:
- raise Exception("function '{}' does not have the"
- " positional argument '{}' in the {} "
- "position".format(func, param, idx))
- cargs.remove(param)
- return cargs
+ path_params.append(param)
+ return path_params
+
+ @classmethod
+ def get_path(cls):
+ return cls._cp_path_
@classmethod
def endpoints(cls):
"""
- The endpoints method returns a list of endpoints. Each endpoint
- consists of a tuple with methods, URL suffix, an action and its
- arguments.
-
- By default, endpoints will be methods of the BaseController class,
- which have been decorated by the @cherrpy.expose decorator. A method
- will also be considered an endpoint if the `exposed` attribute has been
- set on the method to a value which evaluates to True, which is
- basically what @cherrpy.expose does, too.
-
- :return: A tuple of methods, url_suffix, action and arguments of the
- function
- :rtype: list[tuple]
+ This method iterates over all the methods decorated with ``@endpoint``
+ and creates an Endpoint object for each one of the methods.
+
+ :return: A list of endpoint objects
+ :rtype: list[BaseController.Endpoint]
"""
result = []
- for name, func in inspect.getmembers(cls, predicate=callable):
+ for _, func in inspect.getmembers(cls, predicate=callable):
if hasattr(func, 'exposed') and func.exposed:
- args = cls._parse_function_args(func)
- methods = []
- url_suffix = name
- action = name
- if name == '__call__':
- url_suffix = None
- result.append((methods, url_suffix, action, args))
+ result.append(cls.Endpoint(cls, func))
return result
# resource id parameter for using in get, set, and delete methods
# should be overriden by subclasses.
# to specify a composite id (two parameters) use '/'. e.g., "param1/param2".
- # If subclasses don't override this property we try to infer the structure of
- # the resourse ID.
+ # If subclasses don't override this property we try to infer the structure
+ # of the resourse ID.
RESOURCE_ID = None
_method_mapping = collections.OrderedDict([
- (('GET', False), ('list', 200)),
- (('PUT', False), ('bulk_set', 200)),
- (('PATCH', False), ('bulk_set', 200)),
- (('POST', False), ('create', 201)),
- (('DELETE', False), ('bulk_delete', 204)),
- (('GET', True), ('get', 200)),
- (('DELETE', True), ('delete', 204)),
- (('PUT', True), ('set', 200)),
- (('PATCH', True), ('set', 200))
+ ('list', {'method': 'GET', 'resource': False, 'status': 200}),
+ ('create', {'method': 'POST', 'resource': False, 'status': 201}),
+ ('bulk_set', {'method': 'PUT', 'resource': False, 'status': 200}),
+ ('bulk_delete', {'method': 'DELETE', 'resource': False, 'status': 204}),
+ ('get', {'method': 'GET', 'resource': True, 'status': 200}),
+ ('delete', {'method': 'DELETE', 'resource': True, 'status': 204}),
+ ('set', {'method': 'PUT', 'resource': True, 'status': 200})
])
- @classmethod
- def endpoints(cls):
- # pylint: disable=too-many-branches
-
- result = []
- for attr, val in inspect.getmembers(cls, predicate=callable):
- if hasattr(val, 'exposed') and val.exposed and \
- attr != '_collection' and attr != '_element':
- result.append(([], attr, attr, cls._parse_function_args(val)))
-
- for k, v in cls._method_mapping.items():
- func = getattr(cls, v[0], None)
- if not k[1] and func:
- if k[0] != 'PATCH': # we already wrapped in PUT
- wrapper = cls._rest_request_wrapper(func, v[1])
- setattr(cls, v[0], wrapper)
- else:
- wrapper = func
- result.append(([k[0]], None, v[0], []))
+ class RESTEndpoint(BaseController.Endpoint):
+ def __init__(self, ctrl, func):
+ if func.__name__ in ctrl._method_mapping:
+ methods = [ctrl._method_mapping[func.__name__]['method']]
+ status = ctrl._method_mapping[func.__name__]['status']
+ elif hasattr(func, "_resource_method_"):
+ methods = func._resource_method_
+ status = 200
+ elif hasattr(func, "_collection_method_"):
+ methods = func._collection_method_
+ status = 200
+ else:
+ assert False
+
+ wrapper = ctrl._rest_request_wrapper(func, status)
+ setattr(ctrl, func.__name__, wrapper)
+
+ super(RESTController.RESTEndpoint, self).__init__(
+ ctrl, func, methods)
+
+ def get_resource_id_params(self):
+ if self.func.__name__ in self.ctrl._method_mapping:
+ if self.ctrl._method_mapping[self.func.__name__]['resource']:
+ resource_id_params = self.ctrl.infer_resource_id()
+ if resource_id_params:
+ return resource_id_params
+
+ if hasattr(self.func, '_resource_method_'):
+ resource_id_params = self.ctrl.infer_resource_id()
+ if resource_id_params:
+ return resource_id_params
+
+ return []
+
+ @property
+ def url(self):
+ url = self.ctrl.get_path()
+
+ res_id_params = self.get_resource_id_params()
+ if res_id_params:
+ res_id_params = ["{{{}}}".format(p) for p in res_id_params]
+ url += "/{}".format("/".join(res_id_params))
+
+ if hasattr(self.func, "_collection_method_") \
+ or hasattr(self.func, "_resource_method_"):
+ url += "/{}".format(self.func.__name__)
+ return url
+
+ @property
+ def path_params(self):
+ params = [{'name': p, 'required': True}
+ for p in self.ctrl.get_path_param_names()]
+ params.extend([{'name': p, 'required': True}
+ for p in self.get_resource_id_params()])
+ return params
+
+ @property
+ def query_params(self):
+ path_params_names = [p['name'] for p in self.path_params]
+ if 'GET' in self.methods or 'DELETE' in self.methods:
+ return [p for p in _get_function_params(self.func)
+ if p['name'] not in path_params_names]
+ return []
+
+ @property
+ def body_params(self):
+ path_params_names = [p['name'] for p in self.path_params]
+ if 'POST' in self.methods or 'PUT' in self.methods:
+ return [p for p in _get_function_params(self.func)
+ if p['name'] not in path_params_names]
+ return []
- args = []
+ @classmethod
+ def infer_resource_id(cls):
+ if cls.RESOURCE_ID is not None:
+ return cls.RESOURCE_ID.split('/')
for k, v in cls._method_mapping.items():
- func = getattr(cls, v[0], None)
- if k[1] and func:
- if k[0] != 'PATCH': # we already wrapped in PUT
- wrapper = cls._rest_request_wrapper(func, v[1])
- setattr(cls, v[0], wrapper)
- else:
- wrapper = func
- if not args:
- if cls.RESOURCE_ID is None:
- args = cls._parse_function_args(func)
- else:
- args = cls.RESOURCE_ID.split('/')
- result.append(([k[0]], None, v[0], args))
-
- for attr, val in inspect.getmembers(cls, predicate=callable):
- if hasattr(val, '_collection_method_'):
- wrapper = cls._rest_request_wrapper(val, 200)
- setattr(cls, attr, wrapper)
- result.append(
- (val._collection_method_, attr, attr, []))
-
- for attr, val in inspect.getmembers(cls, predicate=callable):
- if hasattr(val, '_resource_method_'):
- wrapper = cls._rest_request_wrapper(val, 200)
- setattr(cls, attr, wrapper)
- res_params = [":{}".format(arg) for arg in args]
- url_suffix = "{}/{}".format("/".join(res_params), attr)
- result.append(
- (val._resource_method_, url_suffix, attr, []))
+ func = getattr(cls, k, None)
+ while hasattr(func, "__wrapped__"):
+ func = func.__wrapped__
+ if v['resource'] and func:
+ path_params = cls.get_path_param_names()
+ params = _get_function_params(func)
+ return [p['name'] for p in params
+ if p['required'] and p['name'] not in path_params]
+ return None
+ @classmethod
+ def endpoints(cls):
+ result = []
+ for _, val in inspect.getmembers(cls, predicate=callable):
+ if val.__name__ in cls._method_mapping:
+ result.append(cls.RESTEndpoint(cls, val))
+ elif hasattr(val, "_collection_method_") \
+ or hasattr(val, "_resource_method_"):
+ result.append(cls.RESTEndpoint(cls, val))
+ elif hasattr(val, 'exposed') and val.exposed:
+ result.append(cls.Endpoint(cls, val))
return result
@classmethod
def _rest_request_wrapper(cls, func, status_code):
+ @wraps(func)
def wrapper(*vpath, **params):
method = func
if cherrypy.request.method not in ['GET', 'DELETE']:
cherrypy.response.status = status_code
return method(*vpath, **params)
+ if not hasattr(wrapper, '__wrapped__'):
+ wrapper.__wrapped__ = func
return wrapper
@staticmethod
@staticmethod
def _takes_json(func):
def inner(*args, **kwargs):
- if cherrypy.request.headers.get('Content-Type',
- '') == 'application/x-www-form-urlencoded':
+ if cherrypy.request.headers.get('Content-Type', '') == \
+ 'application/x-www-form-urlencoded':
return func(*args, **kwargs)
content_length = int(cherrypy.request.headers['Content-Length'])