diff --git a/src/turbosloth/router.py b/src/turbosloth/router.py index d8c5eb0..01980cc 100644 --- a/src/turbosloth/router.py +++ b/src/turbosloth/router.py @@ -8,56 +8,57 @@ from .types import HandlerType, MethodType class Route: static_subroutes: dict[str, Route] - handler: Optional[HandlerType] + handler: dict[MethodType, HandlerType] def __init__(self): self.static_subroutes = {} - self.handler = None + self.handler = {} - def add(self, sequence: Sequence[str], handler: HandlerType): + def add(self, method: MethodType, sequence: Sequence[str], handler: HandlerType): if len(sequence) == 0: - self.handler = handler + self.handler[method] = handler return subroute = self.static_subroutes.get(sequence[0]) if subroute is None: subroute = Route() self.static_subroutes[sequence[0]] = subroute - subroute.add(sequence[1:], handler) - def get(self, sequence: Sequence[str]) -> Optional[HandlerType]: + subroute.add(method, sequence[1:], handler) + + def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[HandlerType]: if len(sequence) == 0: - return self.handler + ret = self.handler.get(method) + if ret is None: + # TODO: extract exceptions + raise ValueError('405') + return ret subroute = self.static_subroutes.get(sequence[0]) if subroute is None: raise ValueError('404') - return subroute.get(sequence[1:]) + return subroute.get(method, sequence[1:]) class Router: - _routes: dict[MethodType, Route] + _root: Route def __init__(self): - self._routes = {} + self._root = Route() def add(self, method: MethodType, path_pattern: str, handler: HandlerType): assert method.upper() == method - root = self._routes.get(method) - if root is None: - root = Route() - self._routes[method] = root - segments = path_pattern.split('/') while len(segments) > 0 and len(segments[0]) == 0: segments = segments[1:] - root.add(segments, handler) + self._root.add(method, segments, handler) def match(self, method: MethodType, path: str) -> HandlerType: method = typing.cast(MethodType, method.upper()) - root = self._routes[method] segments = path.split('/') + while len(segments) > 0 and len(segments[0]) == 0: segments = segments[1:] - h = root.get(segments) + + h = self._root.get(method, segments) return h diff --git a/tests/test_router.py b/tests/test_router.py index c41bd94..6367499 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -10,20 +10,30 @@ def test_router_root_handler(): async def d(_: dict, *args): pass + async def a(_: dict, *args): + pass + r = Router() r.add('GET', '/', f) r.add('GET', '', f) - assert len(r._routes.keys()) == 1 - assert len(r._routes['GET'].static_subroutes.keys()) == 0 - assert r._routes['GET'].handler is not None - assert r._routes['GET'].handler == f + assert len(r._root.static_subroutes) == 0 + assert len(r._root.handler) == 1 + assert r._root.handler['GET'] == f + + r.add('POST', '/', a) + assert len(r._root.static_subroutes) == 0 + assert len(r._root.handler) == 2 + assert r._root.handler['GET'] == f + assert r._root.handler['POST'] == a r.add('GET', '/asdf', d) - assert len(r._routes.keys()) == 1 - assert len(r._routes['GET'].static_subroutes.keys()) == 1 - assert r._routes['GET'].handler is not None - assert r._routes['GET'].handler == f - assert r._routes['GET'].static_subroutes['asdf'].handler == d + assert len(r._root.static_subroutes) == 1 + assert len(r._root.handler) == 2 + assert r._root.handler['GET'] == f + assert r._root.handler['POST'] == a + assert len(r._root.static_subroutes['asdf'].static_subroutes) == 0 + assert len(r._root.static_subroutes['asdf'].handler) == 1 + assert r._root.static_subroutes['asdf'].handler['GET'] == d def test_router_match(): @@ -36,5 +46,7 @@ def test_router_match(): r = Router() r.add('GET', 'asdf', f) assert r.match('GET', '/asdf') - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='404'): r.match('GET', 'asd') + with pytest.raises(ValueError, match='405'): + r.match('POST', 'asdf')