Compare commits

...

2 Commits

Author SHA1 Message Date
328647702d Allow DI injectors use request schemas in a required args 2025-08-20 03:03:19 +03:00
faaa43fdf1 Allow DI args into an app handlers 2025-08-20 01:57:57 +03:00
4 changed files with 148 additions and 65 deletions

View File

@@ -10,9 +10,10 @@ from turbosloth.interfaces.serialize_selector import SerializeSelector
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE from turbosloth.internal_types import QTYPE, BTYPE, PTYPE, HTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
from turbosloth.schema import RequestBody, HeaderParam from turbosloth.schema import RequestBody, HeaderParam, PathParam
app = SlothApp(di_autodoc_prefix='/didoc', serialize_selector=SerializeSelector(default_content_type='application/json')) app = SlothApp(di_autodoc_prefix='/didoc',
serialize_selector=SerializeSelector(default_content_type='application/json'))
# @app.get("/") # @app.get("/")
@@ -75,17 +76,44 @@ class PTYPESchema:
# return SerializedResponse(200, {}, resp) # return SerializedResponse(200, {}, resp)
class DummyDbConnection:
constructions = [0]
def __init__(self):
self.constructions[0] += 1
@app.mark_injector()
def create_db_connection() -> DummyDbConnection:
return DummyDbConnection()
@dataclass
class User:
user_id: int
name: str
@app.mark_injector()
def auth_user(user: RequestBody(UserPostSchema)) -> User:
return User(user.user_id, f'user {user.user_id}')
@app.post("/test/body/{a}") @app.post("/test/body/{a}")
async def test_body(r: RequestBody(UserPostSchema), async def test_body(r: RequestBody(UserPostSchema),
q1: str, q1: str,
a: str, a: str,
h1: HeaderParam(str, 'header1')) -> SerializedResponse: h1: HeaderParam(str, 'header1'),
db: DummyDbConnection,
user: User) -> SerializedResponse:
print(r.user_id) print(r.user_id)
resp = { resp = {
'req': r, 'req': r,
'q1': q1, 'q1': q1,
'h1': h1, 'h1': h1,
'a': a 'a': a,
'db': db.constructions,
'user': user.__dict__
} }
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)

View File

@@ -24,6 +24,7 @@ from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE, PTYP
from breakshaft.convertor import ConvRepo from breakshaft.convertor import ConvRepo
from .util import parse_content_type from .util import parse_content_type
import breakshaft.util
class ASGIApp(Protocol): class ASGIApp(Protocol):
@@ -260,17 +261,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
raise RuntimeError(f'Unsupported scope type: {t}') raise RuntimeError(f'Unsupported scope type: {t}')
def route(self, method: MethodType, path_pattern: str): def _create_infl_from_schemas(self, schemas: Iterable[ParamSchema], from_type_override: type):
def decorator(fn: HandlerType):
path_substs = self.router.find_pattern_substs(path_pattern)
config = EndpointConfig.from_handler(fn, path_substs)
fork_with = set()
inflators_for_didoc = {}
def create_infl_from_schemas(schemas: Iterable[ParamSchema], from_type_override: type):
if len(schemas) > 0: if len(schemas) > 0:
infl = self.infl_generator.schema_to_inflator( infl = self.infl_generator.schema_to_inflator(
list(schemas), list(schemas),
@@ -283,39 +274,12 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap)) return infl, set(ConversionPoint.from_fn(infl, type_remap=type_remap))
return None, None return None, None
i, infl = create_infl_from_schemas(config.path_schemas.values(), PTYPE) def route(self, method: MethodType, path_pattern: str):
if infl is not None: def decorator(fn: HandlerType):
fork_with |= infl path_substs = self.router.find_pattern_substs(path_pattern)
inflators_for_didoc['Path inflator'] = i
i, infl = create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Query inflator'] = i
i, infl = create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
inflators_for_didoc['Header inflator'] = i
if config.body_schema is not None:
infl = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(infl, type_remap=type_remap))
inflators_for_didoc['Body inflator'] = infl
fork_with, fn_type_hints = self._integrate_func(fn, path_substs)
tmp_repo = self.inj_repo.fork(fork_with) tmp_repo = self.inj_repo.fork(fork_with)
fn_type_hints = get_type_hints(fn)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
p = tmp_repo.create_pipeline( p = tmp_repo.create_pipeline(
(Send, BasicRequest), (Send, BasicRequest),
[ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer], [ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer],
@@ -340,8 +304,90 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
) )
self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)( self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)(
create_di_autodoc_handler(method, path_pattern, p, inflators_for_didoc, depgraph)) create_di_autodoc_handler(method, path_pattern, p, depgraph))
return fn return fn
return decorator return decorator
def _didoc_name(self, master_name: str, schemas: dict[str, ParamSchema], ret_schema: type):
name = f'{master_name}['
for s in schemas.values():
name += f'{s.key_name}: '
name += breakshaft.util.universal_qualname(s.schema) + ', '
name = name[:-2] + '] -> ' + breakshaft.util.universal_qualname(ret_schema)
return name
def _integrate_func(self,
func: Callable,
path_params: Optional[Iterable[str]] = None):
injectors = self.inj_repo.filtered_injectors(True, True)
injected_types = list(map(lambda x: x.injects, injectors))
config = EndpointConfig.from_handler(func, path_params or set(), injected_types)
fork_with = set()
i_p, infl = self._create_infl_from_schemas(config.path_schemas.values(), PTYPE)
if infl is not None:
fork_with |= infl
i_q, infl = self._create_infl_from_schemas(config.query_schemas.values(), QTYPE)
if infl is not None:
fork_with |= infl
i_h, infl = self._create_infl_from_schemas(config.header_schemas.values(), HTYPE)
if infl is not None:
fork_with |= infl
i_b = None
if config.body_schema is not None:
i_b = self.infl_generator.schema_to_inflator(
config.body_schema.schema,
strict_mode_override=False,
)
type_remap = {
'from_data': BTYPE,
'return': config.body_schema.replacement_type
}
fork_with |= set(ConversionPoint.from_fn(i_b, type_remap=type_remap))
fn_type_hints = get_type_hints(func)
for k, v in config.type_replacement.items():
fn_type_hints[k] = v
main_cps = ConversionPoint.from_fn(func, type_remap=fn_type_hints)
main_injects = main_cps[0].injects
if self.di_autodoc_prefix is not None:
if len(config.header_schemas) > 0:
setattr(i_h, '__turbosloth_DI_name__',
self._didoc_name('Inflator Header', config.header_schemas, main_injects))
if len(config.query_schemas) > 0:
setattr(i_q, '__turbosloth_DI_name__',
self._didoc_name('Inflator Query', config.query_schemas, main_injects))
if len(config.path_schemas) > 0:
setattr(i_p, '__turbosloth_DI_name__',
self._didoc_name('Inflator Path', config.path_schemas, main_injects))
if i_b is not None:
setattr(i_b, '__turbosloth_DI_name__',
self._didoc_name('Inflator Body', {config.body_schema.key_name: config.body_schema},
main_injects))
setattr(func, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(main_injects))
fork_with |= set(ConversionPoint.from_fn(func, type_remap=fn_type_hints))
return fork_with, fn_type_hints
def add_injector(self, func: Callable):
fork_with, _ = self._integrate_func(func)
self.inj_repo.add_conversion_points(fork_with)
def mark_injector(self):
def inner(func: Callable):
self.add_injector(func)
return func
return inner

View File

@@ -5,11 +5,13 @@ from typing import Callable, Awaitable
import breakshaft.util_mermaid import breakshaft.util_mermaid
import jinja2 import jinja2
from breakshaft.models import ConversionPoint
from turbosloth.interfaces.serialized import SerializedResponse from turbosloth.interfaces.serialized import SerializedResponse
from turbosloth.interfaces.serialized.html import HTMLSerializedResponse from turbosloth.interfaces.serialized.html import HTMLSerializedResponse
from turbosloth.types import InternalHandlerType from turbosloth.types import InternalHandlerType
import breakshaft.renderer import breakshaft.renderer
import breakshaft.util
@dataclass @dataclass
@@ -21,14 +23,22 @@ class MMDiagramData:
def create_di_autodoc_handler(method: str, def create_di_autodoc_handler(method: str,
path: str, path: str,
handler: InternalHandlerType, handler: InternalHandlerType,
param_inflators: dict[str, Callable],
depgraph: str) -> Awaitable[SerializedResponse]: depgraph: str) -> Awaitable[SerializedResponse]:
callseq = getattr(handler, '__breakshaft_callseq__', []) callseq: list[ConversionPoint] = getattr(handler, '__breakshaft_callseq__', [])
mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid( mmd_flowchart = breakshaft.util_mermaid.draw_callseq_mermaid(
list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq( list(map(lambda x: x._injection, breakshaft.renderer.deduplicate_callseq(
breakshaft.renderer.render_data_from_callseq((), {}, callseq) breakshaft.renderer.render_data_from_callseq((), {}, callseq)
))) )))
) )
escaped_sources = []
for c in callseq:
if hasattr(c.fn, '__megasniff_sources__'):
name = getattr(c.fn, '__turbosloth_DI_name__', breakshaft.util.universal_qualname(c.fn))
escaped_sources.append(
(name, html.escape(getattr(c.fn, '__megasniff_sources__', ''))))
sources = getattr(handler, '__breakshaft_render_src__', '') sources = getattr(handler, '__breakshaft_render_src__', '')
pipeline_escaped_sources = html.escape(sources) pipeline_escaped_sources = html.escape(sources)
@@ -42,12 +52,8 @@ def create_di_autodoc_handler(method: str,
MMDiagramData('Dependency graph', html.escape(depgraph)), MMDiagramData('Dependency graph', html.escape(depgraph)),
] ]
escaped_sources = [('Injection pipeline', pipeline_escaped_sources)] escaped_sources.append(('Injection pipeline', pipeline_escaped_sources))
for k, v in param_inflators.items():
sources = getattr(v, '__megasniff_sources__', '')
src = html.escape(sources)
escaped_sources.append((k, src))
html_content = template.render( html_content = template.render(
handler_method=method, handler_method=method,

View File

@@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from typing import TYPE_CHECKING, overload, Any, TypeAlias, Optional, get_origin, get_args, Callable, get_type_hints from typing import TYPE_CHECKING, overload, Any, TypeAlias, Optional, get_origin, get_args, Callable, get_type_hints, \
Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeVar, Generic, Annotated from typing import TypeVar, Generic, Annotated
@@ -72,7 +73,7 @@ class EndpointConfig:
type_replacement: dict[str, type] type_replacement: dict[str, type]
@classmethod @classmethod
def from_handler(cls, h: Callable, path_substituts: set[str]) -> EndpointConfig: def from_handler(cls, h: Callable, path_substituts: set[str], ignore_types: Iterable[type]) -> EndpointConfig:
body_schema = None body_schema = None
query_schemas = {} query_schemas = {}
path_schemas = {} path_schemas = {}
@@ -82,6 +83,8 @@ class EndpointConfig:
handle_hints = get_endpoint_params_info(h) handle_hints = get_endpoint_params_info(h)
for argname, s in handle_hints.items(): for argname, s in handle_hints.items():
tp = s.schema tp = s.schema
if tp in ignore_types:
continue
type_replacement[argname] = s.replacement_type type_replacement[argname] = s.replacement_type
if argname == 'return': if argname == 'return':