diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index d59627a..a4c3c4b 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -88,7 +88,8 @@ class HTTPApp(ASGIApp): sresp: SerializedResponse resp: BasicResponse try: - handler = self.router.match(method, path) + matches, handler = self.router.match(method, path) + req.path_matches = matches await handler(send, req) return diff --git a/src/turbosloth/interfaces/base.py b/src/turbosloth/interfaces/base.py index ce63b87..c7bc417 100644 --- a/src/turbosloth/interfaces/base.py +++ b/src/turbosloth/interfaces/base.py @@ -17,6 +17,7 @@ class BasicRequest: headers: CaseInsensitiveDict[str, str] query: dict[str, list[Any] | Any] body: bytes + path_matches: dict[str, str] def __init__(self, method: MethodType, @@ -29,6 +30,7 @@ class BasicRequest: self.headers = CaseInsensitiveDict(headers) self.query = query self.body = body + self.path_matches = {} @classmethod def from_scope(cls, scope: Scope) -> BasicRequest: diff --git a/src/turbosloth/router.py b/src/turbosloth/router.py index 09ea926..e9817a1 100644 --- a/src/turbosloth/router.py +++ b/src/turbosloth/router.py @@ -1,11 +1,10 @@ from __future__ import annotations import re +from re import Pattern import typing from typing import Optional, Sequence -from pathspec import Pattern - from .exceptions import MethodNotAllowedException, NotFoundException from .types import InternalHandlerType from .internal_types import MethodType @@ -13,7 +12,7 @@ from .internal_types import MethodType class Route: static_subroutes: dict[str, Route] - regexp_subroutes: list[tuple[Pattern, Route]] + regexp_subroutes: list[tuple[Pattern, list[str], Route]] handler: dict[MethodType, InternalHandlerType] def __init__(self) -> None: @@ -21,15 +20,15 @@ class Route: self.regexp_subroutes = [] self.handler = {} - def _find_regexp_subroute(self, p: Pattern) -> Optional[Route]: - for _p, _r in self.regexp_subroutes: + def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]: + for _p, _n, _r in self.regexp_subroutes: if _p == p: - return _r + return _n, _r return None - def _add_regexp_subroute(self, p: Pattern, r: Route): + def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route): if self._find_regexp_subroute(p) is None: - self.regexp_subroutes.append((p, r)) + self.regexp_subroutes.append((p, names, r)) self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True) def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None: @@ -43,13 +42,16 @@ class Route: if '}' not in part: raise ValueError(f'Invalid subpath substitute placeholder: {part}') re_part = part.replace('(', '\\(').replace(')', '\\)') - re_part = re.sub(r'\{.+}', r'(.+)', re_part) + names = re.findall(r'\{(.*)}', part) + re_part = re.sub(r'\{.*}', r'(.*)', re_part) re_part = re.compile('^' + re_part + '$') - subroute = self._find_regexp_subroute(re_part) - if subroute is None: + d = self._find_regexp_subroute(re_part) + if d is None: subroute = Route() - self._add_regexp_subroute(re_part, subroute) + self._add_regexp_subroute(re_part, names, subroute) + else: + _, subroute = d else: subroute = self.static_subroutes.get(part) if subroute is None: @@ -58,26 +60,35 @@ class Route: subroute.add(method, sequence[1:], handler) - def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]: + def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]: if len(sequence) == 0: if len(self.handler) == 0: raise NotFoundException('') ret = self.handler.get(method) if ret is None: raise MethodNotAllowedException(', '.join(map(str, self.handler.keys()))) - return ret + return {}, ret subroute = self.static_subroutes.get(sequence[0]) + matches = {} if subroute is None: - for p, sr in self.regexp_subroutes: - m = p.match(sequence[0]) - if m is not None: + for p, subst_names, sr in self.regexp_subroutes: + m = p.findall(sequence[0]) + if len(m) > 0: + if len(m) != len(subst_names): + raise RuntimeError('Unable to match substitutes') + subroute = sr + for k, v in zip(subst_names, m): + matches[k] = v + break if subroute is None: raise NotFoundException('/'.join(sequence)) - return subroute.get(method, sequence[1:]) + submatches, handler = subroute.get(method, sequence[1:]) + submatches |= matches + return submatches, handler class Router: @@ -87,14 +98,22 @@ class Router: self._root = Route() def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None: - assert method.upper() == method + method = typing.cast(MethodType, method.upper()) + + substitutes = re.findall(r'\{(.*)}', path_pattern) + if len(substitutes) > 0: + subst = [] + for s in substitutes: + subst += list(s) + if len(subst) != len(set(subst)): + raise ValueError('Duplicate path substitute names are prohibited') segments = path_pattern.split('/') while len(segments) > 0 and len(segments[0]) == 0: segments = segments[1:] self._root.add(method, segments, handler) - def match(self, method: MethodType, path: str) -> InternalHandlerType: + def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]: method = typing.cast(MethodType, method.upper()) segments = path.split('/') diff --git a/tests/test_router.py b/tests/test_router.py index 4cd14b1..040fcba 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -46,7 +46,7 @@ def test_router_match(): r = Router() r.add('GET', 'asdf', f) - assert r.match('GET', '/asdf') + assert r.match('GET', '/asdf') == ({}, f) with pytest.raises(NotFoundException, match='404\tNot Found: asd'): r.match('GET', 'asd') with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'): @@ -64,9 +64,9 @@ def test_router_pattern_match(): r.add('GET', '/{some}/asdf', f) r.add('GET', '/{some}/b{some1}c', f) - assert r.match('GET', '/1234/asdf') - assert r.match('GET', '/ /asdf') - assert r.match('GET', '/ /basdfc') + assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f) + assert r.match('GET', '/ /asdf') == ({'some': ' '}, f) + assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f) with pytest.raises(NotFoundException, match='404\tNot Found: asd'): r.match('GET', 'asd')