Wrap api endpoint handlers into a breakshaft DI repo

This commit is contained in:
2025-07-19 01:31:25 +03:00
parent 95c47a5e90
commit e2e66ef14e
6 changed files with 171 additions and 62 deletions

View File

@@ -3,15 +3,18 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import uvicorn
from turbosloth import SlothApp from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
app = SlothApp() app = SlothApp()
@app.get("/") @app.get("/")
async def index(req: UnwrappedRequest[Any, Any]) -> SerializedResponse: async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!') return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -21,7 +24,7 @@ class UserIdSchema:
@app.get("/user/") @app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, Any]) -> SerializedResponse: async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse:
print(req) print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body} resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)
@@ -39,8 +42,29 @@ class UserPostSchema(UserIdSchema):
data: SomeData data: SomeData
class SomeInternalData:
a: list[int] = [0]
@app.inj_repo.mark_injector()
def foo() -> SomeInternalData:
s = SomeInternalData()
s.a[0] += 1
return s
@app.post("/user") @app.post("/user")
async def post_user(req: UnwrappedRequest[Any, UserPostSchema]) -> SerializedResponse: async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse:
print(req) print(req)
resp: dict[str, Any] = {'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data} print(dat)
resp: dict[str, Any] = {
'message': f'Hello, User {req.body.user_id}!',
'from': 'server',
'data': req.body.data,
'inj': dat.a
}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)
if __name__ == '__main__':
uvicorn.run('turbosloth.__main__:app', host='0.0.0.0', port=8000, reload=True)

View File

@@ -1,6 +1,9 @@
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any import typing
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any, Annotated
import breakshaft.util_mermaid
import megasniff.exceptions import megasniff.exceptions
from breakshaft.models import ConversionPoint
from megasniff import SchemaInflatorGenerator from megasniff import SchemaInflatorGenerator
from .exceptions import HTTPException from .exceptions import HTTPException
@@ -10,8 +13,11 @@ from .interfaces.serialized import SerializedResponse, SerializedRequest
from .interfaces.serialized.text import TextSerializedResponse from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest from .req_schema import UnwrappedRequest
from .router import Router from .router import Router
from .types import HandlerType, InternalHandlerType from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE
from breakshaft.convertor import ConvRepo
from .util import parse_content_type
class ASGIApp(Protocol): class ASGIApp(Protocol):
@@ -24,6 +30,45 @@ class ASGIApp(Protocol):
class HTTPApp(ASGIApp): class HTTPApp(ASGIApp):
serialize_selector: SerializeSelector serialize_selector: SerializeSelector
def extract_content_type(self, req: BasicRequest) -> ContentType:
contenttype_header = req.headers.get('Content-Type')
properties: dict[str, str]
if contenttype_header is None:
contenttype = self.serialize_selector.default_content_type
properties = {}
else:
contenttype, properties = parse_content_type(contenttype_header)
charset = properties.get('charset')
if charset is None:
if contenttype == 'application/json':
charset = 'utf-8'
else:
charset = 'latin1'
return ContentType(contenttype, charset)
def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
return ser.req.deserialize(req, ct.charset)
def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: ContentType) -> BasicResponse:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
sresponser = ser.resp
try:
return sresponser.into_basic(sresp, ct.charset)
except UnicodeEncodeError:
return sresponser.into_basic(sresp, 'utf-8')
async def send_answer(self, send: Send, resp: BasicResponse):
await send(resp.into_start_message())
await send({
'type': 'http.response.body',
'body': resp.body,
})
async def _do_http(self, scope: Scope, receive: Receive, send: Send) -> None: async def _do_http(self, scope: Scope, receive: Receive, send: Send) -> None:
method = scope['method'] method = scope['method']
path = scope['path'] path = scope['path']
@@ -45,31 +90,26 @@ class HTTPApp(ASGIApp):
try: try:
handler = self.router.match(method, path) handler = self.router.match(method, path)
ser = self.serialize_selector.select(req.headers) await handler(send, req)
sresponser = ser.resp return
charset = ser.charset
sreq = ser.req.deserialize(req, charset)
sresp = await handler(sreq)
except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException): except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException):
sresp = SerializedResponse(400, {}, 'Schema error') sresp = SerializedResponse(400, {}, 'Schema error')
except HTTPException as e: except HTTPException as e:
sresp = SerializedResponse(e.code, {}, str(e)) sresp = SerializedResponse(e.code, {}, str(e))
try: try:
resp = sresponser.into_basic(sresp, charset) ct = self.extract_content_type(req)
resp = sresponser.into_basic(sresp, ct.charset)
except UnicodeEncodeError: except UnicodeEncodeError:
resp = sresponser.into_basic(sresp, 'utf-8') resp = sresponser.into_basic(sresp, 'utf-8')
await send(resp.into_start_message()) # пока без малейшего понятия как избавиться от ручной обертки обработки ошибок
await send({ await self.send_answer(send, resp)
'type': 'http.response.body',
'body': resp.body,
})
class WSApp(ASGIApp): class WSApp(ASGIApp):
async def _do_websocket(self, scope: Scope, receive: Receive, send: Send) -> None: async def _do_websocket(self, scope: Scope, receive: Receive, send: Send) -> None:
# TODO: impl ws
raise NotImplementedError() raise NotImplementedError()
@@ -144,6 +184,23 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self._on_shutdown = on_shutdown self._on_shutdown = on_shutdown
self.serialize_selector = SerializeSelector() self.serialize_selector = SerializeSelector()
self.infl_generator = SchemaInflatorGenerator(strict_mode=True) self.infl_generator = SchemaInflatorGenerator(strict_mode=True)
self.inj_repo = ConvRepo()
@self.inj_repo.mark_injector()
def extract_query(req: BasicRequest) -> QTYPE:
return req.query
@self.inj_repo.mark_injector()
def extract_query(req: SerializedRequest) -> QTYPE:
return req.query
@self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE:
return req.body
self.inj_repo.add_injector(self.extract_content_type)
self.inj_repo.add_injector(self.serialize_request)
self.inj_repo.add_injector(self.serialize_response)
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
t = scope['type'] t = scope['type']
@@ -158,9 +215,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
def route(self, method: MethodType, path_pattern: str): def route(self, method: MethodType, path_pattern: str):
def decorator(fn: HandlerType): def decorator(fn: HandlerType):
hints = get_type_hints(fn) handle_hints = get_type_hints(fn)
req_schema = None req_schema = None
for argname, tp in hints.items(): for argname, tp in handle_hints.items():
if argname == 'return': if argname == 'return':
continue continue
if get_origin(tp) == UnwrappedRequest: if get_origin(tp) == UnwrappedRequest:
@@ -172,24 +229,56 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
query_type, body_type = get_args(req_schema) query_type, body_type = get_args(req_schema)
q_inflator = None q_inflator = None
b_inflator = None b_inflator = None
if query_type != Any:
q_inflator = self.infl_generator.schema_to_inflator(query_type, strict_mode_override=False)
if body_type != Any:
b_inflator = self.infl_generator.schema_to_inflator(body_type)
def internal_handler(req: SerializedRequest) -> Awaitable[SerializedResponse]: fork_with = set()
if q_inflator is not None:
q = q_inflator(req.query) def none_generator(*args) -> None:
return None
if query_type not in [QTYPE, None, Any]:
q_inflator = self.infl_generator.schema_to_inflator(
query_type,
strict_mode_override=False,
from_type_override=QTYPE
)
fork_with.add(ConversionPoint(q_inflator, query_type, (QTYPE,)))
else: else:
q = req.query fork_with.add(ConversionPoint(none_generator, query_type, (QTYPE,)))
if b_inflator is not None:
b = b_inflator(req.body) if body_type != [BTYPE, None, Any]:
b_inflator = self.infl_generator.schema_to_inflator(
body_type,
from_type_override=BTYPE
)
fork_with.add(ConversionPoint(b_inflator, body_type, (BTYPE,)))
else: else:
b = req.body fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
return fn(UnwrappedRequest(q, b)) def construct_unwrap(q, b) -> UnwrappedRequest:
print(f'unwrapping {query_type} and {body_type} with {q} and {b}')
return UnwrappedRequest(q, b)
self.router.add(method, path_pattern, internal_handler) fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))}
tmp_repo = self.inj_repo.fork(fork_with)
conv = tmp_repo.get_conversion(
(BasicRequest,),
fn,
force_async=True,
)
out_conv = self.inj_repo.get_conversion(
(Send, BasicRequest, handle_hints['return'],),
self.send_answer,
force_async=True
)
async def pipeline(send: Send, req: BasicRequest):
ret = await conv(req)
await out_conv(send, req, ret)
self.router.add(method, path_pattern, pipeline)
return fn return fn
return decorator return decorator

View File

@@ -12,7 +12,6 @@ class SerializeChoise(NamedTuple):
req: type[SerializedRequest] req: type[SerializedRequest]
resp: type[SerializedResponse] resp: type[SerializedResponse]
charset: str charset: str
content_properties: dict[str, str]
class SerializeSelector: class SerializeSelector:
@@ -34,22 +33,7 @@ class SerializeSelector:
ser[k] = v ser[k] = v
self.serializers = ser self.serializers = ser
def select(self, headers: CaseInsensitiveDict[str, str]) -> SerializeChoise: def select(self, contenttype: str, charset: str) -> SerializeChoise:
contenttype_header = headers.get('Content-Type')
properties: dict[str, str]
if contenttype_header is None:
contenttype = self.default_content_type
properties = {}
else:
contenttype, properties = parse_content_type(contenttype_header)
charset = properties.get('charset')
if charset is None:
if contenttype == 'application/json':
charset = 'utf-8'
else:
charset = 'latin1'
choise = self.serializers.get(typing.cast(str, contenttype)) choise = self.serializers.get(typing.cast(str, contenttype))
if choise is None and self.default_content_type is not None: if choise is None and self.default_content_type is not None:
@@ -58,4 +42,4 @@ class SerializeSelector:
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys())) raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys()))
req, resp = choise req, resp = choise
return SerializeChoise(req, resp, charset, properties) return SerializeChoise(req, resp, charset)

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Awaitable, Literal from typing import Any, Callable, Awaitable, Literal, Annotated
type Scope = dict[str, Any] type Scope = dict[str, Any]
type ASGIMessage = dict[str, Any] type ASGIMessage = dict[str, Any]
@@ -16,3 +16,6 @@ type MethodType = (
Literal['CONNECT'] | Literal['CONNECT'] |
Literal['OPTIONS'] | Literal['OPTIONS'] |
Literal['TRACE']) Literal['TRACE'])
type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']

View File

@@ -1,9 +1,18 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Awaitable, Literal, Any from typing import Callable, Awaitable, Literal, Any
from turbosloth.interfaces.base import BasicRequest
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import Send
from turbosloth.req_schema import UnwrappedRequest from turbosloth.req_schema import UnwrappedRequest
type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]] type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]]
type InternalHandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]] type InternalHandlerType = Callable[[Send, BasicRequest], Awaitable[None]]
@dataclass
class ContentType:
contenttype: str
charset: str

12
uv.lock generated
View File

@@ -4,15 +4,15 @@ requires-python = ">=3.13"
[[package]] [[package]]
name = "breakshaft" name = "breakshaft"
version = "0.1.0" version = "0.1.0.post2"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [ dependencies = [
{ name = "hatchling" }, { name = "hatchling" },
{ name = "jinja2" }, { name = "jinja2" },
] ]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0.tar.gz", hash = "sha256:77f114bab957b78f96b5b087e8ef9778c9b1fc4b3433f0b428d9dd084243debe" } sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2.tar.gz", hash = "sha256:d04f8685336080cddb5111422a9624c8afcc8c47264e1b847339139bfd690570" }
wheels = [ wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0-py3-none-any.whl", hash = "sha256:e7c8831e5889fbcd788d59fee05f7aee113663ebef9273cd280893e1ed04f5c2" }, { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2-py3-none-any.whl", hash = "sha256:fd4c9b213e2569c2c4d4f86cef2fe04065c08f9bffac3aa46b8ae21459f35074" },
] ]
[[package]] [[package]]
@@ -184,15 +184,15 @@ wheels = [
[[package]] [[package]]
name = "megasniff" name = "megasniff"
version = "0.2.0" version = "0.2.1"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" } source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [ dependencies = [
{ name = "hatchling" }, { name = "hatchling" },
{ name = "jinja2" }, { name = "jinja2" },
] ]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0.tar.gz", hash = "sha256:feceefd8a618c7b930832b4ff8edcc474961f5d5f53743c64e142a796fea0e30" } sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1.tar.gz", hash = "sha256:c6ed47b40fbdc92d7aa061267d158eea457b560f1bc8f84d297b55e5e4a0ef2e" }
wheels = [ wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0-py3-none-any.whl", hash = "sha256:4f03a1e81dcb1020b6fe80f3fbc7b192a61b4dc55f7c6783141aa464dc3c77d1" }, { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1-py3-none-any.whl", hash = "sha256:cff759fd61e9a4b8634329620cf13a941d7d5b38b834f8eab68e6ac78fa23589" },
] ]
[[package]] [[package]]