Allow router to raise MethodNotAllowed

This commit is contained in:
2025-07-16 02:23:02 +03:00
parent d22c753022
commit 939ef6073f
2 changed files with 41 additions and 28 deletions

View File

@@ -8,56 +8,57 @@ from .types import HandlerType, MethodType
class Route:
static_subroutes: dict[str, Route]
handler: Optional[HandlerType]
handler: dict[MethodType, HandlerType]
def __init__(self):
self.static_subroutes = {}
self.handler = None
self.handler = {}
def add(self, sequence: Sequence[str], handler: HandlerType):
def add(self, method: MethodType, sequence: Sequence[str], handler: HandlerType):
if len(sequence) == 0:
self.handler = handler
self.handler[method] = handler
return
subroute = self.static_subroutes.get(sequence[0])
if subroute is None:
subroute = Route()
self.static_subroutes[sequence[0]] = subroute
subroute.add(sequence[1:], handler)
def get(self, sequence: Sequence[str]) -> Optional[HandlerType]:
subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[HandlerType]:
if len(sequence) == 0:
return self.handler
ret = self.handler.get(method)
if ret is None:
# TODO: extract exceptions
raise ValueError('405')
return ret
subroute = self.static_subroutes.get(sequence[0])
if subroute is None:
raise ValueError('404')
return subroute.get(sequence[1:])
return subroute.get(method, sequence[1:])
class Router:
_routes: dict[MethodType, Route]
_root: Route
def __init__(self):
self._routes = {}
self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: HandlerType):
assert method.upper() == method
root = self._routes.get(method)
if root is None:
root = Route()
self._routes[method] = root
segments = path_pattern.split('/')
while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:]
root.add(segments, handler)
self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> HandlerType:
method = typing.cast(MethodType, method.upper())
root = self._routes[method]
segments = path.split('/')
while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:]
h = root.get(segments)
h = self._root.get(method, segments)
return h

View File

@@ -10,20 +10,30 @@ def test_router_root_handler():
async def d(_: dict, *args):
pass
async def a(_: dict, *args):
pass
r = Router()
r.add('GET', '/', f)
r.add('GET', '', f)
assert len(r._routes.keys()) == 1
assert len(r._routes['GET'].static_subroutes.keys()) == 0
assert r._routes['GET'].handler is not None
assert r._routes['GET'].handler == f
assert len(r._root.static_subroutes) == 0
assert len(r._root.handler) == 1
assert r._root.handler['GET'] == f
r.add('POST', '/', a)
assert len(r._root.static_subroutes) == 0
assert len(r._root.handler) == 2
assert r._root.handler['GET'] == f
assert r._root.handler['POST'] == a
r.add('GET', '/asdf', d)
assert len(r._routes.keys()) == 1
assert len(r._routes['GET'].static_subroutes.keys()) == 1
assert r._routes['GET'].handler is not None
assert r._routes['GET'].handler == f
assert r._routes['GET'].static_subroutes['asdf'].handler == d
assert len(r._root.static_subroutes) == 1
assert len(r._root.handler) == 2
assert r._root.handler['GET'] == f
assert r._root.handler['POST'] == a
assert len(r._root.static_subroutes['asdf'].static_subroutes) == 0
assert len(r._root.static_subroutes['asdf'].handler) == 1
assert r._root.static_subroutes['asdf'].handler['GET'] == d
def test_router_match():
@@ -36,5 +46,7 @@ def test_router_match():
r = Router()
r.add('GET', 'asdf', f)
assert r.match('GET', '/asdf')
with pytest.raises(ValueError):
with pytest.raises(ValueError, match='404'):
r.match('GET', 'asd')
with pytest.raises(ValueError, match='405'):
r.match('POST', 'asdf')