Allow router to raise MethodNotAllowed
This commit is contained in:
@@ -8,56 +8,57 @@ from .types import HandlerType, MethodType
|
||||
|
||||
class Route:
|
||||
static_subroutes: dict[str, Route]
|
||||
handler: Optional[HandlerType]
|
||||
handler: dict[MethodType, HandlerType]
|
||||
|
||||
def __init__(self):
|
||||
self.static_subroutes = {}
|
||||
self.handler = None
|
||||
self.handler = {}
|
||||
|
||||
def add(self, sequence: Sequence[str], handler: HandlerType):
|
||||
def add(self, method: MethodType, sequence: Sequence[str], handler: HandlerType):
|
||||
if len(sequence) == 0:
|
||||
self.handler = handler
|
||||
self.handler[method] = handler
|
||||
return
|
||||
|
||||
subroute = self.static_subroutes.get(sequence[0])
|
||||
if subroute is None:
|
||||
subroute = Route()
|
||||
self.static_subroutes[sequence[0]] = subroute
|
||||
subroute.add(sequence[1:], handler)
|
||||
|
||||
def get(self, sequence: Sequence[str]) -> Optional[HandlerType]:
|
||||
subroute.add(method, sequence[1:], handler)
|
||||
|
||||
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[HandlerType]:
|
||||
if len(sequence) == 0:
|
||||
return self.handler
|
||||
ret = self.handler.get(method)
|
||||
if ret is None:
|
||||
# TODO: extract exceptions
|
||||
raise ValueError('405')
|
||||
return ret
|
||||
subroute = self.static_subroutes.get(sequence[0])
|
||||
if subroute is None:
|
||||
raise ValueError('404')
|
||||
return subroute.get(sequence[1:])
|
||||
return subroute.get(method, sequence[1:])
|
||||
|
||||
|
||||
class Router:
|
||||
_routes: dict[MethodType, Route]
|
||||
_root: Route
|
||||
|
||||
def __init__(self):
|
||||
self._routes = {}
|
||||
self._root = Route()
|
||||
|
||||
def add(self, method: MethodType, path_pattern: str, handler: HandlerType):
|
||||
assert method.upper() == method
|
||||
|
||||
root = self._routes.get(method)
|
||||
if root is None:
|
||||
root = Route()
|
||||
self._routes[method] = root
|
||||
|
||||
segments = path_pattern.split('/')
|
||||
while len(segments) > 0 and len(segments[0]) == 0:
|
||||
segments = segments[1:]
|
||||
root.add(segments, handler)
|
||||
self._root.add(method, segments, handler)
|
||||
|
||||
def match(self, method: MethodType, path: str) -> HandlerType:
|
||||
method = typing.cast(MethodType, method.upper())
|
||||
root = self._routes[method]
|
||||
segments = path.split('/')
|
||||
|
||||
while len(segments) > 0 and len(segments[0]) == 0:
|
||||
segments = segments[1:]
|
||||
h = root.get(segments)
|
||||
|
||||
h = self._root.get(method, segments)
|
||||
return h
|
||||
|
||||
@@ -10,20 +10,30 @@ def test_router_root_handler():
|
||||
async def d(_: dict, *args):
|
||||
pass
|
||||
|
||||
async def a(_: dict, *args):
|
||||
pass
|
||||
|
||||
r = Router()
|
||||
r.add('GET', '/', f)
|
||||
r.add('GET', '', f)
|
||||
assert len(r._routes.keys()) == 1
|
||||
assert len(r._routes['GET'].static_subroutes.keys()) == 0
|
||||
assert r._routes['GET'].handler is not None
|
||||
assert r._routes['GET'].handler == f
|
||||
assert len(r._root.static_subroutes) == 0
|
||||
assert len(r._root.handler) == 1
|
||||
assert r._root.handler['GET'] == f
|
||||
|
||||
r.add('POST', '/', a)
|
||||
assert len(r._root.static_subroutes) == 0
|
||||
assert len(r._root.handler) == 2
|
||||
assert r._root.handler['GET'] == f
|
||||
assert r._root.handler['POST'] == a
|
||||
|
||||
r.add('GET', '/asdf', d)
|
||||
assert len(r._routes.keys()) == 1
|
||||
assert len(r._routes['GET'].static_subroutes.keys()) == 1
|
||||
assert r._routes['GET'].handler is not None
|
||||
assert r._routes['GET'].handler == f
|
||||
assert r._routes['GET'].static_subroutes['asdf'].handler == d
|
||||
assert len(r._root.static_subroutes) == 1
|
||||
assert len(r._root.handler) == 2
|
||||
assert r._root.handler['GET'] == f
|
||||
assert r._root.handler['POST'] == a
|
||||
assert len(r._root.static_subroutes['asdf'].static_subroutes) == 0
|
||||
assert len(r._root.static_subroutes['asdf'].handler) == 1
|
||||
assert r._root.static_subroutes['asdf'].handler['GET'] == d
|
||||
|
||||
|
||||
def test_router_match():
|
||||
@@ -36,5 +46,7 @@ def test_router_match():
|
||||
r = Router()
|
||||
r.add('GET', 'asdf', f)
|
||||
assert r.match('GET', '/asdf')
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match='404'):
|
||||
r.match('GET', 'asd')
|
||||
with pytest.raises(ValueError, match='405'):
|
||||
r.match('POST', 'asdf')
|
||||
|
||||
Reference in New Issue
Block a user