From c40bdca9e4bd97d11710d3c097a6c7fe020ad2b2 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Wed, 20 Aug 2025 03:03:19 +0300 Subject: [PATCH] Allow DI injectors use request schemas in a required args --- src/turbosloth/__main__.py | 23 ++++- src/turbosloth/app.py | 156 ++++++++++++++++++++----------- src/turbosloth/didoc/__init__.py | 20 ++-- uv.lock | 12 +-- 4 files changed, 137 insertions(+), 74 deletions(-) diff --git a/src/turbosloth/__main__.py b/src/turbosloth/__main__.py index 43361d8..32817bb 100644 --- a/src/turbosloth/__main__.py +++ b/src/turbosloth/__main__.py @@ -10,7 +10,7 @@ from turbosloth.interfaces.serialize_selector import SerializeSelector from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE from turbosloth.req_schema import UnwrappedRequest -from turbosloth.schema import RequestBody, HeaderParam +from turbosloth.schema import RequestBody, HeaderParam, PathParam app = SlothApp(di_autodoc_prefix='/didoc', serialize_selector=SerializeSelector(default_content_type='application/json')) @@ -83,24 +83,39 @@ class DummyDbConnection: self.constructions[0] += 1 -@app.inj_repo.mark_injector() +@app.mark_injector() def create_db_connection() -> DummyDbConnection: return DummyDbConnection() +@dataclass +class User: + user_id: int + name: str + + +@app.mark_injector() +def auth_user(user: RequestBody(UserPostSchema)) -> User: + return User(user.user_id, f'user {user.user_id}') + + @app.post("/test/body/{a}") async def test_body(r: RequestBody(UserPostSchema), q1: str, a: str, h1: HeaderParam(str, 'header1'), - db: DummyDbConnection) -> SerializedResponse: + db: DummyDbConnection, + user: User, + q2: int = 321) -> SerializedResponse: print(r.user_id) resp = { 'req': r, 'q1': q1, 'h1': h1, 'a': a, - 'db': db.constructions + 'db': db.constructions, + 'user': user.__dict__, + 'q2': q2 } return SerializedResponse(200, {}, resp) diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index 763005f..c7ae850 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -24,6 +24,7 @@ from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYP from breakshaft.convertor import ConvRepo from .util import parse_content_type +import breakshaft.util class ASGIApp(Protocol): @@ -260,70 +261,29 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): raise RuntimeError(f'Unsupported scope type: {t}') + def _create_infl_from_schemas(self, schemas: Iterable[ParamSchema], from_type_override: type): + if len(schemas) > 0: + infl = self.infl_generator.schema_to_inflator( + list(schemas), + strict_mode_override=False, + ) + type_remap = { + 'from_data': from_type_override, + 'return': Tuple[tuple(t.replacement_type for t in schemas)] + } + return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap)) + return None, None + def route(self, method: MethodType, path_pattern: str): def decorator(fn: HandlerType): - path_substs = self.router.find_pattern_substs(path_pattern) - injectors = self.inj_repo.filtered_injectors(True, True) - injected_types = list(map(lambda x: x.injects, injectors)) - - config = EndpointConfig.from_handler(fn, path_substs, injected_types) - - fork_with = set() - - inflators_for_didoc = {} - - def create_infl_from_schemas(schemas: Iterable[ParamSchema], from_type_override: type): - if len(schemas) > 0: - infl = self.infl_generator.schema_to_inflator( - list(schemas), - strict_mode_override=False, - ) - type_remap = { - 'from_data': from_type_override, - 'return': Tuple[tuple(t.replacement_type for t in schemas)] - } - return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap)) - return None, None - - i, infl = create_infl_from_schemas(config.path_schemas.values(), PTYPE) - if infl is not None: - fork_with |= infl - inflators_for_didoc['Path inflator'] = i - i, infl = create_infl_from_schemas(config.query_schemas.values(), QTYPE) - if infl is not None: - fork_with |= infl - inflators_for_didoc['Query inflator'] = i - i, infl = create_infl_from_schemas(config.header_schemas.values(), HTYPE) - if infl is not None: - fork_with |= infl - inflators_for_didoc['Header inflator'] = i - - if config.body_schema is not None: - infl = self.infl_generator.schema_to_inflator( - config.body_schema.schema, - strict_mode_override=False, - ) - - type_remap = { - 'from_data': BTYPE, - 'return': config.body_schema.replacement_type - } - - fork_with |= set(ConversionPoint.from_fn(infl, type_remap=type_remap)) - inflators_for_didoc['Body inflator'] = infl - + fork_with, fn_type_hints = self._integrate_func(fn, path_substs) tmp_repo = self.inj_repo.fork(fork_with) - - fn_type_hints = get_type_hints(fn) - for k, v in config.type_replacement.items(): - fn_type_hints[k] = v - p = tmp_repo.create_pipeline( (Send, BasicRequest), [ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer], - force_async=True + force_async=True, ) self.router.add(method, path_pattern, p) @@ -344,8 +304,90 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): ) self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)( - create_di_autodoc_handler(method, path_pattern, p, inflators_for_didoc, depgraph)) + create_di_autodoc_handler(method, path_pattern, p, depgraph)) return fn return decorator + + def _didoc_name(self, master_name: str, schemas: dict[str, ParamSchema], ret_schema: type): + name = f'{master_name}[' + for s in schemas.values(): + name += f'{s.key_name}: ' + name += breakshaft.util.universal_qualname(s.schema) + ', ' + name = name[:-2] + '] -> ' + breakshaft.util.universal_qualname(ret_schema) + return name + + def _integrate_func(self, + func: Callable, + path_params: Optional[Iterable[str]] = None): + + injectors = self.inj_repo.filtered_injectors(True, True) + injected_types = list(map(lambda x: x.injects, injectors)) + + config = EndpointConfig.from_handler(func, path_params or set(), injected_types) + + fork_with = set() + + i_p, infl = self._create_infl_from_schemas(config.path_schemas.values(), PTYPE) + if infl is not None: + fork_with |= infl + i_q, infl = self._create_infl_from_schemas(config.query_schemas.values(), QTYPE) + if infl is not None: + fork_with |= infl + i_h, infl = self._create_infl_from_schemas(config.header_schemas.values(), HTYPE) + if infl is not None: + fork_with |= infl + + i_b = None + if config.body_schema is not None: + i_b = self.infl_generator.schema_to_inflator( + config.body_schema.schema, + strict_mode_override=False, + ) + + type_remap = { + 'from_data': BTYPE, + 'return': config.body_schema.replacement_type + } + + fork_with |= set(ConversionPoint.from_fn(i_b, type_remap=type_remap)) + + fn_type_hints = get_type_hints(func) + for k, v in config.type_replacement.items(): + fn_type_hints[k] = v + + main_cps = ConversionPoint.from_fn(func, type_remap=fn_type_hints) + main_injects = main_cps[0].injects + + if self.di_autodoc_prefix is not None: + if len(config.header_schemas) > 0: + setattr(i_h, '__turbosloth_DI_name__', + self._didoc_name('Inflator Header', config.header_schemas, main_injects)) + if len(config.query_schemas) > 0: + setattr(i_q, '__turbosloth_DI_name__', + self._didoc_name('Inflator Query', config.query_schemas, main_injects)) + if len(config.path_schemas) > 0: + setattr(i_p, '__turbosloth_DI_name__', + self._didoc_name('Inflator Path', config.path_schemas, main_injects)) + + if i_b is not None: + setattr(i_b, '__turbosloth_DI_name__', + self._didoc_name('Inflator Body', {config.body_schema.key_name: config.body_schema}, + main_injects)) + + setattr(func, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(main_injects)) + + fork_with |= set(ConversionPoint.from_fn(func, type_remap=fn_type_hints)) + return fork_with, fn_type_hints + + def add_injector(self, func: Callable): + fork_with, _ = self._integrate_func(func) + self.inj_repo.add_conversion_points(fork_with) + + def mark_injector(self): + def inner(func: Callable): + self.add_injector(func) + return func + + return inner diff --git a/src/turbosloth/didoc/__init__.py b/src/turbosloth/didoc/__init__.py index ba4232b..de2dd25 100644 --- a/src/turbosloth/didoc/__init__.py +++ b/src/turbosloth/didoc/__init__.py @@ -5,11 +5,13 @@ from typing import Callable, Awaitable import breakshaft.util_mermaid import jinja2 +from breakshaft.models import ConversionPoint from turbosloth.interfaces.serialized import SerializedResponse from turbosloth.interfaces.serialized.html import HTMLSerializedResponse from turbosloth.types import InternalHandlerType import breakshaft.renderer +import breakshaft.util @dataclass @@ -21,14 +23,22 @@ class MMDiagramData: def create_di_autodoc_handler(method: str, path: str, handler: InternalHandlerType, - param_inflators: dict[str, Callable], depgraph: str) -> Awaitable[SerializedResponse]: - callseq = getattr(handler, '__breakshaft_callseq__', []) + callseq: list[ConversionPoint] = getattr(handler, '__breakshaft_callseq__', []) mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid( list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq( breakshaft.renderer.render_data_from_callseq((), {}, callseq) ))) ) + + escaped_sources = [] + + for c in callseq: + if hasattr(c.fn, '__megasniff_sources__'): + name = getattr(c.fn, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(c.fn)) + escaped_sources.append( + (name, html.escape(getattr(c.fn, '__megasniff_sources__', '')))) + sources = getattr(handler, '__breakshaft_render_src__', '') pipeline_escaped_sources = html.escape(sources) @@ -42,12 +52,8 @@ def create_di_autodoc_handler(method: str, MMDiagramData('Dependency graph', html.escape(depgraph)), ] - escaped_sources = [('Injection pipeline', pipeline_escaped_sources)] + escaped_sources.append(('Injection pipeline', pipeline_escaped_sources)) - for k, v in param_inflators.items(): - sources = getattr(v, '__megasniff_sources__', '') - src = html.escape(sources) - escaped_sources.append((k, src)) html_content = template.render( handler_method=method, diff --git a/uv.lock b/uv.lock index 276ee53..80ca419 100644 --- a/uv.lock +++ b/uv.lock @@ -4,15 +4,15 @@ requires-python = ">=3.13" [[package]] name = "breakshaft" -version = "0.1.6" +version = "0.1.6.post2" source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } dependencies = [ { name = "hatchling" }, { name = "jinja2" }, ] -sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6/breakshaft-0.1.6.tar.gz", hash = "sha256:443777f9f13889e79b31f763659b2d84540e045afb0f1c696ebec955213b653a" } +sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6.post2/breakshaft-0.1.6.post2.tar.gz", hash = "sha256:523cdbd55f7fcea3a3f664bbc3ec7d524be09f1f45c4c3b39f3de18164fad37d" } wheels = [ - { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6/breakshaft-0.1.6-py3-none-any.whl", hash = "sha256:abc3e99269cac906a0aafbc1f6af628198986eef571f1e27f29c2cb7f7bfde08" }, + { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6.post2/breakshaft-0.1.6.post2-py3-none-any.whl", hash = "sha256:a355b4546ef94f92e792e14a6cfcd029298a299b453c7ba47e13c32e5a82ff4b" }, ] [[package]] @@ -206,15 +206,15 @@ wheels = [ [[package]] name = "megasniff" -version = "0.2.3" +version = "0.2.3.post1" source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } dependencies = [ { name = "hatchling" }, { name = "jinja2" }, ] -sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3/megasniff-0.2.3.tar.gz", hash = "sha256:448776b495bb9b6a7c6d8c1a26afa1290585ef648b3e861c779f7530a2dd36bb" } +sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3.post1/megasniff-0.2.3.post1.tar.gz", hash = "sha256:41dd6c235225df13b0d18ad417b93bbddfd70d4b9e5ca405ddef976a8f2e753c" } wheels = [ - { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3/megasniff-0.2.3-py3-none-any.whl", hash = "sha256:b58803cfbcd113f18f20850250ac9630dcc5cec9336ac0fc024142e921047b59" }, + { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3.post1/megasniff-0.2.3.post1-py3-none-any.whl", hash = "sha256:414560ae1a2b1a00a7fcaf102e8faf45214748f230ad756ee34a71f2598d93e4" }, ] [[package]]