import json
from cherrypy.test import helper
+from more_itertools import always_iterable
from ..module import Module
def _put(self, url, data=None):
self._request(url, 'PUT', data)
- def assertJsonBody(self, data):
- self.assertBody(json.dumps(data))
+ def assertJsonBody(self, data, msg=None):
+ """Fail if value != self.body."""
+ body_str = self.body.decode('utf-8') if isinstance(self.body, bytes) else self.body
+ json_body = json.loads(body_str)
+ if data != json_body:
+ if msg is None:
+ msg = 'expected body:\n%r\n\nactual body:\n%r' % (
+ data, json_body)
+ self._handlewebError(msg)
class ControllerTestCase(helper.CPWebCase, RequestHelper):
from mock import patch
from .helper import RequestHelper
-from ..tools import RESTController
+from ..tools import RESTController, detail_route
# pylint: disable=W0613
def bulk_delete(self):
FooResource.elems = []
+ def set(self, data, key):
+ FooResource.elems[int(key)] = data
+ return dict(key=key, **data)
+
+ @detail_route(methods=['get'])
+ def detail(self, key):
+ return {'detail': key}
+
class FooArgs(RESTController):
@RESTController.args_from_json
self.assertHeader('Content-Type', 'application/json')
self.assertJsonBody([data] * 5)
+ self._put('/foo/0', {'newdata': 'newdata'})
+ self.assertStatus('200 OK')
+ self.assertHeader('Content-Type', 'application/json')
+ self.assertJsonBody({'newdata': 'newdata', 'key': '0'})
+
def test_not_implemented(self):
self._put("/foo")
self.assertStatus(405)
def test_args_from_json(self):
self._put("/fooargs/hello", {'name': 'world'})
self.assertJsonBody({'code': 'hello', 'name': 'world'})
+
+ def test_detail_route(self):
+ self._get('/foo/1/detail')
+ self.assertJsonBody({'detail': '1'})
+
+ self._post('/foo/1/detail', 'post-data')
+ self.assertStatus(405)
"""
- def _not_implemented(self, is_element):
- methods = [method
- for ((method, _is_element), (meth, _))
- in self._method_mapping.items()
- if _is_element == is_element and hasattr(self, meth)]
+ def _not_implemented(self, obj_key, detail_route_name):
+ if detail_route_name:
+ try:
+ methods = getattr(getattr(self, detail_route_name), 'detail_route_methods')
+ except AttributeError:
+ raise cherrypy.NotFound()
+ else:
+ methods = [method
+ for ((method, _is_element), (meth, _))
+ in self._method_mapping.items()
+ if _is_element == obj_key is not None and hasattr(self, meth)]
cherrypy.response.headers['Allow'] = ','.join(methods)
raise cherrypy.HTTPError(405, 'Method not implemented.')
('DELETE', True): ('delete', 204),
}
+ def _get_method(self, obj_key, detail_route_name):
+ if detail_route_name:
+ try:
+ method = getattr(self, detail_route_name)
+ if not getattr(method, 'detail_route'):
+ self._not_implemented(obj_key, detail_route_name)
+ if cherrypy.request.method not in getattr(method, 'detail_route_methods'):
+ self._not_implemented(obj_key, detail_route_name)
+ return method, 200
+ except AttributeError:
+ self._not_implemented(obj_key, detail_route_name)
+ else:
+ method_name, status_code = self._method_mapping[
+ (cherrypy.request.method, obj_key is not None)]
+ method = getattr(self, method_name, None)
+ if not method:
+ self._not_implemented(obj_key, detail_route_name)
+ return method, status_code
+
@cherrypy.expose
def default(self, *vpath, **params):
cherrypy.config.update({
'error_page.default': _json_error_page})
- is_element = len(vpath) > 0
-
- (method_name, status_code) = self._method_mapping[
- (cherrypy.request.method, is_element)]
- method = getattr(self, method_name, None)
- if not method:
- self._not_implemented(is_element)
+ obj_key, detail_route_name = self.split_vpath(vpath)
+ method, status_code = self._get_method(obj_key, detail_route_name)
if cherrypy.request.method not in ['GET', 'DELETE']:
method = RESTController._takes_json(method)
cherrypy.response.status = status_code
- return method(*vpath, **params)
+ obj_key_args = [obj_key] if obj_key else []
+ return method(*obj_key_args, **params)
@staticmethod
def args_from_json(func):
ret = func(*args, **kwargs)
return json.dumps(ret).encode('utf8')
return inner
+
+ @staticmethod
+ def split_vpath(vpath):
+ if not vpath:
+ return None, None
+ if len(vpath) == 1:
+ return vpath[0], None
+ return vpath[0], vpath[1]
+
+
+def detail_route(methods):
+ def decorator(func):
+ func.detail_route = True
+ func.detail_route_methods = [m.upper() for m in methods]
+ return func
+ return decorator