diff --git a/src/turbosloth/__main__.py b/src/turbosloth/__main__.py index 27074f7..bebb391 100644 --- a/src/turbosloth/__main__.py +++ b/src/turbosloth/__main__.py @@ -7,14 +7,14 @@ import uvicorn from turbosloth import SlothApp from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest -from turbosloth.internal_types import QTYPE, BTYPE +from turbosloth.internal_types import QTYPE, BTYPE, PTYPE from turbosloth.req_schema import UnwrappedRequest app = SlothApp() @app.get("/") -async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse: +async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse: return SerializedResponse(200, {}, 'Hello, ASGI Router!') @@ -24,7 +24,7 @@ class UserIdSchema: @app.get("/user/") -async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse: +async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse: print(req) resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} return SerializedResponse(200, {}, resp) @@ -53,15 +53,22 @@ def foo() -> SomeInternalData: return s -@app.post("/user") -async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse: +@dataclass +class PTYPESchema: + user_id: int + + +@app.post("/user/u{user_id}r") +async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema], + dat: SomeInternalData) -> SerializedResponse: print(req) print(dat) resp: dict[str, Any] = { 'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data, - 'inj': dat.a + 'inj': dat.a, + 'user_id': req.path_matches.user_id } return SerializedResponse(200, {}, resp) diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index d59627a..476be8e 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse from .req_schema import UnwrappedRequest from .router import Router from .types import HandlerType, InternalHandlerType, ContentType -from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE +from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE from breakshaft.convertor import ConvRepo from .util import parse_content_type @@ -88,7 +88,8 @@ class HTTPApp(ASGIApp): sresp: SerializedResponse resp: BasicResponse try: - handler = self.router.match(method, path) + matches, handler = self.router.match(method, path) + req.path_matches = matches await handler(send, req) return @@ -187,12 +188,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): self.inj_repo = ConvRepo() @self.inj_repo.mark_injector() - def extract_query(req: BasicRequest) -> QTYPE: + def extract_query(req: BasicRequest | SerializedRequest) -> QTYPE: return req.query @self.inj_repo.mark_injector() - def extract_query(req: SerializedRequest) -> QTYPE: - return req.query + def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE: + return req.path_matches @self.inj_repo.mark_injector() def extract_body(req: SerializedRequest) -> BTYPE: @@ -226,15 +227,26 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): if req_schema is None: raise ValueError(f'Unable to find request schema in handler {fn}') - query_type, body_type = get_args(req_schema) + query_type, body_type, path_type = get_args(req_schema) q_inflator = None b_inflator = None + p_inflator = None fork_with = set() def none_generator(*args) -> None: return None + if path_type not in [PTYPE, None, Any]: + p_inflator = self.infl_generator.schema_to_inflator( + path_type, + strict_mode_override=False, + from_type_override=QTYPE + ) + fork_with.add(ConversionPoint(p_inflator, path_type, (PTYPE,))) + else: + fork_with.add(ConversionPoint(none_generator, path_type, (PTYPE,))) + if query_type not in [QTYPE, None, Any]: q_inflator = self.infl_generator.schema_to_inflator( query_type, @@ -254,11 +266,16 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): else: fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,))) - def construct_unwrap(q, b) -> UnwrappedRequest: - print(f'unwrapping {query_type} and {body_type} with {q} and {b}') - return UnwrappedRequest(q, b) + def construct_unwrap(q, b, p) -> UnwrappedRequest: + return UnwrappedRequest(q, b, p) - fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))} + fork_with |= { + ConversionPoint( + construct_unwrap, + req_schema, + ((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE)) + ) + } tmp_repo = self.inj_repo.fork(fork_with) diff --git a/src/turbosloth/interfaces/base.py b/src/turbosloth/interfaces/base.py index ce63b87..c7bc417 100644 --- a/src/turbosloth/interfaces/base.py +++ b/src/turbosloth/interfaces/base.py @@ -17,6 +17,7 @@ class BasicRequest: headers: CaseInsensitiveDict[str, str] query: dict[str, list[Any] | Any] body: bytes + path_matches: dict[str, str] def __init__(self, method: MethodType, @@ -29,6 +30,7 @@ class BasicRequest: self.headers = CaseInsensitiveDict(headers) self.query = query self.body = body + self.path_matches = {} @classmethod def from_scope(cls, scope: Scope) -> BasicRequest: diff --git a/src/turbosloth/interfaces/serialized/base.py b/src/turbosloth/interfaces/serialized/base.py index 414d17b..3cc4caf 100644 --- a/src/turbosloth/interfaces/serialized/base.py +++ b/src/turbosloth/interfaces/serialized/base.py @@ -30,6 +30,10 @@ class SerializedRequest: def headers(self) -> CaseInsensitiveDict[str, str]: return self.basic.headers + @property + def path_matches(self) -> dict[str, str]: + return self.basic.path_matches + @classmethod def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: raise NotImplementedError() diff --git a/src/turbosloth/internal_types.py b/src/turbosloth/internal_types.py index 74df546..8c2c390 100644 --- a/src/turbosloth/internal_types.py +++ b/src/turbosloth/internal_types.py @@ -19,3 +19,4 @@ type MethodType = ( type QTYPE = Annotated[dict[str, Any], 'query_params'] type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body'] +type PTYPE = Annotated[dict[str, str], 'path_matches'] diff --git a/src/turbosloth/req_schema.py b/src/turbosloth/req_schema.py index 6b5ef88..c33466e 100644 --- a/src/turbosloth/req_schema.py +++ b/src/turbosloth/req_schema.py @@ -1,13 +1,13 @@ from dataclasses import dataclass -from typing import Any, TypeVar - -from mypy.visitor import Generic +from typing import Any, TypeVar, Generic Q = TypeVar('Q') B = TypeVar('B') +P = TypeVar('P') @dataclass -class UnwrappedRequest(Generic[Q, B]): +class UnwrappedRequest(Generic[Q, B, P]): query: Q body: B + path_matches: P diff --git a/src/turbosloth/router.py b/src/turbosloth/router.py index 9201cc6..c125180 100644 --- a/src/turbosloth/router.py +++ b/src/turbosloth/router.py @@ -1,5 +1,7 @@ from __future__ import annotations +import re +from re import Pattern import typing from typing import Optional, Sequence @@ -10,34 +12,83 @@ from .internal_types import MethodType class Route: static_subroutes: dict[str, Route] + regexp_subroutes: list[tuple[Pattern, list[str], Route]] handler: dict[MethodType, InternalHandlerType] def __init__(self) -> None: self.static_subroutes = {} + self.regexp_subroutes = [] self.handler = {} + def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]: + for _p, _n, _r in self.regexp_subroutes: + if _p == p: + return _n, _r + return None + + def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route): + if self._find_regexp_subroute(p) is None: + self.regexp_subroutes.append((p, names, r)) + self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True) + def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None: if len(sequence) == 0: self.handler[method] = handler return - subroute = self.static_subroutes.get(sequence[0]) - if subroute is None: - subroute = Route() - self.static_subroutes[sequence[0]] = subroute + part = sequence[0] + + if '{' in part: + if '}' not in part: + raise ValueError(f'Invalid subpath substitute placeholder: {part}') + re_part = part.replace('(', '\\(').replace(')', '\\)') + names = re.findall(r'\{(.*?)}', part) + re_part = re.sub(r'\{.*?}', r'(.*?)', re_part) + re_part = re.compile('^' + re_part + '$') + + d = self._find_regexp_subroute(re_part) + if d is None: + subroute = Route() + self._add_regexp_subroute(re_part, names, subroute) + else: + _, subroute = d + else: + subroute = self.static_subroutes.get(part) + if subroute is None: + subroute = Route() + self.static_subroutes[part] = subroute subroute.add(method, sequence[1:], handler) - def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]: + def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]: if len(sequence) == 0: + if len(self.handler) == 0: + raise NotFoundException('') ret = self.handler.get(method) if ret is None: raise MethodNotAllowedException(', '.join(map(str, self.handler.keys()))) - return ret + return {}, ret + subroute = self.static_subroutes.get(sequence[0]) + + matches = {} + if subroute is None: + for p, subst_names, sr in self.regexp_subroutes: + m = p.findall(sequence[0]) + if len(m) > 0: + if len(m) != len(subst_names): + raise RuntimeError('Unable to match substitutes') + + subroute = sr + for k, v in zip(subst_names, m): + matches[k] = v + break + if subroute is None: raise NotFoundException('/'.join(sequence)) - return subroute.get(method, sequence[1:]) + submatches, handler = subroute.get(method, sequence[1:]) + submatches |= matches + return submatches, handler class Router: @@ -47,14 +98,24 @@ class Router: self._root = Route() def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None: - assert method.upper() == method + method = typing.cast(MethodType, method.upper()) + + substitutes = re.findall(r'\{(.*?)}', path_pattern) + if len(substitutes) > 0: + subst = [] + for s in substitutes: + if isinstance(s, str): + s = [s] + subst += list(s) + if len(subst) != len(set(subst)): + raise ValueError('Duplicate path substitute names are prohibited') segments = path_pattern.split('/') while len(segments) > 0 and len(segments[0]) == 0: segments = segments[1:] self._root.add(method, segments, handler) - def match(self, method: MethodType, path: str) -> InternalHandlerType: + def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]: method = typing.cast(MethodType, method.upper()) segments = path.split('/') diff --git a/tests/test_router.py b/tests/test_router.py index a5b225f..040fcba 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -5,13 +5,13 @@ from src.turbosloth.router import Router def test_router_root_handler(): - async def f(_: dict, *args): + async def f(*args): pass - async def d(_: dict, *args): + async def d(*args): pass - async def a(_: dict, *args): + async def a(*args): pass r = Router() @@ -38,19 +38,44 @@ def test_router_root_handler(): def test_router_match(): - async def f(_: dict, *args): + async def f(*args): pass - async def d(_: dict, *args): + async def d(*args): pass r = Router() r.add('GET', 'asdf', f) - assert r.match('GET', '/asdf') + assert r.match('GET', '/asdf') == ({}, f) with pytest.raises(NotFoundException, match='404\tNot Found: asd'): r.match('GET', 'asd') with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'): r.match('POST', 'asdf') - with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /, allowed: none'): + with pytest.raises(NotFoundException, match=f'404\tNot Found'): r.match('POST', '') + + +def test_router_pattern_match(): + async def f(*args): + pass + + r = Router() + r.add('GET', '/{some}/asdf', f) + r.add('GET', '/{some}/b{some1}c', f) + + assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f) + assert r.match('GET', '/ /asdf') == ({'some': ' '}, f) + assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f) + + with pytest.raises(NotFoundException, match='404\tNot Found: asd'): + r.match('GET', 'asd') + + with pytest.raises(NotFoundException, match='404\tNot Found: asd'): + r.match('GET', 'asd/') + + with pytest.raises(NotFoundException, match='404\tNot Found: asd'): + r.match('GET', 'asd/b') + + with pytest.raises(NotFoundException, match='404\tNot Found: asd'): + r.match('GET', 'asd/basdf')