Compare commits

...

2 Commits

Author SHA1 Message Date
15f6438407 Add path matches support into an UnwrappedRequest 2025-07-19 04:02:04 +03:00
75824d1893 Add key-value path subst matching 2025-07-19 03:45:11 +03:00
8 changed files with 94 additions and 44 deletions

View File

@@ -7,14 +7,14 @@ import uvicorn
from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE
from turbosloth.req_schema import UnwrappedRequest
app = SlothApp()
@app.get("/")
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse:
async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -24,7 +24,7 @@ class UserIdSchema:
@app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse:
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse:
print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp)
@@ -53,15 +53,22 @@ def foo() -> SomeInternalData:
return s
@app.post("/user")
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse:
@dataclass
class PTYPESchema:
user_id: int
@app.post("/user/u{user_id}r")
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema],
dat: SomeInternalData) -> SerializedResponse:
print(req)
print(dat)
resp: dict[str, Any] = {
'message': f'Hello, User {req.body.user_id}!',
'from': 'server',
'data': req.body.data,
'inj': dat.a
'inj': dat.a,
'user_id': req.path_matches.user_id
}
return SerializedResponse(200, {}, resp)

View File

@@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest
from .router import Router
from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE
from breakshaft.convertor import ConvRepo
from .util import parse_content_type
@@ -88,7 +88,8 @@ class HTTPApp(ASGIApp):
sresp: SerializedResponse
resp: BasicResponse
try:
handler = self.router.match(method, path)
matches, handler = self.router.match(method, path)
req.path_matches = matches
await handler(send, req)
return
@@ -187,12 +188,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self.inj_repo = ConvRepo()
@self.inj_repo.mark_injector()
def extract_query(req: BasicRequest) -> QTYPE:
def extract_query(req: BasicRequest | SerializedRequest) -> QTYPE:
return req.query
@self.inj_repo.mark_injector()
def extract_query(req: SerializedRequest) -> QTYPE:
return req.query
def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE:
return req.path_matches
@self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE:
@@ -226,15 +227,26 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
if req_schema is None:
raise ValueError(f'Unable to find request schema in handler {fn}')
query_type, body_type = get_args(req_schema)
query_type, body_type, path_type = get_args(req_schema)
q_inflator = None
b_inflator = None
p_inflator = None
fork_with = set()
def none_generator(*args) -> None:
return None
if path_type not in [PTYPE, None, Any]:
p_inflator = self.infl_generator.schema_to_inflator(
path_type,
strict_mode_override=False,
from_type_override=QTYPE
)
fork_with.add(ConversionPoint(p_inflator, path_type, (PTYPE,)))
else:
fork_with.add(ConversionPoint(none_generator, path_type, (PTYPE,)))
if query_type not in [QTYPE, None, Any]:
q_inflator = self.infl_generator.schema_to_inflator(
query_type,
@@ -254,11 +266,16 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
else:
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
def construct_unwrap(q, b) -> UnwrappedRequest:
print(f'unwrapping {query_type} and {body_type} with {q} and {b}')
return UnwrappedRequest(q, b)
def construct_unwrap(q, b, p) -> UnwrappedRequest:
return UnwrappedRequest(q, b, p)
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))}
fork_with |= {
ConversionPoint(
construct_unwrap,
req_schema,
((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE))
)
}
tmp_repo = self.inj_repo.fork(fork_with)

View File

@@ -17,6 +17,7 @@ class BasicRequest:
headers: CaseInsensitiveDict[str, str]
query: dict[str, list[Any] | Any]
body: bytes
path_matches: dict[str, str]
def __init__(self,
method: MethodType,
@@ -29,6 +30,7 @@ class BasicRequest:
self.headers = CaseInsensitiveDict(headers)
self.query = query
self.body = body
self.path_matches = {}
@classmethod
def from_scope(cls, scope: Scope) -> BasicRequest:

View File

@@ -30,6 +30,10 @@ class SerializedRequest:
def headers(self) -> CaseInsensitiveDict[str, str]:
return self.basic.headers
@property
def path_matches(self) -> dict[str, str]:
return self.basic.path_matches
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
raise NotImplementedError()

View File

@@ -19,3 +19,4 @@ type MethodType = (
type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
type PTYPE = Annotated[dict[str, str], 'path_matches']

View File

@@ -1,13 +1,13 @@
from dataclasses import dataclass
from typing import Any, TypeVar
from mypy.visitor import Generic
from typing import Any, TypeVar, Generic
Q = TypeVar('Q')
B = TypeVar('B')
P = TypeVar('P')
@dataclass
class UnwrappedRequest(Generic[Q, B]):
class UnwrappedRequest(Generic[Q, B, P]):
query: Q
body: B
path_matches: P

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import re
from re import Pattern
import typing
from typing import Optional, Sequence
from pathspec import Pattern
from .exceptions import MethodNotAllowedException, NotFoundException
from .types import InternalHandlerType
from .internal_types import MethodType
@@ -13,7 +12,7 @@ from .internal_types import MethodType
class Route:
static_subroutes: dict[str, Route]
regexp_subroutes: list[tuple[Pattern, Route]]
regexp_subroutes: list[tuple[Pattern, list[str], Route]]
handler: dict[MethodType, InternalHandlerType]
def __init__(self) -> None:
@@ -21,15 +20,15 @@ class Route:
self.regexp_subroutes = []
self.handler = {}
def _find_regexp_subroute(self, p: Pattern) -> Optional[Route]:
for _p, _r in self.regexp_subroutes:
def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]:
for _p, _n, _r in self.regexp_subroutes:
if _p == p:
return _r
return _n, _r
return None
def _add_regexp_subroute(self, p: Pattern, r: Route):
def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route):
if self._find_regexp_subroute(p) is None:
self.regexp_subroutes.append((p, r))
self.regexp_subroutes.append((p, names, r))
self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True)
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
@@ -43,13 +42,16 @@ class Route:
if '}' not in part:
raise ValueError(f'Invalid subpath substitute placeholder: {part}')
re_part = part.replace('(', '\\(').replace(')', '\\)')
re_part = re.sub(r'\{.+}', r'(.+)', re_part)
names = re.findall(r'\{(.*)}', part)
re_part = re.sub(r'\{.*}', r'(.*)', re_part)
re_part = re.compile('^' + re_part + '$')
subroute = self._find_regexp_subroute(re_part)
if subroute is None:
d = self._find_regexp_subroute(re_part)
if d is None:
subroute = Route()
self._add_regexp_subroute(re_part, subroute)
self._add_regexp_subroute(re_part, names, subroute)
else:
_, subroute = d
else:
subroute = self.static_subroutes.get(part)
if subroute is None:
@@ -58,26 +60,35 @@ class Route:
subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]:
def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]:
if len(sequence) == 0:
if len(self.handler) == 0:
raise NotFoundException('')
ret = self.handler.get(method)
if ret is None:
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
return ret
return {}, ret
subroute = self.static_subroutes.get(sequence[0])
matches = {}
if subroute is None:
for p, sr in self.regexp_subroutes:
m = p.match(sequence[0])
if m is not None:
for p, subst_names, sr in self.regexp_subroutes:
m = p.findall(sequence[0])
if len(m) > 0:
if len(m) != len(subst_names):
raise RuntimeError('Unable to match substitutes')
subroute = sr
for k, v in zip(subst_names, m):
matches[k] = v
break
if subroute is None:
raise NotFoundException('/'.join(sequence))
return subroute.get(method, sequence[1:])
submatches, handler = subroute.get(method, sequence[1:])
submatches |= matches
return submatches, handler
class Router:
@@ -87,14 +98,22 @@ class Router:
self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
assert method.upper() == method
method = typing.cast(MethodType, method.upper())
substitutes = re.findall(r'\{(.*)}', path_pattern)
if len(substitutes) > 0:
subst = []
for s in substitutes:
subst += list(s)
if len(subst) != len(set(subst)):
raise ValueError('Duplicate path substitute names are prohibited')
segments = path_pattern.split('/')
while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:]
self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> InternalHandlerType:
def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]:
method = typing.cast(MethodType, method.upper())
segments = path.split('/')

View File

@@ -46,7 +46,7 @@ def test_router_match():
r = Router()
r.add('GET', 'asdf', f)
assert r.match('GET', '/asdf')
assert r.match('GET', '/asdf') == ({}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd')
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
@@ -64,9 +64,9 @@ def test_router_pattern_match():
r.add('GET', '/{some}/asdf', f)
r.add('GET', '/{some}/b{some1}c', f)
assert r.match('GET', '/1234/asdf')
assert r.match('GET', '/ /asdf')
assert r.match('GET', '/ /basdfc')
assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f)
assert r.match('GET', '/ /asdf') == ({'some': ' '}, f)
assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd')