Allow DI injectors use request schemas in a required args

This commit is contained in:
2025-08-20 03:03:19 +03:00
parent faaa43fdf1
commit c40bdca9e4
4 changed files with 137 additions and 74 deletions

View File

@@ -10,7 +10,7 @@ from turbosloth.interfaces.serialize_selector import SerializeSelector
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE
from turbosloth.req_schema import UnwrappedRequest
from turbosloth.schema import RequestBody, HeaderParam
from turbosloth.schema import RequestBody, HeaderParam, PathParam
app = SlothApp(di_autodoc_prefix='/didoc',
serialize_selector=SerializeSelector(default_content_type='application/json'))
@@ -83,24 +83,39 @@ class DummyDbConnection:
self.constructions[0] += 1
@app.inj_repo.mark_injector()
@app.mark_injector()
def create_db_connection() -> DummyDbConnection:
return DummyDbConnection()
@dataclass
class User:
user_id: int
name: str
@app.mark_injector()
def auth_user(user: RequestBody(UserPostSchema)) -> User:
return User(user.user_id, f'user {user.user_id}')
@app.post("/test/body/{a}")
async def test_body(r: RequestBody(UserPostSchema),
q1: str,
a: str,
h1: HeaderParam(str, 'header1'),
db: DummyDbConnection) -> SerializedResponse:
db: DummyDbConnection,
user: User,
q2: int = 321) -> SerializedResponse:
print(r.user_id)
resp = {
'req': r,
'q1': q1,
'h1': h1,
'a': a,
'db': db.constructions
'db': db.constructions,
'user': user.__dict__,
'q2': q2
}
return SerializedResponse(200, {}, resp)

View File

@@ -24,6 +24,7 @@ from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYP
from breakshaft.convertor import ConvRepo
from .util import parse_content_type
import breakshaft.util
class ASGIApp(Protocol):
@@ -260,70 +261,29 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
raise RuntimeError(f'Unsupported scope type: {t}')
def _create_infl_from_schemas(self, schemas: Iterable[ParamSchema], from_type_override: type):
if len(schemas) > 0:
infl = self.infl_generator.schema_to_inflator(
list(schemas),
strict_mode_override=False,
)
type_remap = {
'from_data': from_type_override,
'return': Tuple[tuple(t.replacement_type for t in schemas)]
}
return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap))
return None, None
def route(self, method: MethodType, path_pattern: str):
def decorator(fn: HandlerType):
path_substs = self.router.find_pattern_substs(path_pattern)
injectors = self.inj_repo.filtered_injectors(True, True)
injected_types = list(map(lambda x: x.injects, injectors))
config = EndpointConfig.from_handler(fn, path_substs, injected_types)
fork_with = set()
inflators_for_didoc = {}
def create_infl_from_schemas(schemas: Iterable[ParamSchema], from_type_override: type):
if len(schemas) > 0:
infl = self.infl_generator.schema_to_inflator(
list(schemas),
strict_mode_override=False,
)
type_remap = {
'from_data': from_type_override,
'return': Tuple[tuple(t.replacement_type for t in schemas)]
}
return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap))
return None, None
i, infl = create_infl_from_schemas(config.path_schemas.values(), PTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Path inflator'] = i
i, infl = create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Query inflator'] = i
i, infl = create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Header inflator'] = i
if config.body_schema is not None:
infl = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(infl, type_remap=type_remap))
inflators_for_didoc['Body inflator'] = infl
fork_with, fn_type_hints = self._integrate_func(fn, path_substs)
tmp_repo = self.inj_repo.fork(fork_with)
fn_type_hints = get_type_hints(fn)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
p = tmp_repo.create_pipeline(
(Send, BasicRequest),
[ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer],
force_async=True
force_async=True,
)
self.router.add(method, path_pattern, p)
@@ -344,8 +304,90 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
)
self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)(
create_di_autodoc_handler(method, path_pattern, p, inflators_for_didoc, depgraph))
create_di_autodoc_handler(method, path_pattern, p, depgraph))
return fn
return decorator
def _didoc_name(self, master_name: str, schemas: dict[str, ParamSchema], ret_schema: type):
name = f'{master_name}['
for s in schemas.values():
name += f'{s.key_name}: '
name += breakshaft.util.universal_qualname(s.schema) + ', '
name = name[:-2] + '] -> ' + breakshaft.util.universal_qualname(ret_schema)
return name
def _integrate_func(self,
func: Callable,
path_params: Optional[Iterable[str]] = None):
injectors = self.inj_repo.filtered_injectors(True, True)
injected_types = list(map(lambda x: x.injects, injectors))
config = EndpointConfig.from_handler(func, path_params or set(), injected_types)
fork_with = set()
i_p, infl = self._create_infl_from_schemas(config.path_schemas.values(), PTYPE)
if infl is not None:
fork_with |= infl
i_q, infl = self._create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
i_h, infl = self._create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
i_b = None
if config.body_schema is not None:
i_b = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(i_b, type_remap=type_remap))
fn_type_hints = get_type_hints(func)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
main_cps = ConversionPoint.from_fn(func, type_remap=fn_type_hints)
main_injects = main_cps[0].injects
if self.di_autodoc_prefix is not None:
if len(config.header_schemas) > 0:
setattr(i_h, '__turbosloth_DI_name__',
self._didoc_name('Inflator Header', config.header_schemas, main_injects))
if len(config.query_schemas) > 0:
setattr(i_q, '__turbosloth_DI_name__',
self._didoc_name('Inflator Query', config.query_schemas, main_injects))
if len(config.path_schemas) > 0:
setattr(i_p, '__turbosloth_DI_name__',
self._didoc_name('Inflator Path', config.path_schemas, main_injects))
if i_b is not None:
setattr(i_b, '__turbosloth_DI_name__',
self._didoc_name('Inflator Body', {config.body_schema.key_name: config.body_schema},
main_injects))
setattr(func, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(main_injects))
fork_with |= set(ConversionPoint.from_fn(func, type_remap=fn_type_hints))
return fork_with, fn_type_hints
def add_injector(self, func: Callable):
fork_with, _ = self._integrate_func(func)
self.inj_repo.add_conversion_points(fork_with)
def mark_injector(self):
def inner(func: Callable):
self.add_injector(func)
return func
return inner

View File

@@ -5,11 +5,13 @@ from typing import Callable, Awaitable
import breakshaft.util_mermaid
import jinja2
from breakshaft.models import ConversionPoint
from turbosloth.interfaces.serialized import SerializedResponse
from turbosloth.interfaces.serialized.html import HTMLSerializedResponse
from turbosloth.types import InternalHandlerType
import breakshaft.renderer
import breakshaft.util
@dataclass
@@ -21,14 +23,22 @@ class MMDiagramData:
def create_di_autodoc_handler(method: str,
path: str,
handler: InternalHandlerType,
param_inflators: dict[str, Callable],
depgraph: str) -> Awaitable[SerializedResponse]:
callseq = getattr(handler, '__breakshaft_callseq__', [])
callseq: list[ConversionPoint] = getattr(handler, '__breakshaft_callseq__', [])
mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid(
list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq(
breakshaft.renderer.render_data_from_callseq((), {}, callseq)
)))
)
escaped_sources = []
for c in callseq:
if hasattr(c.fn, '__megasniff_sources__'):
name = getattr(c.fn, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(c.fn))
escaped_sources.append(
(name, html.escape(getattr(c.fn, '__megasniff_sources__', ''))))
sources = getattr(handler, '__breakshaft_render_src__', '')
pipeline_escaped_sources = html.escape(sources)
@@ -42,12 +52,8 @@ def create_di_autodoc_handler(method: str,
MMDiagramData('Dependency graph', html.escape(depgraph)),
]
escaped_sources = [('Injection pipeline', pipeline_escaped_sources)]
escaped_sources.append(('Injection pipeline', pipeline_escaped_sources))
for k, v in param_inflators.items():
sources = getattr(v, '__megasniff_sources__', '')
src = html.escape(sources)
escaped_sources.append((k, src))
html_content = template.render(
handler_method=method,

12
uv.lock generated
View File

@@ -4,15 +4,15 @@ requires-python = ">=3.13"
[[package]]
name = "breakshaft"
version = "0.1.6"
version = "0.1.6.post2"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [
{ name = "hatchling" },
{ name = "jinja2" },
]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6/breakshaft-0.1.6.tar.gz", hash = "sha256:443777f9f13889e79b31f763659b2d84540e045afb0f1c696ebec955213b653a" }
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6.post2/breakshaft-0.1.6.post2.tar.gz", hash = "sha256:523cdbd55f7fcea3a3f664bbc3ec7d524be09f1f45c4c3b39f3de18164fad37d" }
wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6/breakshaft-0.1.6-py3-none-any.whl", hash = "sha256:abc3e99269cac906a0aafbc1f6af628198986eef571f1e27f29c2cb7f7bfde08" },
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.6.post2/breakshaft-0.1.6.post2-py3-none-any.whl", hash = "sha256:a355b4546ef94f92e792e14a6cfcd029298a299b453c7ba47e13c32e5a82ff4b" },
]
[[package]]
@@ -206,15 +206,15 @@ wheels = [
[[package]]
name = "megasniff"
version = "0.2.3"
version = "0.2.3.post1"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [
{ name = "hatchling" },
{ name = "jinja2" },
]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3/megasniff-0.2.3.tar.gz", hash = "sha256:448776b495bb9b6a7c6d8c1a26afa1290585ef648b3e861c779f7530a2dd36bb" }
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3.post1/megasniff-0.2.3.post1.tar.gz", hash = "sha256:41dd6c235225df13b0d18ad417b93bbddfd70d4b9e5ca405ddef976a8f2e753c" }
wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3/megasniff-0.2.3-py3-none-any.whl", hash = "sha256:b58803cfbcd113f18f20850250ac9630dcc5cec9336ac0fc024142e921047b59" },
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.3.post1/megasniff-0.2.3.post1-py3-none-any.whl", hash = "sha256:414560ae1a2b1a00a7fcaf102e8faf45214748f230ad756ee34a71f2598d93e4" },
]
[[package]]