Add header schema for UnwrappedRequest

This commit is contained in:
2025-07-21 15:58:33 +03:00
parent 6cb01f6204
commit a763f0960c
4 changed files with 23 additions and 17 deletions

View File

@@ -7,14 +7,14 @@ import uvicorn
from turbosloth import SlothApp from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
app = SlothApp() app = SlothApp()
@app.get("/") @app.get("/")
async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE]) -> SerializedResponse: async def index(req: UnwrappedRequest[QTYPE, BTYPE, PTYPE, HTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!') return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -24,7 +24,7 @@ class UserIdSchema:
@app.get("/user/") @app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE]) -> SerializedResponse: async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE, PTYPE, HTYPE]) -> SerializedResponse:
print(req) print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)
@@ -59,7 +59,7 @@ class PTYPESchema:
@app.post("/user/u{user_id}r") @app.post("/user/u{user_id}r")
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema], async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema, PTYPESchema, HTYPE],
dat: SomeInternalData) -> SerializedResponse: dat: SomeInternalData) -> SerializedResponse:
print(req) print(req)
print(dat) print(dat)

View File

@@ -14,7 +14,7 @@ from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest from .req_schema import UnwrappedRequest
from .router import Router, Route from .router import Router, Route
from .types import HandlerType, InternalHandlerType, ContentType from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYPE, HTYPE
from breakshaft.convertor import ConvRepo from breakshaft.convertor import ConvRepo
from .util import parse_content_type from .util import parse_content_type
@@ -198,6 +198,10 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE: def extract_path_matches(req: BasicRequest | SerializedRequest) -> PTYPE:
return req.path_matches return req.path_matches
@self.inj_repo.mark_injector()
def extract_headers(req: BasicRequest | SerializedRequest) -> HTYPE:
return req.headers
@self.inj_repo.mark_injector() @self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE: def extract_body(req: SerializedRequest) -> BTYPE:
return req.body return req.body
@@ -230,7 +234,8 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
if req_schema is None: if req_schema is None:
raise ValueError(f'Unable to find request schema in handler {fn}') raise ValueError(f'Unable to find request schema in handler {fn}')
query_type, body_type, path_type = get_args(req_schema) unwrap_types = get_args(req_schema)
defaults = (QTYPE, BTYPE, PTYPE, HTYPE)
def none_generator(*args) -> None: def none_generator(*args) -> None:
return None return None
@@ -243,24 +248,20 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
strict_mode_override=False, strict_mode_override=False,
from_type_override=def_type from_type_override=def_type
) )
return ConversionPoint(p_inflator, path_type, (PTYPE,), (), ) return ConversionPoint(infl, t, (def_type,), (), )
else: else:
return ConversionPoint(none_generator, path_type, (PTYPE,), (), ) return ConversionPoint(none_generator, t, (def_type,), (), )
q_inflator = create_convertor(query_type, QTYPE) fork_with = set(map(lambda x: create_convertor(*x), zip(unwrap_types, defaults)))
b_inflator = create_convertor(body_type, BTYPE)
p_inflator = create_convertor(path_type, PTYPE)
fork_with = {q_inflator, b_inflator, p_inflator} def construct_unwrap(q: QTYPE, b: BTYPE, p: PTYPE, h: HTYPE) -> UnwrappedRequest:
return UnwrappedRequest(q, b, p, h)
def construct_unwrap(q: QTYPE, b: BTYPE, p: PTYPE) -> UnwrappedRequest:
return UnwrappedRequest(q, b, p)
fork_with |= { fork_with |= {
ConversionPoint( ConversionPoint(
construct_unwrap, construct_unwrap,
req_schema, req_schema,
((query_type or QTYPE), (body_type or BTYPE), (path_type or PTYPE)), unwrap_types,
(), (),
) )
} }

View File

@@ -1,5 +1,7 @@
from typing import Any, Callable, Awaitable, Literal, Annotated from typing import Any, Callable, Awaitable, Literal, Annotated
from case_insensitive_dict import CaseInsensitiveDict
type Scope = dict[str, Any] type Scope = dict[str, Any]
type ASGIMessage = dict[str, Any] type ASGIMessage = dict[str, Any]
type Receive = Callable[[], Awaitable[ASGIMessage]] type Receive = Callable[[], Awaitable[ASGIMessage]]
@@ -20,3 +22,4 @@ type MethodType = (
type QTYPE = Annotated[dict[str, Any], 'query_params'] type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body'] type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
type PTYPE = Annotated[dict[str, str], 'path_matches'] type PTYPE = Annotated[dict[str, str], 'path_matches']
type HTYPE = Annotated[CaseInsensitiveDict[str, str], 'headers']

View File

@@ -4,10 +4,12 @@ from typing import Any, TypeVar, Generic
Q = TypeVar('Q') Q = TypeVar('Q')
B = TypeVar('B') B = TypeVar('B')
P = TypeVar('P') P = TypeVar('P')
H = TypeVar('H')
@dataclass @dataclass
class UnwrappedRequest(Generic[Q, B, P]): class UnwrappedRequest(Generic[Q, B, P, H]):
query: Q query: Q
body: B body: B
path_matches: P path_matches: P
headers: H