Allow DI injectors use request schemas in a required args

This commit is contained in:
2025-08-20 03:03:19 +03:00
parent faaa43fdf1
commit 328647702d
3 changed files with 128 additions and 67 deletions

View File

@@ -10,7 +10,7 @@ from turbosloth.interfaces.serialize_selector import SerializeSelector
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
from turbosloth.schema import RequestBody, HeaderParam from turbosloth.schema import RequestBody, HeaderParam, PathParam
app = SlothApp(di_autodoc_prefix='/didoc', app = SlothApp(di_autodoc_prefix='/didoc',
serialize_selector=SerializeSelector(default_content_type='application/json')) serialize_selector=SerializeSelector(default_content_type='application/json'))
@@ -83,24 +83,37 @@ class DummyDbConnection:
self.constructions[0] += 1 self.constructions[0] += 1
@app.inj_repo.mark_injector() @app.mark_injector()
def create_db_connection() -> DummyDbConnection: def create_db_connection() -> DummyDbConnection:
return DummyDbConnection() return DummyDbConnection()
@dataclass
class User:
user_id: int
name: str
@app.mark_injector()
def auth_user(user: RequestBody(UserPostSchema)) -> User:
return User(user.user_id, f'user {user.user_id}')
@app.post("/test/body/{a}") @app.post("/test/body/{a}")
async def test_body(r: RequestBody(UserPostSchema), async def test_body(r: RequestBody(UserPostSchema),
q1: str, q1: str,
a: str, a: str,
h1: HeaderParam(str, 'header1'), h1: HeaderParam(str, 'header1'),
db: DummyDbConnection) -> SerializedResponse: db: DummyDbConnection,
user: User) -> SerializedResponse:
print(r.user_id) print(r.user_id)
resp = { resp = {
'req': r, 'req': r,
'q1': q1, 'q1': q1,
'h1': h1, 'h1': h1,
'a': a, 'a': a,
'db': db.constructions 'db': db.constructions,
'user': user.__dict__
} }
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)

View File

@@ -24,6 +24,7 @@ from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYP
from breakshaft.convertor import ConvRepo from breakshaft.convertor import ConvRepo
from .util import parse_content_type from .util import parse_content_type
import breakshaft.util
class ASGIApp(Protocol): class ASGIApp(Protocol):
@@ -260,66 +261,25 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
raise RuntimeError(f'Unsupported scope type: {t}') raise RuntimeError(f'Unsupported scope type: {t}')
def _create_infl_from_schemas(self, schemas: Iterable[ParamSchema], from_type_override: type):
if len(schemas) > 0:
infl = self.infl_generator.schema_to_inflator(
list(schemas),
strict_mode_override=False,
)
type_remap = {
'from_data': from_type_override,
'return': Tuple[tuple(t.replacement_type for t in schemas)]
}
return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap))
return None, None
def route(self, method: MethodType, path_pattern: str): def route(self, method: MethodType, path_pattern: str):
def decorator(fn: HandlerType): def decorator(fn: HandlerType):
path_substs = self.router.find_pattern_substs(path_pattern) path_substs = self.router.find_pattern_substs(path_pattern)
injectors = self.inj_repo.filtered_injectors(True, True) fork_with, fn_type_hints = self._integrate_func(fn, path_substs)
injected_types = list(map(lambda x: x.injects, injectors))
config = EndpointConfig.from_handler(fn, path_substs, injected_types)
fork_with = set()
inflators_for_didoc = {}
def create_infl_from_schemas(schemas: Iterable[ParamSchema], from_type_override: type):
if len(schemas) > 0:
infl = self.infl_generator.schema_to_inflator(
list(schemas),
strict_mode_override=False,
)
type_remap = {
'from_data': from_type_override,
'return': Tuple[tuple(t.replacement_type for t in schemas)]
}
return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap))
return None, None
i, infl = create_infl_from_schemas(config.path_schemas.values(), PTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Path inflator'] = i
i, infl = create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Query inflator'] = i
i, infl = create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Header inflator'] = i
if config.body_schema is not None:
infl = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(infl, type_remap=type_remap))
inflators_for_didoc['Body inflator'] = infl
tmp_repo = self.inj_repo.fork(fork_with) tmp_repo = self.inj_repo.fork(fork_with)
fn_type_hints = get_type_hints(fn)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
p = tmp_repo.create_pipeline( p = tmp_repo.create_pipeline(
(Send, BasicRequest), (Send, BasicRequest),
[ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer], [ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer],
@@ -344,8 +304,90 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
) )
self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)( self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)(
create_di_autodoc_handler(method, path_pattern, p, inflators_for_didoc, depgraph)) create_di_autodoc_handler(method, path_pattern, p, depgraph))
return fn return fn
return decorator return decorator
def _didoc_name(self, master_name: str, schemas: dict[str, ParamSchema], ret_schema: type):
name = f'{master_name}['
for s in schemas.values():
name += f'{s.key_name}: '
name += breakshaft.util.universal_qualname(s.schema) + ', '
name = name[:-2] + '] -> ' + breakshaft.util.universal_qualname(ret_schema)
return name
def _integrate_func(self,
func: Callable,
path_params: Optional[Iterable[str]] = None):
injectors = self.inj_repo.filtered_injectors(True, True)
injected_types = list(map(lambda x: x.injects, injectors))
config = EndpointConfig.from_handler(func, path_params or set(), injected_types)
fork_with = set()
i_p, infl = self._create_infl_from_schemas(config.path_schemas.values(), PTYPE)
if infl is not None:
fork_with |= infl
i_q, infl = self._create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
i_h, infl = self._create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
i_b = None
if config.body_schema is not None:
i_b = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(i_b, type_remap=type_remap))
fn_type_hints = get_type_hints(func)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
main_cps = ConversionPoint.from_fn(func, type_remap=fn_type_hints)
main_injects = main_cps[0].injects
if self.di_autodoc_prefix is not None:
if len(config.header_schemas) > 0:
setattr(i_h, '__turbosloth_DI_name__',
self._didoc_name('Inflator Header', config.header_schemas, main_injects))
if len(config.query_schemas) > 0:
setattr(i_q, '__turbosloth_DI_name__',
self._didoc_name('Inflator Query', config.query_schemas, main_injects))
if len(config.path_schemas) > 0:
setattr(i_p, '__turbosloth_DI_name__',
self._didoc_name('Inflator Path', config.path_schemas, main_injects))
if i_b is not None:
setattr(i_b, '__turbosloth_DI_name__',
self._didoc_name('Inflator Body', {config.body_schema.key_name: config.body_schema},
main_injects))
setattr(func, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(main_injects))
fork_with |= set(ConversionPoint.from_fn(func, type_remap=fn_type_hints))
return fork_with, fn_type_hints
def add_injector(self, func: Callable):
fork_with, _ = self._integrate_func(func)
self.inj_repo.add_conversion_points(fork_with)
def mark_injector(self):
def inner(func: Callable):
self.add_injector(func)
return func
return inner

View File

@@ -5,11 +5,13 @@ from typing import Callable, Awaitable
import breakshaft.util_mermaid import breakshaft.util_mermaid
import jinja2 import jinja2
from breakshaft.models import ConversionPoint
from turbosloth.interfaces.serialized import SerializedResponse from turbosloth.interfaces.serialized import SerializedResponse
from turbosloth.interfaces.serialized.html import HTMLSerializedResponse from turbosloth.interfaces.serialized.html import HTMLSerializedResponse
from turbosloth.types import InternalHandlerType from turbosloth.types import InternalHandlerType
import breakshaft.renderer import breakshaft.renderer
import breakshaft.util
@dataclass @dataclass
@@ -21,14 +23,22 @@ class MMDiagramData:
def create_di_autodoc_handler(method: str, def create_di_autodoc_handler(method: str,
path: str, path: str,
handler: InternalHandlerType, handler: InternalHandlerType,
param_inflators: dict[str, Callable],
depgraph: str) -> Awaitable[SerializedResponse]: depgraph: str) -> Awaitable[SerializedResponse]:
callseq = getattr(handler, '__breakshaft_callseq__', []) callseq: list[ConversionPoint] = getattr(handler, '__breakshaft_callseq__', [])
mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid( mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid(
list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq( list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq(
breakshaft.renderer.render_data_from_callseq((), {}, callseq) breakshaft.renderer.render_data_from_callseq((), {}, callseq)
))) )))
) )
escaped_sources = []
for c in callseq:
if hasattr(c.fn, '__megasniff_sources__'):
name = getattr(c.fn, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(c.fn))
escaped_sources.append(
(name, html.escape(getattr(c.fn, '__megasniff_sources__', ''))))
sources = getattr(handler, '__breakshaft_render_src__', '') sources = getattr(handler, '__breakshaft_render_src__', '')
pipeline_escaped_sources = html.escape(sources) pipeline_escaped_sources = html.escape(sources)
@@ -42,12 +52,8 @@ def create_di_autodoc_handler(method: str,
MMDiagramData('Dependency graph', html.escape(depgraph)), MMDiagramData('Dependency graph', html.escape(depgraph)),
] ]
escaped_sources = [('Injection pipeline', pipeline_escaped_sources)] escaped_sources.append(('Injection pipeline', pipeline_escaped_sources))
for k, v in param_inflators.items():
sources = getattr(v, '__megasniff_sources__', '')
src = html.escape(sources)
escaped_sources.append((k, src))
html_content = template.render( html_content = template.render(
handler_method=method, handler_method=method,