From faaa43fdf17d77163d555c62a334e40af977a677 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Wed, 20 Aug 2025 01:57:57 +0300 Subject: [PATCH] Allow DI args into an app handlers --- src/turbosloth/__main__.py | 21 ++++++++++++++++++--- src/turbosloth/app.py | 6 +++++- src/turbosloth/schema.py | 7 +++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/turbosloth/__main__.py b/src/turbosloth/__main__.py index 6a42cda..43361d8 100644 --- a/src/turbosloth/__main__.py +++ b/src/turbosloth/__main__.py @@ -12,7 +12,8 @@ from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE from turbosloth.req_schema import UnwrappedRequest from turbosloth.schema import RequestBody, HeaderParam -app = SlothApp(di_autodoc_prefix='/didoc', serialize_selector=SerializeSelector(default_content_type='application/json')) +app = SlothApp(di_autodoc_prefix='/didoc', + serialize_selector=SerializeSelector(default_content_type='application/json')) # @app.get("/") @@ -75,17 +76,31 @@ class PTYPESchema: # return SerializedResponse(200, {}, resp) +class DummyDbConnection: + constructions = [0] + + def __init__(self): + self.constructions[0] += 1 + + +@app.inj_repo.mark_injector() +def create_db_connection() -> DummyDbConnection: + return DummyDbConnection() + + @app.post("/test/body/{a}") async def test_body(r: RequestBody(UserPostSchema), q1: str, a: str, - h1: HeaderParam(str, 'header1')) -> SerializedResponse: + h1: HeaderParam(str, 'header1'), + db: DummyDbConnection) -> SerializedResponse: print(r.user_id) resp = { 'req': r, 'q1': q1, 'h1': h1, - 'a': a + 'a': a, + 'db': db.constructions } return SerializedResponse(200, {}, resp) diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index 55452c7..763005f 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -264,7 +264,11 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): def decorator(fn: HandlerType): path_substs = self.router.find_pattern_substs(path_pattern) - config = EndpointConfig.from_handler(fn, path_substs) + + injectors = self.inj_repo.filtered_injectors(True, True) + injected_types = list(map(lambda x: x.injects, injectors)) + + config = EndpointConfig.from_handler(fn, path_substs, injected_types) fork_with = set() diff --git a/src/turbosloth/schema.py b/src/turbosloth/schema.py index eb8e639..21ef2f4 100644 --- a/src/turbosloth/schema.py +++ b/src/turbosloth/schema.py @@ -1,7 +1,8 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, overload, Any, TypeAlias, Optional, get_origin, get_args, Callable, get_type_hints +from typing import TYPE_CHECKING, overload, Any, TypeAlias, Optional, get_origin, get_args, Callable, get_type_hints, \ + Iterable from dataclasses import dataclass from typing import TypeVar, Generic, Annotated @@ -72,7 +73,7 @@ class EndpointConfig: type_replacement: dict[str, type] @classmethod - def from_handler(cls, h: Callable, path_substituts: set[str]) -> EndpointConfig: + def from_handler(cls, h: Callable, path_substituts: set[str], ignore_types: Iterable[type]) -> EndpointConfig: body_schema = None query_schemas = {} path_schemas = {} @@ -82,6 +83,8 @@ class EndpointConfig: handle_hints = get_endpoint_params_info(h) for argname, s in handle_hints.items(): tp = s.schema + if tp in ignore_types: + continue type_replacement[argname] = s.replacement_type if argname == 'return':