Router path substitution #1
@@ -7,14 +7,14 @@ import uvicorn
|
|||||||
|
|
||||||
from turbosloth import SlothApp
|
from turbosloth import SlothApp
|
||||||
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
|
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
|
||||||
from turbosloth.internal_types import QTYPE, BTYPE
|
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE
|
||||||
from turbosloth.req_schema import UnwrappedRequest
|
from turbosloth.req_schema import UnwrappedRequest
|
||||||
|
|
||||||
app = SlothApp()
|
app = SlothApp()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse:
|
async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse:
|
||||||
return SerializedResponse(200, {}, 'Hello, ASGI Router!')
|
return SerializedResponse(200, {}, 'Hello, ASGI Router!')
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class UserIdSchema:
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/user/")
|
@app.get("/user/")
|
||||||
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse:
|
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse:
|
||||||
print(req)
|
print(req)
|
||||||
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
|
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
|
||||||
return SerializedResponse(200, {}, resp)
|
return SerializedResponse(200, {}, resp)
|
||||||
@@ -53,15 +53,22 @@ def foo() -> SomeInternalData:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@app.post("/user")
|
@dataclass
|
||||||
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse:
|
class PTYPESchema:
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/user/u{user_id}r")
|
||||||
|
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema],
|
||||||
|
dat: SomeInternalData) -> SerializedResponse:
|
||||||
print(req)
|
print(req)
|
||||||
print(dat)
|
print(dat)
|
||||||
resp: dict[str, Any] = {
|
resp: dict[str, Any] = {
|
||||||
'message': f'Hello, User {req.body.user_id}!',
|
'message': f'Hello, User {req.body.user_id}!',
|
||||||
'from': 'server',
|
'from': 'server',
|
||||||
'data': req.body.data,
|
'data': req.body.data,
|
||||||
'inj': dat.a
|
'inj': dat.a,
|
||||||
|
'user_id': req.path_matches.user_id
|
||||||
}
|
}
|
||||||
return SerializedResponse(200, {}, resp)
|
return SerializedResponse(200, {}, resp)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse
|
|||||||
from .req_schema import UnwrappedRequest
|
from .req_schema import UnwrappedRequest
|
||||||
from .router import Router
|
from .router import Router
|
||||||
from .types import HandlerType, InternalHandlerType, ContentType
|
from .types import HandlerType, InternalHandlerType, ContentType
|
||||||
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE
|
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE
|
||||||
from breakshaft.convertor import ConvRepo
|
from breakshaft.convertor import ConvRepo
|
||||||
|
|
||||||
from .util import parse_content_type
|
from .util import parse_content_type
|
||||||
@@ -88,7 +88,8 @@ class HTTPApp(ASGIApp):
|
|||||||
sresp: SerializedResponse
|
sresp: SerializedResponse
|
||||||
resp: BasicResponse
|
resp: BasicResponse
|
||||||
try:
|
try:
|
||||||
handler = self.router.match(method, path)
|
matches, handler = self.router.match(method, path)
|
||||||
|
req.path_matches = matches
|
||||||
|
|
||||||
await handler(send, req)
|
await handler(send, req)
|
||||||
return
|
return
|
||||||
@@ -187,12 +188,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
|||||||
self.inj_repo = ConvRepo()
|
self.inj_repo = ConvRepo()
|
||||||
|
|
||||||
@self.inj_repo.mark_injector()
|
@self.inj_repo.mark_injector()
|
||||||
def extract_query(req: BasicRequest) -> QTYPE:
|
def extract_query(req: BasicRequest | SerializedRequest) -> QTYPE:
|
||||||
return req.query
|
return req.query
|
||||||
|
|
||||||
@self.inj_repo.mark_injector()
|
@self.inj_repo.mark_injector()
|
||||||
def extract_query(req: SerializedRequest) -> QTYPE:
|
def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE:
|
||||||
return req.query
|
return req.path_matches
|
||||||
|
|
||||||
@self.inj_repo.mark_injector()
|
@self.inj_repo.mark_injector()
|
||||||
def extract_body(req: SerializedRequest) -> BTYPE:
|
def extract_body(req: SerializedRequest) -> BTYPE:
|
||||||
@@ -226,15 +227,26 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
|||||||
if req_schema is None:
|
if req_schema is None:
|
||||||
raise ValueError(f'Unable to find request schema in handler {fn}')
|
raise ValueError(f'Unable to find request schema in handler {fn}')
|
||||||
|
|
||||||
query_type, body_type = get_args(req_schema)
|
query_type, body_type, path_type = get_args(req_schema)
|
||||||
q_inflator = None
|
q_inflator = None
|
||||||
b_inflator = None
|
b_inflator = None
|
||||||
|
p_inflator = None
|
||||||
|
|
||||||
fork_with = set()
|
fork_with = set()
|
||||||
|
|
||||||
def none_generator(*args) -> None:
|
def none_generator(*args) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if path_type not in [PTYPE, None, Any]:
|
||||||
|
p_inflator = self.infl_generator.schema_to_inflator(
|
||||||
|
path_type,
|
||||||
|
strict_mode_override=False,
|
||||||
|
from_type_override=QTYPE
|
||||||
|
)
|
||||||
|
fork_with.add(ConversionPoint(p_inflator, path_type, (PTYPE,)))
|
||||||
|
else:
|
||||||
|
fork_with.add(ConversionPoint(none_generator, path_type, (PTYPE,)))
|
||||||
|
|
||||||
if query_type not in [QTYPE, None, Any]:
|
if query_type not in [QTYPE, None, Any]:
|
||||||
q_inflator = self.infl_generator.schema_to_inflator(
|
q_inflator = self.infl_generator.schema_to_inflator(
|
||||||
query_type,
|
query_type,
|
||||||
@@ -254,11 +266,16 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
|||||||
else:
|
else:
|
||||||
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
|
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
|
||||||
|
|
||||||
def construct_unwrap(q, b) -> UnwrappedRequest:
|
def construct_unwrap(q, b, p) -> UnwrappedRequest:
|
||||||
print(f'unwrapping {query_type} and {body_type} with {q} and {b}')
|
return UnwrappedRequest(q, b, p)
|
||||||
return UnwrappedRequest(q, b)
|
|
||||||
|
|
||||||
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))}
|
fork_with |= {
|
||||||
|
ConversionPoint(
|
||||||
|
construct_unwrap,
|
||||||
|
req_schema,
|
||||||
|
((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
tmp_repo = self.inj_repo.fork(fork_with)
|
tmp_repo = self.inj_repo.fork(fork_with)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class BasicRequest:
|
|||||||
headers: CaseInsensitiveDict[str, str]
|
headers: CaseInsensitiveDict[str, str]
|
||||||
query: dict[str, list[Any] | Any]
|
query: dict[str, list[Any] | Any]
|
||||||
body: bytes
|
body: bytes
|
||||||
|
path_matches: dict[str, str]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
method: MethodType,
|
method: MethodType,
|
||||||
@@ -29,6 +30,7 @@ class BasicRequest:
|
|||||||
self.headers = CaseInsensitiveDict(headers)
|
self.headers = CaseInsensitiveDict(headers)
|
||||||
self.query = query
|
self.query = query
|
||||||
self.body = body
|
self.body = body
|
||||||
|
self.path_matches = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_scope(cls, scope: Scope) -> BasicRequest:
|
def from_scope(cls, scope: Scope) -> BasicRequest:
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ class SerializedRequest:
|
|||||||
def headers(self) -> CaseInsensitiveDict[str, str]:
|
def headers(self) -> CaseInsensitiveDict[str, str]:
|
||||||
return self.basic.headers
|
return self.basic.headers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path_matches(self) -> dict[str, str]:
|
||||||
|
return self.basic.path_matches
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
|
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ type MethodType = (
|
|||||||
|
|
||||||
type QTYPE = Annotated[dict[str, Any], 'query_params']
|
type QTYPE = Annotated[dict[str, Any], 'query_params']
|
||||||
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
|
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
|
||||||
|
type PTYPE = Annotated[dict[str, str], 'path_matches']
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar, Generic
|
||||||
|
|
||||||
from mypy.visitor import Generic
|
|
||||||
|
|
||||||
Q = TypeVar('Q')
|
Q = TypeVar('Q')
|
||||||
B = TypeVar('B')
|
B = TypeVar('B')
|
||||||
|
P = TypeVar('P')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnwrappedRequest(Generic[Q, B]):
|
class UnwrappedRequest(Generic[Q, B, P]):
|
||||||
query: Q
|
query: Q
|
||||||
body: B
|
body: B
|
||||||
|
path_matches: P
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from re import Pattern
|
||||||
import typing
|
import typing
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
@@ -10,34 +12,83 @@ from .internal_types import MethodType
|
|||||||
|
|
||||||
class Route:
|
class Route:
|
||||||
static_subroutes: dict[str, Route]
|
static_subroutes: dict[str, Route]
|
||||||
|
regexp_subroutes: list[tuple[Pattern, list[str], Route]]
|
||||||
handler: dict[MethodType, InternalHandlerType]
|
handler: dict[MethodType, InternalHandlerType]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.static_subroutes = {}
|
self.static_subroutes = {}
|
||||||
|
self.regexp_subroutes = []
|
||||||
self.handler = {}
|
self.handler = {}
|
||||||
|
|
||||||
|
def _find_regexp_subroute(self, p: Pattern) -> Optional[tuple[list[str], Route]]:
|
||||||
|
for _p, _n, _r in self.regexp_subroutes:
|
||||||
|
if _p == p:
|
||||||
|
return _n, _r
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _add_regexp_subroute(self, p: Pattern, names: list[str], r: Route):
|
||||||
|
if self._find_regexp_subroute(p) is None:
|
||||||
|
self.regexp_subroutes.append((p, names, r))
|
||||||
|
self.regexp_subroutes.sort(key=lambda x: len(x[0].pattern), reverse=True)
|
||||||
|
|
||||||
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
|
def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
|
||||||
if len(sequence) == 0:
|
if len(sequence) == 0:
|
||||||
self.handler[method] = handler
|
self.handler[method] = handler
|
||||||
return
|
return
|
||||||
|
|
||||||
subroute = self.static_subroutes.get(sequence[0])
|
part = sequence[0]
|
||||||
if subroute is None:
|
|
||||||
subroute = Route()
|
if '{' in part:
|
||||||
self.static_subroutes[sequence[0]] = subroute
|
if '}' not in part:
|
||||||
|
raise ValueError(f'Invalid subpath substitute placeholder: {part}')
|
||||||
|
re_part = part.replace('(', '\\(').replace(')', '\\)')
|
||||||
|
names = re.findall(r'\{(.*?)}', part)
|
||||||
|
re_part = re.sub(r'\{.*?}', r'(.*?)', re_part)
|
||||||
|
re_part = re.compile('^' + re_part + '$')
|
||||||
|
|
||||||
|
d = self._find_regexp_subroute(re_part)
|
||||||
|
if d is None:
|
||||||
|
subroute = Route()
|
||||||
|
self._add_regexp_subroute(re_part, names, subroute)
|
||||||
|
else:
|
||||||
|
_, subroute = d
|
||||||
|
else:
|
||||||
|
subroute = self.static_subroutes.get(part)
|
||||||
|
if subroute is None:
|
||||||
|
subroute = Route()
|
||||||
|
self.static_subroutes[part] = subroute
|
||||||
|
|
||||||
subroute.add(method, sequence[1:], handler)
|
subroute.add(method, sequence[1:], handler)
|
||||||
|
|
||||||
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]:
|
def get(self, method: MethodType, sequence: Sequence[str]) -> tuple[dict[str, str], Optional[InternalHandlerType]]:
|
||||||
if len(sequence) == 0:
|
if len(sequence) == 0:
|
||||||
|
if len(self.handler) == 0:
|
||||||
|
raise NotFoundException('')
|
||||||
ret = self.handler.get(method)
|
ret = self.handler.get(method)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
|
raise MethodNotAllowedException(', '.join(map(str, self.handler.keys())))
|
||||||
return ret
|
return {}, ret
|
||||||
|
|
||||||
subroute = self.static_subroutes.get(sequence[0])
|
subroute = self.static_subroutes.get(sequence[0])
|
||||||
|
|
||||||
|
matches = {}
|
||||||
|
if subroute is None:
|
||||||
|
for p, subst_names, sr in self.regexp_subroutes:
|
||||||
|
m = p.findall(sequence[0])
|
||||||
|
if len(m) > 0:
|
||||||
|
if len(m) != len(subst_names):
|
||||||
|
raise RuntimeError('Unable to match substitutes')
|
||||||
|
|
||||||
|
subroute = sr
|
||||||
|
for k, v in zip(subst_names, m):
|
||||||
|
matches[k] = v
|
||||||
|
break
|
||||||
|
|
||||||
if subroute is None:
|
if subroute is None:
|
||||||
raise NotFoundException('/'.join(sequence))
|
raise NotFoundException('/'.join(sequence))
|
||||||
return subroute.get(method, sequence[1:])
|
submatches, handler = subroute.get(method, sequence[1:])
|
||||||
|
submatches |= matches
|
||||||
|
return submatches, handler
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
@@ -47,14 +98,24 @@ class Router:
|
|||||||
self._root = Route()
|
self._root = Route()
|
||||||
|
|
||||||
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
|
def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
|
||||||
assert method.upper() == method
|
method = typing.cast(MethodType, method.upper())
|
||||||
|
|
||||||
|
substitutes = re.findall(r'\{(.*?)}', path_pattern)
|
||||||
|
if len(substitutes) > 0:
|
||||||
|
subst = []
|
||||||
|
for s in substitutes:
|
||||||
|
if isinstance(s, str):
|
||||||
|
s = [s]
|
||||||
|
subst += list(s)
|
||||||
|
if len(subst) != len(set(subst)):
|
||||||
|
raise ValueError('Duplicate path substitute names are prohibited')
|
||||||
|
|
||||||
segments = path_pattern.split('/')
|
segments = path_pattern.split('/')
|
||||||
while len(segments) > 0 and len(segments[0]) == 0:
|
while len(segments) > 0 and len(segments[0]) == 0:
|
||||||
segments = segments[1:]
|
segments = segments[1:]
|
||||||
self._root.add(method, segments, handler)
|
self._root.add(method, segments, handler)
|
||||||
|
|
||||||
def match(self, method: MethodType, path: str) -> InternalHandlerType:
|
def match(self, method: MethodType, path: str) -> tuple[dict[str, str], InternalHandlerType]:
|
||||||
method = typing.cast(MethodType, method.upper())
|
method = typing.cast(MethodType, method.upper())
|
||||||
segments = path.split('/')
|
segments = path.split('/')
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ from src.turbosloth.router import Router
|
|||||||
|
|
||||||
|
|
||||||
def test_router_root_handler():
|
def test_router_root_handler():
|
||||||
async def f(_: dict, *args):
|
async def f(*args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def d(_: dict, *args):
|
async def d(*args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def a(_: dict, *args):
|
async def a(*args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
r = Router()
|
r = Router()
|
||||||
@@ -38,19 +38,44 @@ def test_router_root_handler():
|
|||||||
|
|
||||||
|
|
||||||
def test_router_match():
|
def test_router_match():
|
||||||
async def f(_: dict, *args):
|
async def f(*args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def d(_: dict, *args):
|
async def d(*args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
r = Router()
|
r = Router()
|
||||||
r.add('GET', 'asdf', f)
|
r.add('GET', 'asdf', f)
|
||||||
assert r.match('GET', '/asdf')
|
assert r.match('GET', '/asdf') == ({}, f)
|
||||||
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
||||||
r.match('GET', 'asd')
|
r.match('GET', 'asd')
|
||||||
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
|
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /asdf, allowed: GET'):
|
||||||
r.match('POST', 'asdf')
|
r.match('POST', 'asdf')
|
||||||
|
|
||||||
with pytest.raises(MethodNotAllowedException, match=f'405\tMethod Not Allowed: POST /, allowed: none'):
|
with pytest.raises(NotFoundException, match=f'404\tNot Found'):
|
||||||
r.match('POST', '')
|
r.match('POST', '')
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_pattern_match():
|
||||||
|
async def f(*args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
r = Router()
|
||||||
|
r.add('GET', '/{some}/asdf', f)
|
||||||
|
r.add('GET', '/{some}/b{some1}c', f)
|
||||||
|
|
||||||
|
assert r.match('GET', '/1234/asdf') == ({'some': '1234'}, f)
|
||||||
|
assert r.match('GET', '/ /asdf') == ({'some': ' '}, f)
|
||||||
|
assert r.match('GET', '/ /basdfc') == ({'some': ' ', 'some1': 'asdf'}, f)
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
||||||
|
r.match('GET', 'asd')
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
||||||
|
r.match('GET', 'asd/')
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
||||||
|
r.match('GET', 'asd/b')
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundException, match='404\tNot Found: asd'):
|
||||||
|
r.match('GET', 'asd/basdf')
|
||||||
|
|||||||
Reference in New Issue
Block a user