From e2e66ef14eb3fc1392a11b2334414ec202a420f2 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 01:31:25 +0300 Subject: [PATCH] Wrap api endpoint handlers into a `breakshaft` DI repo --- src/turbosloth/__main__.py | 32 +++- src/turbosloth/app.py | 153 ++++++++++++++---- .../interfaces/serialize_selector.py | 20 +-- src/turbosloth/internal_types.py | 5 +- src/turbosloth/types.py | 11 +- uv.lock | 12 +- 6 files changed, 171 insertions(+), 62 deletions(-) diff --git a/src/turbosloth/__main__.py b/src/turbosloth/__main__.py index c7e50cd..27074f7 100644 --- a/src/turbosloth/__main__.py +++ b/src/turbosloth/__main__.py @@ -3,15 +3,18 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Optional +import uvicorn + from turbosloth import SlothApp from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest +from turbosloth.internal_types import QTYPE, BTYPE from turbosloth.req_schema import UnwrappedRequest app = SlothApp() @app.get("/") -async def index(req: UnwrappedRequest[Any, Any]) -> SerializedResponse: +async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse: return SerializedResponse(200, {}, 'Hello, ASGI Router!') @@ -21,7 +24,7 @@ class UserIdSchema: @app.get("/user/") -async def get_user(req: UnwrappedRequest[UserIdSchema, Any]) -> SerializedResponse: +async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse: print(req) resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} return SerializedResponse(200, {}, resp) @@ -39,8 +42,29 @@ class UserPostSchema(UserIdSchema): data: SomeData +class SomeInternalData: + a: list[int] = [0] + + +@app.inj_repo.mark_injector() +def foo() -> SomeInternalData: + s = SomeInternalData() + s.a[0] += 1 + return s + + @app.post("/user") -async def post_user(req: UnwrappedRequest[Any, UserPostSchema]) -> SerializedResponse: +async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse: print(req) - resp: dict[str, Any] = {'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data} + print(dat) + resp: dict[str, Any] = { + 'message': f'Hello, User {req.body.user_id}!', + 'from': 'server', + 'data': req.body.data, + 'inj': dat.a + } return SerializedResponse(200, {}, resp) + + +if __name__ == '__main__': + uvicorn.run('turbosloth.__main__:app', host='0.0.0.0', port=8000, reload=True) diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index 8de5f56..d59627a 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -1,6 +1,9 @@ -from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any +import typing +from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any, Annotated +import breakshaft.util_mermaid import megasniff.exceptions +from breakshaft.models import ConversionPoint from megasniff import SchemaInflatorGenerator from .exceptions import HTTPException @@ -10,8 +13,11 @@ from .interfaces.serialized import SerializedResponse, SerializedRequest from .interfaces.serialized.text import TextSerializedResponse from .req_schema import UnwrappedRequest from .router import Router -from .types import HandlerType, InternalHandlerType -from .internal_types import Scope, Receive, Send, MethodType +from .types import HandlerType, InternalHandlerType, ContentType +from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE +from breakshaft.convertor import ConvRepo + +from .util import parse_content_type class ASGIApp(Protocol): @@ -24,6 +30,45 @@ class ASGIApp(Protocol): class HTTPApp(ASGIApp): serialize_selector: SerializeSelector + def extract_content_type(self, req: BasicRequest) -> ContentType: + contenttype_header = req.headers.get('Content-Type') + properties: dict[str, str] + if contenttype_header is None: + contenttype = self.serialize_selector.default_content_type + properties = {} + else: + contenttype, properties = parse_content_type(contenttype_header) + + charset = properties.get('charset') + + if charset is None: + if contenttype == 'application/json': + charset = 'utf-8' + else: + charset = 'latin1' + + return ContentType(contenttype, charset) + + def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest: + ser = self.serialize_selector.select(ct.contenttype, ct.charset) + return ser.req.deserialize(req, ct.charset) + + def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: ContentType) -> BasicResponse: + ser = self.serialize_selector.select(ct.contenttype, ct.charset) + sresponser = ser.resp + + try: + return sresponser.into_basic(sresp, ct.charset) + except UnicodeEncodeError: + return sresponser.into_basic(sresp, 'utf-8') + + async def send_answer(self, send: Send, resp: BasicResponse): + await send(resp.into_start_message()) + await send({ + 'type': 'http.response.body', + 'body': resp.body, + }) + async def _do_http(self, scope: Scope, receive: Receive, send: Send) -> None: method = scope['method'] path = scope['path'] @@ -45,31 +90,26 @@ class HTTPApp(ASGIApp): try: handler = self.router.match(method, path) - ser = self.serialize_selector.select(req.headers) - sresponser = ser.resp - charset = ser.charset - sreq = ser.req.deserialize(req, charset) - - sresp = await handler(sreq) + await handler(send, req) + return except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException): sresp = SerializedResponse(400, {}, 'Schema error') except HTTPException as e: sresp = SerializedResponse(e.code, {}, str(e)) try: - resp = sresponser.into_basic(sresp, charset) + ct = self.extract_content_type(req) + resp = sresponser.into_basic(sresp, ct.charset) except UnicodeEncodeError: resp = sresponser.into_basic(sresp, 'utf-8') - await send(resp.into_start_message()) - await send({ - 'type': 'http.response.body', - 'body': resp.body, - }) + # пока без малейшего понятия как избавиться от ручной обертки обработки ошибок + await self.send_answer(send, resp) class WSApp(ASGIApp): async def _do_websocket(self, scope: Scope, receive: Receive, send: Send) -> None: + # TODO: impl ws raise NotImplementedError() @@ -144,6 +184,23 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): self._on_shutdown = on_shutdown self.serialize_selector = SerializeSelector() self.infl_generator = SchemaInflatorGenerator(strict_mode=True) + self.inj_repo = ConvRepo() + + @self.inj_repo.mark_injector() + def extract_query(req: BasicRequest) -> QTYPE: + return req.query + + @self.inj_repo.mark_injector() + def extract_query(req: SerializedRequest) -> QTYPE: + return req.query + + @self.inj_repo.mark_injector() + def extract_body(req: SerializedRequest) -> BTYPE: + return req.body + + self.inj_repo.add_injector(self.extract_content_type) + self.inj_repo.add_injector(self.serialize_request) + self.inj_repo.add_injector(self.serialize_response) async def __call__(self, scope: Scope, receive: Receive, send: Send): t = scope['type'] @@ -158,9 +215,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): def route(self, method: MethodType, path_pattern: str): def decorator(fn: HandlerType): - hints = get_type_hints(fn) + handle_hints = get_type_hints(fn) req_schema = None - for argname, tp in hints.items(): + for argname, tp in handle_hints.items(): if argname == 'return': continue if get_origin(tp) == UnwrappedRequest: @@ -172,24 +229,56 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): query_type, body_type = get_args(req_schema) q_inflator = None b_inflator = None - if query_type != Any: - q_inflator = self.infl_generator.schema_to_inflator(query_type, strict_mode_override=False) - if body_type != Any: - b_inflator = self.infl_generator.schema_to_inflator(body_type) - def internal_handler(req: SerializedRequest) -> Awaitable[SerializedResponse]: - if q_inflator is not None: - q = q_inflator(req.query) - else: - q = req.query - if b_inflator is not None: - b = b_inflator(req.body) - else: - b = req.body + fork_with = set() - return fn(UnwrappedRequest(q, b)) + def none_generator(*args) -> None: + return None - self.router.add(method, path_pattern, internal_handler) + if query_type not in [QTYPE, None, Any]: + q_inflator = self.infl_generator.schema_to_inflator( + query_type, + strict_mode_override=False, + from_type_override=QTYPE + ) + fork_with.add(ConversionPoint(q_inflator, query_type, (QTYPE,))) + else: + fork_with.add(ConversionPoint(none_generator, query_type, (QTYPE,))) + + if body_type != [BTYPE, None, Any]: + b_inflator = self.infl_generator.schema_to_inflator( + body_type, + from_type_override=BTYPE + ) + fork_with.add(ConversionPoint(b_inflator, body_type, (BTYPE,))) + else: + fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,))) + + def construct_unwrap(q, b) -> UnwrappedRequest: + print(f'unwrapping {query_type} and {body_type} with {q} and {b}') + return UnwrappedRequest(q, b) + + fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))} + + tmp_repo = self.inj_repo.fork(fork_with) + + conv = tmp_repo.get_conversion( + (BasicRequest,), + fn, + force_async=True, + ) + + out_conv = self.inj_repo.get_conversion( + (Send, BasicRequest, handle_hints['return'],), + self.send_answer, + force_async=True + ) + + async def pipeline(send: Send, req: BasicRequest): + ret = await conv(req) + await out_conv(send, req, ret) + + self.router.add(method, path_pattern, pipeline) return fn return decorator diff --git a/src/turbosloth/interfaces/serialize_selector.py b/src/turbosloth/interfaces/serialize_selector.py index 71767b5..d909b40 100644 --- a/src/turbosloth/interfaces/serialize_selector.py +++ b/src/turbosloth/interfaces/serialize_selector.py @@ -12,7 +12,6 @@ class SerializeChoise(NamedTuple): req: type[SerializedRequest] resp: type[SerializedResponse] charset: str - content_properties: dict[str, str] class SerializeSelector: @@ -34,22 +33,7 @@ class SerializeSelector: ser[k] = v self.serializers = ser - def select(self, headers: CaseInsensitiveDict[str, str]) -> SerializeChoise: - contenttype_header = headers.get('Content-Type') - properties: dict[str, str] - if contenttype_header is None: - contenttype = self.default_content_type - properties = {} - else: - contenttype, properties = parse_content_type(contenttype_header) - - charset = properties.get('charset') - - if charset is None: - if contenttype == 'application/json': - charset = 'utf-8' - else: - charset = 'latin1' + def select(self, contenttype: str, charset: str) -> SerializeChoise: choise = self.serializers.get(typing.cast(str, contenttype)) if choise is None and self.default_content_type is not None: @@ -58,4 +42,4 @@ class SerializeSelector: raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys())) req, resp = choise - return SerializeChoise(req, resp, charset, properties) + return SerializeChoise(req, resp, charset) diff --git a/src/turbosloth/internal_types.py b/src/turbosloth/internal_types.py index 178002e..74df546 100644 --- a/src/turbosloth/internal_types.py +++ b/src/turbosloth/internal_types.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Awaitable, Literal +from typing import Any, Callable, Awaitable, Literal, Annotated type Scope = dict[str, Any] type ASGIMessage = dict[str, Any] @@ -16,3 +16,6 @@ type MethodType = ( Literal['CONNECT'] | Literal['OPTIONS'] | Literal['TRACE']) + +type QTYPE = Annotated[dict[str, Any], 'query_params'] +type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body'] diff --git a/src/turbosloth/types.py b/src/turbosloth/types.py index 94c00d3..7613ce4 100644 --- a/src/turbosloth/types.py +++ b/src/turbosloth/types.py @@ -1,9 +1,18 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Callable, Awaitable, Literal, Any +from turbosloth.interfaces.base import BasicRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest +from turbosloth.internal_types import Send from turbosloth.req_schema import UnwrappedRequest type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]] -type InternalHandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]] +type InternalHandlerType = Callable[[Send, BasicRequest], Awaitable[None]] + + +@dataclass +class ContentType: + contenttype: str + charset: str diff --git a/uv.lock b/uv.lock index b576230..f0b8840 100644 --- a/uv.lock +++ b/uv.lock @@ -4,15 +4,15 @@ requires-python = ">=3.13" [[package]] name = "breakshaft" -version = "0.1.0" +version = "0.1.0.post2" source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } dependencies = [ { name = "hatchling" }, { name = "jinja2" }, ] -sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0.tar.gz", hash = "sha256:77f114bab957b78f96b5b087e8ef9778c9b1fc4b3433f0b428d9dd084243debe" } +sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2.tar.gz", hash = "sha256:d04f8685336080cddb5111422a9624c8afcc8c47264e1b847339139bfd690570" } wheels = [ - { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0-py3-none-any.whl", hash = "sha256:e7c8831e5889fbcd788d59fee05f7aee113663ebef9273cd280893e1ed04f5c2" }, + { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2-py3-none-any.whl", hash = "sha256:fd4c9b213e2569c2c4d4f86cef2fe04065c08f9bffac3aa46b8ae21459f35074" }, ] [[package]] @@ -184,15 +184,15 @@ wheels = [ [[package]] name = "megasniff" -version = "0.2.0" +version = "0.2.1" source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } dependencies = [ { name = "hatchling" }, { name = "jinja2" }, ] -sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0.tar.gz", hash = "sha256:feceefd8a618c7b930832b4ff8edcc474961f5d5f53743c64e142a796fea0e30" } +sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1.tar.gz", hash = "sha256:c6ed47b40fbdc92d7aa061267d158eea457b560f1bc8f84d297b55e5e4a0ef2e" } wheels = [ - { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0-py3-none-any.whl", hash = "sha256:4f03a1e81dcb1020b6fe80f3fbc7b192a61b4dc55f7c6783141aa464dc3c77d1" }, + { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1-py3-none-any.whl", hash = "sha256:cff759fd61e9a4b8634329620cf13a941d7d5b38b834f8eab68e6ac78fa23589" }, ] [[package]]