Router path substitution #1

Merged
nikto_b merged 4 commits from path-substitute-feature into master 2025-07-19 04:13:44 +03:00
8 changed files with 153 additions and 36 deletions

View File

@@ -7,14 +7,14 @@ import uvicorn
from turbosloth import SlothApp from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE from turbosloth.internal_types import QTYPE, BTYPE, PTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
app = SlothApp() app = SlothApp()
@app.get("/") @app.get("/")
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse: async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!') return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -24,7 +24,7 @@ class UserIdSchema:
@app.get("/user/") @app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse: async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse:
print(req) print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)
@@ -53,15 +53,22 @@ def foo() -> SomeInternalData:
return s return s
@app.post("/user") @dataclass
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse: class PTYPESchema:
user_id: int
@app.post("/user/u{user_id}r")
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema],
dat: SomeInternalData) -> SerializedResponse:
print(req) print(req)
print(dat) print(dat)
resp: dict[str, Any] = { resp: dict[str, Any] = {
'message': f'Hello, User {req.body.user_id}!', 'message': f'Hello, User {req.body.user_id}!',
'from': 'server', 'from': 'server',
'data': req.body.data, 'data': req.body.data,
'inj': dat.a 'inj': dat.a,
'user_id': req.path_matches.user_id
} }
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)

View File

@@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest from .req_schema import UnwrappedRequest
from .router import Router from .router import Router
from .types import HandlerType, InternalHandlerType, ContentType from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE
from breakshaft.convertor import ConvRepo from breakshaft.convertor import ConvRepo
from .util import parse_content_type from .util import parse_content_type
@@ -88,7 +88,8 @@ class HTTPApp(ASGIApp):
sresp: SerializedResponse sresp: SerializedResponse
resp: BasicResponse resp: BasicResponse
try: try:
handler = self.router.match(method, path) matches, handler = self.router.match(method, path)
req.path_matches = matches
await handler(send, req) await handler(send, req)
return return
@@ -187,12 +188,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self.inj_repo = ConvRepo() self.inj_repo = ConvRepo()
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_query(req: BasicRequest) -> QTYPE: def extract_query(req: BasicRequest | SerializedRequest) -> QTYPE:
return req.query return req.query
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_query(req: SerializedRequest) -> QTYPE: def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE:
return req.query return req.path_matches
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE: def extract_body(req: SerializedRequest) -> BTYPE:
@@ -226,15 +227,26 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
if req_schema is None: if req_schema is None:
raise ValueError(f'Unable to find request schema in handler {fn}') raise ValueError(f'Unable to find request schema in handler {fn}')
query_type, body_type = get_args(req_schema) query_type, body_type, path_type = get_args(req_schema)
q_inflator = None q_inflator = None
b_inflator = None b_inflator = None
p_inflator = None
fork_with = set() fork_with = set()
def none_generator(*args) -> None: def none_generator(*args) -> None:
return None return None
if path_type not in [PTYPE, None, Any]:
p_inflator = self.infl_generator.schema_to_inflator(
path_type,
strict_mode_override=False,
from_type_override=QTYPE
)
fork_with.add(ConversionPoint(p_inflator, path_type, (PTYPE,)))
else:
fork_with.add(ConversionPoint(none_generator, path_type, (PTYPE,)))
if query_type not in [QTYPE, None, Any]: if query_type not in [QTYPE, None, Any]:
q_inflator = self.infl_generator.schema_to_inflator( q_inflator = self.infl_generator.schema_to_inflator(
query_type, query_type,
@@ -254,11 +266,16 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
else: else:
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,))) fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
def construct_unwrap(q, b) -> UnwrappedRequest: def construct_unwrap(q, b, p) -> UnwrappedRequest:
print(f'unwrapping {query_type} and {body_type} with {q} and {b}') return UnwrappedRequest(q, b, p)
return UnwrappedRequest(q, b)
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))} fork_with |= {
ConversionPoint(
construct_unwrap,
req_schema,
((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE))
)
}
tmp_repo = self.inj_repo.fork(fork_with) tmp_repo = self.inj_repo.fork(fork_with)

View File

@@ -17,6 +17,7 @@ class BasicRequest:
headers: CaseInsensitiveDict[str, str] headers: CaseInsensitiveDict[str, str]
query: dict[str, list[Any] | Any] query: dict[str, list[Any] | Any]
body: bytes body: bytes
path_matches: dict[str, str]
def __init__(self, def __init__(self,
method: MethodType, method: MethodType,
@@ -29,6 +30,7 @@ class BasicRequest:
self.headers = CaseInsensitiveDict(headers) self.headers = CaseInsensitiveDict(headers)
self.query = query self.query = query
self.body = body self.body = body
self.path_matches = {}
@classmethod @classmethod
def from_scope(cls, scope: Scope) -> BasicRequest: def from_scope(cls, scope: Scope) -> BasicRequest:

View File

@@ -30,6 +30,10 @@ class SerializedRequest:
def headers(self) -> CaseInsensitiveDict[str, str]: def headers(self) -> CaseInsensitiveDict[str, str]:
return self.basic.headers return self.basic.headers
@property
def path_matches(self) -> dict[str, str]:
return self.basic.path_matches
@classmethod @classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
raise NotImplementedError() raise NotImplementedError()

View File

@@ -19,3 +19,4 @@ type MethodType = (
type QTYPE = Annotated[dict[str, Any], 'query_params'] type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body'] type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
type PTYPE = Annotated[dict[str, str], 'path_matches']

View File

@@ -1,13 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, TypeVar from typing import Any, TypeVar, Generic
from mypy.visitor import Generic
Q = TypeVar('Q') Q = TypeVar('Q')
B = TypeVar('B') B = TypeVar('B')
P = TypeVar('P')
@dataclass @dataclass
class UnwrappedRequest(Generic[Q, B]): class UnwrappedRequest(Generic[Q, B, P]):
query: Q query: Q
body: B body: B
path_matches: P

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re
from re import Pattern
import typing import typing
from typing import Optional, Sequence from typing import Optional, Sequence
@@ -10,34 +12,83 @@ from .internal_types import MethodType
class Route: class Route:
static_subroutes: dict[str, Route] static_subroutes: dict[str, Route]
regexp_subroutes: list[tuple[Pattern, list[str], Route]]
handler: dict[MethodType, InternalHandlerType] handler: dict[MethodType, InternalHandlerType]
def __init__(self) -> None: def __init__(self) -> None:
self.static_subroutes = {} self.static_subroutes = {}
self.regexp_subroutes = []
self.handler = {} self.handler = {}
def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]:
for _p, _n, _r in self.regexp_subroutes:
if _p == p:
return _n, _r
return None
def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route):
if self._find_regexp_subroute(p) is None:
self.regexp_subroutes.append((p, names, r))
self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True)
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None: def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
if len(sequence) == 0: if len(sequence) == 0:
self.handler[method] = handler self.handler[method] = handler
return return
subroute = self.static_subroutes.get(sequence[0]) part = sequence[0]
if subroute is None:
subroute = Route() if '{' in part:
self.static_subroutes[sequence[0]] = subroute if '}' not in part:
raise ValueError(f'Invalid subpath substitute placeholder: {part}')
re_part = part.replace('(', '\\(').replace(')', '\\)')
names = re.findall(r'\{(.*?)}', part)
re_part = re.sub(r'\{.*?}', r'(.*?)', re_part)
re_part = re.compile('^' + re_part + '$')
d = self._find_regexp_subroute(re_part)
if d is None:
subroute = Route()
self._add_regexp_subroute(re_part, names, subroute)
else:
_, subroute = d
else:
subroute = self.static_subroutes.get(part)
if subroute is None:
subroute = Route()
self.static_subroutes[part] = subroute
subroute.add(method, sequence[1:], handler) subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]: def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]:
if len(sequence) == 0: if len(sequence) == 0:
if len(self.handler) == 0:
raise NotFoundException('')
ret = self.handler.get(method) ret = self.handler.get(method)
if ret is None: if ret is None:
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys()))) raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
return ret return {}, ret
subroute = self.static_subroutes.get(sequence[0]) subroute = self.static_subroutes.get(sequence[0])
matches = {}
if subroute is None:
for p, subst_names, sr in self.regexp_subroutes:
m = p.findall(sequence[0])
if len(m) > 0:
if len(m) != len(subst_names):
raise RuntimeError('Unable to match substitutes')
subroute = sr
for k, v in zip(subst_names, m):
matches[k] = v
break
if subroute is None: if subroute is None:
raise NotFoundException('/'.join(sequence)) raise NotFoundException('/'.join(sequence))
return subroute.get(method, sequence[1:]) submatches, handler = subroute.get(method, sequence[1:])
submatches |= matches
return submatches, handler
class Router: class Router:
@@ -47,14 +98,24 @@ class Router:
self._root = Route() self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None: def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
assert method.upper() == method method = typing.cast(MethodType, method.upper())
substitutes = re.findall(r'\{(.*?)}', path_pattern)
if len(substitutes) > 0:
subst = []
for s in substitutes:
if isinstance(s, str):
s = [s]
subst += list(s)
if len(subst) != len(set(subst)):
raise ValueError('Duplicate path substitute names are prohibited')
segments = path_pattern.split('/') segments = path_pattern.split('/')
while len(segments) > 0 and len(segments[0]) == 0: while len(segments) > 0 and len(segments[0]) == 0:
segments = segments[1:] segments = segments[1:]
self._root.add(method, segments, handler) self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> InternalHandlerType: def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]:
method = typing.cast(MethodType, method.upper()) method = typing.cast(MethodType, method.upper())
segments = path.split('/') segments = path.split('/')

View File

@@ -5,13 +5,13 @@ from src.turbosloth.router import Router
def test_router_root_handler(): def test_router_root_handler():
async def f(_: dict, *args): async def f(*args):
pass pass
async def d(_: dict, *args): async def d(*args):
pass pass
async def a(_: dict, *args): async def a(*args):
pass pass
r = Router() r = Router()
@@ -38,19 +38,44 @@ def test_router_root_handler():
def test_router_match(): def test_router_match():
async def f(_: dict, *args): async def f(*args):
pass pass
async def d(_: dict, *args): async def d(*args):
pass pass
r = Router() r = Router()
r.add('GET', 'asdf', f) r.add('GET', 'asdf', f)
assert r.match('GET', '/asdf') assert r.match('GET', '/asdf') == ({}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'): with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd') r.match('GET', 'asd')
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'): with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
r.match('POST', 'asdf') r.match('POST', 'asdf')
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /, allowed: none'): with pytest.raises(NotFoundException, match=f'404\tNot Found'):
r.match('POST', '') r.match('POST', '')
def test_router_pattern_match():
async def f(*args):
pass
r = Router()
r.add('GET', '/{some}/asdf', f)
r.add('GET', '/{some}/b{some1}c', f)
assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f)
assert r.match('GET', '/ /asdf') == ({'some': ' '}, f)
assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f)
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd')
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd/')
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd/b')
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
r.match('GET', 'asd/basdf')