# -*- coding: utf-8 -*-
-# pylint: disable=protected-access
+# pylint: disable=protected-access,too-many-branches
from __future__ import absolute_import
import collections
class Controller(object):
- def __init__(self, path, base_url=""):
+ def __init__(self, path, base_url=None):
self.path = path
self.base_url = base_url
+ if self.path and self.path[0] != "/":
+ self.path = "/" + self.path
+
+ if self.base_url is None:
+ self.base_url = ""
+ elif self.base_url == "/":
+ self.base_url = ""
+
+ if self.base_url == "" and self.path == "":
+ self.base_url = "/"
+
def __call__(self, cls):
cls._cp_controller_ = True
- if self.base_url:
- cls._cp_path_ = "{}/{}".format(self.base_url, self.path)
- else:
- cls._cp_path_ = self.path
+ cls._cp_path_ = "{}{}".format(self.base_url, self.path)
+
config = {
'tools.sessions.on': True,
'tools.sessions.name': Session.NAME,
class ApiController(Controller):
- def __init__(self, path, version=1):
- if version == 1:
- base_url = "api"
- else:
- base_url = "api/v" + str(version)
- super(ApiController, self).__init__(path, base_url)
- self.version = version
+ def __init__(self, path):
+ super(ApiController, self).__init__(path, base_url="/api")
def __call__(self, cls):
cls = super(ApiController, self).__call__(cls)
- cls._api_version = self.version
+ cls._api_endpoint = True
return cls
return decorate
+def Endpoint(method=None, path=None, path_params=None, query_params=None,
+ json_response=True, proxy=False):
+
+ if method is None:
+ method = 'GET'
+ elif not isinstance(method, str) or \
+ method.upper() not in ['GET', 'POST', 'DELETE', 'PUT']:
+ raise TypeError("Possible values for method are: 'GET', 'POST', "
+ "'DELETE', or 'PUT'")
+
+ method = method.upper()
+
+ if method in ['GET', 'DELETE']:
+ if path_params is not None:
+ raise TypeError("path_params should not be used for {} "
+ "endpoints. All function params are considered"
+ " path parameters by default".format(method))
+
+ if path_params is None:
+ if method in ['POST', 'PUT']:
+ path_params = []
+
+ if query_params is None:
+ query_params = []
+
+ def _wrapper(func):
+ if method in ['POST', 'PUT']:
+ func_params = _get_function_params(func)
+ for param in func_params:
+ if param['name'] in path_params and not param['required']:
+ raise TypeError("path_params can only reference "
+ "non-optional function parameters")
+
+ if func.__name__ == '__call__' and path is None:
+ e_path = ""
+ else:
+ e_path = path
+
+ if e_path is not None:
+ e_path = e_path.strip()
+ if e_path and e_path[0] != "/":
+ e_path = "/" + e_path
+ elif e_path == "/":
+ e_path = ""
+
+ func._endpoint = {
+ 'method': method,
+ 'path': e_path,
+ 'path_params': path_params,
+ 'query_params': query_params,
+ 'json_response': json_response,
+ 'proxy': proxy
+ }
+ return func
+ return _wrapper
+
+
+def Proxy(path=None):
+ if path is None:
+ path = ""
+ elif path == "/":
+ path = ""
+ path += "/{path:.*}"
+ return Endpoint(path=path, proxy=True)
+
+
def load_controllers():
# setting sys.path properly when not running under the mgr
controllers_dir = os.path.dirname(os.path.realpath(__file__))
endp_base_urls = set()
for endpoint in ctrl_class.endpoints():
- conditions = dict(method=endpoint.methods) if endpoint.methods else None
+ if endpoint.proxy:
+ conditions = None
+ else:
+ conditions = dict(method=[endpoint.method])
+
endp_url = endpoint.url
- if '/' in endp_url:
- endp_base_urls.add(endp_url[:endp_url.find('/')])
+ if base_url == "/":
+ base_url = ""
+ if endp_url == "/" and base_url:
+ endp_url = ""
+ url = "{}{}".format(base_url, endp_url)
+
+ if '/' in url[len(base_url)+1:]:
+ endp_base_urls.add(url[:len(base_url)+1+endp_url[1:].find('/')])
else:
- endp_base_urls.add(endp_url)
- url = "{}/{}".format(base_url, endp_url)
+ endp_base_urls.add(url)
logger.debug("Mapped [%s] to %s:%s restricted to %s",
url, ctrl_class.__name__, endpoint.action,
- endpoint.methods)
+ endpoint.method)
ENDPOINT_MAP[endpoint.url].append(endpoint)
"""
An instance of this class represents an endpoint.
"""
- def __init__(self, ctrl, func, methods=None):
+ def __init__(self, ctrl, func):
self.ctrl = ctrl
- self.func = self._unwrap(func)
- if methods is None:
- methods = []
- self.methods = methods
+ self.func = func
- @classmethod
- def _unwrap(cls, func):
- while hasattr(func, "__wrapped__"):
- func = func.__wrapped__
- return func
+ if not self.config['proxy']:
+ setattr(self.ctrl, func.__name__, self.function)
+
+ @property
+ def config(self):
+ func = self.func
+ while not hasattr(func, '_endpoint'):
+ if hasattr(func, "__wrapped__"):
+ func = func.__wrapped__
+ else:
+ return None
+ return func._endpoint
+
+ @property
+ def function(self):
+ return self.ctrl._request_wrapper(self.func, self.method,
+ self.config['json_response'])
+
+ @property
+ def method(self):
+ return self.config['method']
+
+ @property
+ def proxy(self):
+ return self.config['proxy']
@property
def url(self):
- ctrl_path_params = self.ctrl.get_path_param_names()
- if self.func.__name__ != '__call__':
- url = "{}/{}".format(self.ctrl.get_path(), self.func.__name__)
+ if self.config['path'] is not None:
+ url = "{}{}".format(self.ctrl.get_path(), self.config['path'])
else:
- url = self.ctrl.get_path()
- path_params = [
- p['name'] for p in _get_function_params(self.func)
- if p['required'] and p['name'] not in ctrl_path_params]
+ url = "{}/{}".format(self.ctrl.get_path(), self.func.__name__)
+
+ ctrl_path_params = self.ctrl.get_path_param_names(
+ self.config['path'])
+ path_params = [p['name'] for p in self.path_params
+ if p['name'] not in ctrl_path_params]
path_params = ["{{{}}}".format(p) for p in path_params]
if path_params:
url += "/{}".format("/".join(path_params))
+
return url
@property
@property
def path_params(self):
- return [p for p in _get_function_params(self.func) if p['required']]
+ ctrl_path_params = self.ctrl.get_path_param_names(
+ self.config['path'])
+ func_params = _get_function_params(self.func)
+
+ if self.method in ['GET', 'DELETE']:
+ assert self.config['path_params'] is None
+
+ return [p for p in func_params if p['name'] in ctrl_path_params
+ or (p['name'] not in self.config['query_params']
+ and p['required'])]
+
+ # elif self.method in ['POST', 'PUT']:
+ return [p for p in func_params if p['name'] in ctrl_path_params
+ or p['name'] in self.config['path_params']]
@property
def query_params(self):
- return [p for p in _get_function_params(self.func)
- if not p['required']]
+ if self.method in ['GET', 'DELETE']:
+ func_params = _get_function_params(self.func)
+ path_params = [p['name'] for p in self.path_params]
+ return [p for p in func_params if p['name'] not in path_params]
+
+ # elif self.method in ['POST', 'PUT']:
+ func_params = _get_function_params(self.func)
+ return [p for p in func_params
+ if p['name'] in self.config['query_params']]
@property
def body_params(self):
- return []
+ func_params = _get_function_params(self.func)
+ path_params = [p['name'] for p in self.path_params]
+ query_params = [p['name'] for p in self.query_params]
+ return [p for p in func_params
+ if p['name'] not in path_params and
+ p['name'] not in query_params]
@property
def group(self):
@property
def is_api(self):
- return hasattr(self.ctrl, '_api_version')
+ return hasattr(self.ctrl, '_api_endpoint')
@property
def is_secure(self):
return self.ctrl._cp_config['tools.authenticate.on']
def __repr__(self):
- return "Endpoint({}, {}, {})".format(self.url, self.methods,
+ return "Endpoint({}, {}, {})".format(self.url, self.method,
self.action)
def __init__(self):
- logger.info('Initializing controller: %s -> /%s',
+ logger.info('Initializing controller: %s -> %s',
self.__class__.__name__, self._cp_path_)
@classmethod
- def get_path_param_names(cls):
+ def get_path_param_names(cls, path_extension=None):
+ if path_extension is None:
+ path_extension = ""
+ full_path = cls._cp_path_[1:] + path_extension
path_params = []
- for step in cls._cp_path_.split('/'):
+ for step in full_path.split('/'):
param = None
+ if not step:
+ continue
if step[0] == ':':
param = step[1:]
elif step[0] == '{' and step[-1] == '}':
:rtype: list[BaseController.Endpoint]
"""
result = []
-
for _, func in inspect.getmembers(cls, predicate=callable):
- if hasattr(func, 'exposed') and func.exposed:
+ if hasattr(func, '_endpoint'):
result.append(cls.Endpoint(cls, func))
return result
+ @staticmethod
+ def _request_wrapper(func, method, json_response):
+ @wraps(func)
+ def inner(*args, **kwargs):
+ if method in ['GET', 'DELETE']:
+ ret = func(*args, **kwargs)
+
+ elif cherrypy.request.headers.get('Content-Type', '') == \
+ 'application/x-www-form-urlencoded':
+ ret = func(*args, **kwargs)
+
+ else:
+ content_length = int(cherrypy.request.headers['Content-Length'])
+ body = cherrypy.request.body.read(content_length)
+ if not body:
+ return func(*args, **kwargs)
+
+ try:
+ data = json.loads(body.decode('utf-8'))
+ except Exception as e:
+ raise cherrypy.HTTPError(400, 'Failed to decode JSON: {}'
+ .format(str(e)))
+ kwargs.update(data.items())
+ ret = func(*args, **kwargs)
+
+ if json_response:
+ cherrypy.response.headers['Content-Type'] = 'application/json'
+ return json.dumps(ret).encode('utf8')
+ return ret
+ return inner
+
class RESTController(BaseController):
"""
('set', {'method': 'PUT', 'resource': True, 'status': 200})
])
- class RESTEndpoint(BaseController.Endpoint):
- def __init__(self, ctrl, func):
- if func.__name__ in ctrl._method_mapping:
- methods = [ctrl._method_mapping[func.__name__]['method']]
- status = ctrl._method_mapping[func.__name__]['status']
- elif hasattr(func, "_resource_method_"):
- methods = func._resource_method_
- status = 200
- elif hasattr(func, "_collection_method_"):
- methods = func._collection_method_
- status = 200
- else:
- assert False
-
- wrapper = ctrl._rest_request_wrapper(func, status)
- setattr(ctrl, func.__name__, wrapper)
-
- super(RESTController.RESTEndpoint, self).__init__(
- ctrl, func, methods)
-
- def get_resource_id_params(self):
- if self.func.__name__ in self.ctrl._method_mapping:
- if self.ctrl._method_mapping[self.func.__name__]['resource']:
- resource_id_params = self.ctrl.infer_resource_id()
- if resource_id_params:
- return resource_id_params
-
- if hasattr(self.func, '_resource_method_'):
- resource_id_params = self.ctrl.infer_resource_id()
- if resource_id_params:
- return resource_id_params
-
- return []
-
- @property
- def url(self):
- url = self.ctrl.get_path()
-
- res_id_params = self.get_resource_id_params()
- if res_id_params:
- res_id_params = ["{{{}}}".format(p) for p in res_id_params]
- url += "/{}".format("/".join(res_id_params))
-
- if hasattr(self.func, "_collection_method_") \
- or hasattr(self.func, "_resource_method_"):
- url += "/{}".format(self.func.__name__)
- return url
-
- @property
- def path_params(self):
- params = [{'name': p, 'required': True}
- for p in self.ctrl.get_path_param_names()]
- params.extend([{'name': p, 'required': True}
- for p in self.get_resource_id_params()])
- return params
-
- @property
- def query_params(self):
- path_params_names = [p['name'] for p in self.path_params]
- if 'GET' in self.methods or 'DELETE' in self.methods:
- return [p for p in _get_function_params(self.func)
- if p['name'] not in path_params_names]
- return []
-
- @property
- def body_params(self):
- path_params_names = [p['name'] for p in self.path_params]
- if 'POST' in self.methods or 'PUT' in self.methods:
- return [p for p in _get_function_params(self.func)
- if p['name'] not in path_params_names]
- return []
-
@classmethod
def infer_resource_id(cls):
if cls.RESOURCE_ID is not None:
@classmethod
def endpoints(cls):
- result = []
- for _, val in inspect.getmembers(cls, predicate=callable):
- if val.__name__ in cls._method_mapping:
- result.append(cls.RESTEndpoint(cls, val))
- elif hasattr(val, "_collection_method_") \
- or hasattr(val, "_resource_method_"):
- result.append(cls.RESTEndpoint(cls, val))
- elif hasattr(val, 'exposed') and val.exposed:
- result.append(cls.Endpoint(cls, val))
- return result
+ result = super(RESTController, cls).endpoints()
+ for _, func in inspect.getmembers(cls, predicate=callable):
+ no_resource_id_params = False
+ status = 200
+ method = None
+ path = ""
+
+ if func.__name__ in cls._method_mapping:
+ meth = cls._method_mapping[func.__name__]
+
+ if meth['resource']:
+ res_id_params = cls.infer_resource_id()
+ if res_id_params is None:
+ no_resource_id_params = True
+ else:
+ res_id_params = ["{{{}}}".format(p) for p in res_id_params]
+ path += "/{}".format("/".join(res_id_params))
- @classmethod
- def _rest_request_wrapper(cls, func, status_code):
- @wraps(func)
- def wrapper(*vpath, **params):
- method = func
- if cherrypy.request.method not in ['GET', 'DELETE']:
- method = RESTController._takes_json(method)
+ status = meth['status']
+ method = meth['method']
- method = RESTController._returns_json(method)
+ elif hasattr(func, "_collection_method_"):
+ path = "/{}".format(func.__name__)
+ method = func._collection_method_
- cherrypy.response.status = status_code
+ elif hasattr(func, "_resource_method_"):
+ res_id_params = cls.infer_resource_id()
+ if res_id_params is None:
+ no_resource_id_params = True
+ else:
+ res_id_params = ["{{{}}}".format(p) for p in res_id_params]
+ path += "/{}".format("/".join(res_id_params))
+ path += "/{}".format(func.__name__)
- return method(*vpath, **params)
- if not hasattr(wrapper, '__wrapped__'):
- wrapper.__wrapped__ = func
- return wrapper
+ method = func._resource_method_
- @staticmethod
- def _function_args(func):
- return getargspec(func).args[1:]
+ else:
+ continue
- @staticmethod
- def _takes_json(func):
- def inner(*args, **kwargs):
- if cherrypy.request.headers.get('Content-Type', '') == \
- 'application/x-www-form-urlencoded':
- return func(*args, **kwargs)
+ if no_resource_id_params:
+ raise TypeError("Could not infer the resource ID parameters for"
+ " method {}. "
+ "Please specify the resource ID parameters "
+ "using the RESOURCE_ID class property"
+ .format(func.__name__))
- content_length = int(cherrypy.request.headers['Content-Length'])
- body = cherrypy.request.body.read(content_length)
- if not body:
- return func(*args, **kwargs)
+ func = cls._status_code_wrapper(func, status)
+ endp_func = Endpoint(method, path=path)(func)
+ result.append(cls.Endpoint(cls, endp_func))
- try:
- data = json.loads(body.decode('utf-8'))
- except Exception as e:
- raise cherrypy.HTTPError(400, 'Failed to decode JSON: {}'
- .format(str(e)))
+ return result
- kwargs.update(data.items())
- return func(*args, **kwargs)
- return inner
+ @classmethod
+ def _status_code_wrapper(cls, func, status_code):
+ @wraps(func)
+ def wrapper(*vpath, **params):
+ cherrypy.response.status = status_code
+ return func(*vpath, **params)
- @staticmethod
- def _returns_json(func):
- def inner(*args, **kwargs):
- cherrypy.response.headers['Content-Type'] = 'application/json'
- ret = func(*args, **kwargs)
- return json.dumps(ret).encode('utf8')
- return inner
+ return wrapper
@staticmethod
- def resource(methods=None):
- if not methods:
- methods = ['GET']
+ def Resource(method=None):
+ if not method:
+ method = 'GET'
def _wrapper(func):
- func._resource_method_ = methods
+ func._resource_method_ = method
return func
return _wrapper
@staticmethod
- def collection(methods=None):
- if not methods:
- methods = ['GET']
+ def Collection(method=None):
+ if not method:
+ method = 'GET'
def _wrapper(func):
- func._collection_method_ = methods
+ func._collection_method_ = method
return func
return _wrapper