Compare commits

...

2 Commits

Author SHA1 Message Date
15f6438407 Add path matches support into an UnwrappedRequest 2025-07-19 04:02:04 +03:00
75824d1893 Add key-value path subst matching 2025-07-19 03:45:11 +03:00
8 changed files with 94 additions and 44 deletions

View File

@@ -7,14 +7,14 @@ import uvicorn
from turbosloth import SlothApp from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE from turbosloth.internal_types import QTYPE, BTYPE, PTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
app = SlothApp() app = SlothApp()
@app.get("/") @app.get("/")
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse: async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!') return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -24,7 +24,7 @@ class UserIdSchema:
@app.get("/user/") @app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse: async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse:
print(req) print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)
@@ -53,15 +53,22 @@ def foo() -> SomeInternalData:
return s return s
@app.post("/user") @dataclass
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse: class PTYPESchema:
user_id: int
@app.post("/user/u{user_id}r")
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema],
dat: SomeInternalData) -> SerializedResponse:
print(req) print(req)
print(dat) print(dat)
resp: dict[str, Any] = { resp: dict[str, Any] = {
'message': f'Hello, User {req.body.user_id}!', 'message': f'Hello, User {req.body.user_id}!',
'from': 'server', 'from': 'server',
'data': req.body.data, 'data': req.body.data,
'inj': dat.a 'inj': dat.a,
'user_id': req.path_matches.user_id
} }
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)

View File

@@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest from .req_schema import UnwrappedRequest
from .router import Router from .router import Router
from .types import HandlerType, InternalHandlerType, ContentType from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE
from breakshaft.convertor import ConvRepo from breakshaft.convertor import ConvRepo
from .util import parse_content_type from .util import parse_content_type
@@ -88,7 +88,8 @@ class HTTPApp(ASGIApp):
sresp: SerializedResponse sresp: SerializedResponse
resp: BasicResponse resp: BasicResponse
try: try:
handler = self.router.match(method, path) matches, handler = self.router.match(method, path)
req.path_matches = matches
await handler(send, req) await handler(send, req)
return return
@@ -187,12 +188,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self.inj_repo = ConvRepo() self.inj_repo = ConvRepo()
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_query(req: BasicRequest) -> QTYPE: def extract_query(req: BasicRequest | SerializedRequest) -> QTYPE:
return req.query return req.query
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_query(req: SerializedRequest) -> QTYPE: def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE:
return req.query return req.path_matches
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE: def extract_body(req: SerializedRequest) -> BTYPE:
@@ -226,15 +227,26 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
if req_schema is None: if req_schema is None:
raise ValueError(f'Unable to find request schema in handler {fn}') raise ValueError(f'Unable to find request schema in handler {fn}')
query_type, body_type = get_args(req_schema) query_type, body_type, path_type = get_args(req_schema)
q_inflator = None q_inflator = None
b_inflator = None b_inflator = None
p_inflator = None
fork_with = set() fork_with = set()
def none_generator(*args) -> None: def none_generator(*args) -> None:
return None return None
if path_type not in [PTYPE, None, Any]:
p_inflator = self.infl_generator.schema_to_inflator(
path_type,
strict_mode_override=False,
from_type_override=QTYPE
)
fork_with.add(ConversionPoint(p_inflator, path_type, (PTYPE,)))
else:
fork_with.add(ConversionPoint(none_generator, path_type, (PTYPE,)))
if query_type not in [QTYPE, None, Any]: if query_type not in [QTYPE, None, Any]:
q_inflator = self.infl_generator.schema_to_inflator( q_inflator = self.infl_generator.schema_to_inflator(
query_type, query_type,
@@ -254,11 +266,16 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
else: else:
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,))) fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
def construct_unwrap(q, b) -> UnwrappedRequest: def construct_unwrap(q, b, p) -> UnwrappedRequest:
print(f'unwrapping {query_type} and {body_type} with {q} and {b}') return UnwrappedRequest(q, b, p)
return UnwrappedRequest(q, b)
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))} fork_with |= {
ConversionPoint(
construct_unwrap,
req_schema,
((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE))
)
}
tmp_repo = self.inj_repo.fork(fork_with) tmp_repo = self.inj_repo.fork(fork_with)

View File

@@ -17,6 +17,7 @@ class BasicRequest:
headers: CaseInsensitiveDict[str, str] headers: CaseInsensitiveDict[str, str]
query: dict[str, list[Any] | Any] query: dict[str, list[Any] | Any]
body: bytes body: bytes
path_matches: dict[str, str]
def __init__(self, def __init__(self,
method: MethodType, method: MethodType,
@@ -29,6 +30,7 @@ class BasicRequest:
self.headers = CaseInsensitiveDict(headers) self.headers = CaseInsensitiveDict(headers)
self.query = query self.query = query
self.body = body self.body = body
self.path_matches = {}
@classmethod @classmethod
def from_scope(cls, scope: Scope) -> BasicRequest: def from_scope(cls, scope: Scope) -> BasicRequest:

View File

@@ -30,6 +30,10 @@ class SerializedRequest:
def headers(self) -> CaseInsensitiveDict[str, str]: def headers(self) -> CaseInsensitiveDict[str, str]:
return self.basic.headers return self.basic.headers
@property
def path_matches(self) -> dict[str, str]:
return self.basic.path_matches
@classmethod @classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
raise NotImplementedError() raise NotImplementedError()

View File

@@ -19,3 +19,4 @@ type MethodType = (
type QTYPE = Annotated[dict[str, Any], 'query_params'] type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body'] type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
type PTYPE = Annotated[dict[str, str], 'path_matches']

View File

@@ -1,13 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, TypeVar from typing import Any, TypeVar, Generic
from mypy.visitor import Generic
Q = TypeVar('Q') Q = TypeVar('Q')
B = TypeVar('B') B = TypeVar('B')
P = TypeVar('P')
@dataclass @dataclass
class UnwrappedRequest(Generic[Q, B]): class UnwrappedRequest(Generic[Q, B, P]):
query: Q query: Q
body: B body: B
path_matches: P

View File

@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import re import re
from re import Pattern
import typing import typing
from typing import Optional, Sequence from typing import Optional, Sequence
from pathspec import Pattern
from .exceptions import MethodNotAllowedException, NotFoundException from .exceptions import MethodNotAllowedException, NotFoundException
from .types import InternalHandlerType from .types import InternalHandlerType
from .internal_types import MethodType from .internal_types import MethodType
@@ -13,7 +12,7 @@ from .internal_types import MethodType
class Route: class Route:
static_subroutes: dict[str, Route] static_subroutes: dict[str, Route]
regexp_subroutes: list[tuple[Pattern, Route]] regexp_subroutes: list[tuple[Pattern, list[str], Route]]
handler: dict[MethodType, InternalHandlerType] handler: dict[MethodType, InternalHandlerType]
def __init__(self) -> None: def __init__(self) -> None:
@@ -21,15 +20,15 @@ class Route:
self.regexp_subroutes = [] self.regexp_subroutes = []
self.handler = {} self.handler = {}
def _find_regexp_subroute(self, p: Pattern) -> Optional[Route]: def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]:
for _p, _r in self.regexp_subroutes: for _p, _n, _r in self.regexp_subroutes:
if _p == p: if _p == p:
return _r return _n, _r
return None return None
def _add_regexp_subroute(self, p: Pattern, r: Route): def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route):
if self._find_regexp_subroute(p) is None: if self._find_regexp_subroute(p) is None:
self.regexp_subroutes.append((p, r)) self.regexp_subroutes.append((p, names, r))
self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True) self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True)
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None: def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
@@ -43,13 +42,16 @@ class Route:
if '}' not in part: if '}' not in part:
raise ValueError(f'Invalid subpath substitute placeholder: {part}') raise ValueError(f'Invalid subpath substitute placeholder: {part}')
re_part = part.replace('(', '\\(').replace(')', '\\)') re_part = part.replace('(', '\\(').replace(')', '\\)')
re_part = re.sub(r'\{.+}', r'(.+)', re_part) names = re.findall(r'\{(.*)}', part)
re_part = re.sub(r'\{.*}', r'(.*)', re_part)
re_part = re.compile('^' + re_part + '$') re_part = re.compile('^' + re_part + '$')
subroute = self._find_regexp_subroute(re_part) d = self._find_regexp_subroute(re_part)
if subroute is None: if d is None:
subroute = Route() subroute = Route()
self._add_regexp_subroute(re_part, subroute) self._add_regexp_subroute(re_part, names, subroute)
else:
_, subroute = d
else: else:
subroute = self.static_subroutes.get(part) subroute = self.static_subroutes.get(part)
if subroute is None: if subroute is None:
@@ -58,26 +60,35 @@ class Route:
subroute.add(method, sequence[1:], handler) subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]: def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]:
if len(sequence) == 0: if len(sequence) == 0:
if len(self.handler) == 0: if len(self.handler) == 0:
raise NotFoundException('') raise NotFoundException('')
ret = self.handler.get(method) ret = self.handler.get(method)
if ret is None: if ret is None:
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys()))) raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
return ret return {}, ret
subroute = self.static_subroutes.get(sequence[0]) subroute = self.static_subroutes.get(sequence[0])
matches = {}
if subroute is None: if subroute is None:
for p, sr in self.regexp_subroutes: for p, subst_names, sr in self.regexp_subroutes:
m = p.match(sequence[0]) m = p.findall(sequence[0])
if m is not None: if len(m) > 0:
if len(m) != len(subst_names):
raise RuntimeError('Unable to match substitutes')
subroute = sr subroute = sr
for k, v in zip(subst_names, m):
matches[k] = v
break
if subroute is None: if subroute is None:
raise NotFoundException('/'.join(sequence)) raise NotFoundException('/'.join(sequence))
return subroute.get(method, sequence[1:]) submatches, handler = subroute.get(method, sequence[1:])
submatches |= matches
return submatches, handler
class Router: class Router:
@@ -87,14 +98,22 @@ class Router:
self._root = Route() self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None: def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
assert method.upper() == method method = typing.cast(MethodType, method.upper())
substitutes = re.findall(r'\{(.*)}', path_pattern)
if len(substitutes) > 0:
subst = []
for s in substitutes:
subst += list(s)
if len(subst) != len(set(subst)):
raise ValueError('Duplicate path substitute names are prohibited')
segments = path_pattern.split('/') segments = path_pattern.split('/')
while len(segments) > 0 and len(segments[0]) == 0: while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:] segments = segments[1:]
self._root.add(method, segments, handler) self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> InternalHandlerType: def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]:
method = typing.cast(MethodType, method.upper()) method = typing.cast(MethodType, method.upper())
segments = path.split('/') segments = path.split('/')

View File

@@ -46,7 +46,7 @@ def test_router_match():
r = Router() r = Router()
r.add('GET', 'asdf', f) r.add('GET', 'asdf', f)
assert r.match('GET', '/asdf') assert r.match('GET', '/asdf') == ({}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'): with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd') r.match('GET', 'asd')
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'): with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
@@ -64,9 +64,9 @@ def test_router_pattern_match():
r.add('GET', '/{some}/asdf', f) r.add('GET', '/{some}/asdf', f)
r.add('GET', '/{some}/b{some1}c', f) r.add('GET', '/{some}/b{some1}c', f)
assert r.match('GET', '/1234/asdf') assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f)
assert r.match('GET', '/ /asdf') assert r.match('GET', '/ /asdf') == ({'some': ' '}, f)
assert r.match('GET', '/ /basdfc') assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'): with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd') r.match('GET', 'asd')