Wrap api endpoint handlers into a breakshaft DI repo

This commit is contained in:
2025-07-19 01:31:25 +03:00
parent 95c47a5e90
commit e2e66ef14e
6 changed files with 171 additions and 62 deletions

View File

@@ -3,15 +3,18 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import uvicorn
from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import QTYPE, BTYPE
from turbosloth.req_schema import UnwrappedRequest
app = SlothApp()
@app.get("/")
async def index(req: UnwrappedRequest[Any, Any]) -> SerializedResponse:
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@@ -21,7 +24,7 @@ class UserIdSchema:
@app.get("/user/")
async def get_user(req: UnwrappedRequest[UserIdSchema, Any]) -> SerializedResponse:
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse:
print(req)
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
return SerializedResponse(200, {}, resp)
@@ -39,8 +42,29 @@ class UserPostSchema(UserIdSchema):
data: SomeData
class SomeInternalData:
a: list[int] = [0]
@app.inj_repo.mark_injector()
def foo() -> SomeInternalData:
s = SomeInternalData()
s.a[0] += 1
return s
@app.post("/user")
async def post_user(req: UnwrappedRequest[Any, UserPostSchema]) -> SerializedResponse:
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse:
print(req)
resp: dict[str, Any] = {'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data}
print(dat)
resp: dict[str, Any] = {
'message': f'Hello, User {req.body.user_id}!',
'from': 'server',
'data': req.body.data,
'inj': dat.a
}
return SerializedResponse(200, {}, resp)
if __name__ == '__main__':
uvicorn.run('turbosloth.__main__:app', host='0.0.0.0', port=8000, reload=True)

View File

@@ -1,6 +1,9 @@
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any
import typing
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any, Annotated
import breakshaft.util_mermaid
import megasniff.exceptions
from breakshaft.models import ConversionPoint
from megasniff import SchemaInflatorGenerator
from .exceptions import HTTPException
@@ -10,8 +13,11 @@ from .interfaces.serialized import SerializedResponse, SerializedRequest
from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest
from .router import Router
from .types import HandlerType, InternalHandlerType
from .internal_types import Scope, Receive, Send, MethodType
from .types import HandlerType, InternalHandlerType, ContentType
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE
from breakshaft.convertor import ConvRepo
from .util import parse_content_type
class ASGIApp(Protocol):
@@ -24,6 +30,45 @@ class ASGIApp(Protocol):
class HTTPApp(ASGIApp):
serialize_selector: SerializeSelector
def extract_content_type(self, req: BasicRequest) -> ContentType:
contenttype_header = req.headers.get('Content-Type')
properties: dict[str, str]
if contenttype_header is None:
contenttype = self.serialize_selector.default_content_type
properties = {}
else:
contenttype, properties = parse_content_type(contenttype_header)
charset = properties.get('charset')
if charset is None:
if contenttype == 'application/json':
charset = 'utf-8'
else:
charset = 'latin1'
return ContentType(contenttype, charset)
def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
return ser.req.deserialize(req, ct.charset)
def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: ContentType) -> BasicResponse:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
sresponser = ser.resp
try:
return sresponser.into_basic(sresp, ct.charset)
except UnicodeEncodeError:
return sresponser.into_basic(sresp, 'utf-8')
async def send_answer(self, send: Send, resp: BasicResponse):
await send(resp.into_start_message())
await send({
'type': 'http.response.body',
'body': resp.body,
})
async def _do_http(self, scope: Scope, receive: Receive, send: Send) -> None:
method = scope['method']
path = scope['path']
@@ -45,31 +90,26 @@ class HTTPApp(ASGIApp):
try:
handler = self.router.match(method, path)
ser = self.serialize_selector.select(req.headers)
sresponser = ser.resp
charset = ser.charset
sreq = ser.req.deserialize(req, charset)
sresp = await handler(sreq)
await handler(send, req)
return
except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException):
sresp = SerializedResponse(400, {}, 'Schema error')
except HTTPException as e:
sresp = SerializedResponse(e.code, {}, str(e))
try:
resp = sresponser.into_basic(sresp, charset)
ct = self.extract_content_type(req)
resp = sresponser.into_basic(sresp, ct.charset)
except UnicodeEncodeError:
resp = sresponser.into_basic(sresp, 'utf-8')
await send(resp.into_start_message())
await send({
'type': 'http.response.body',
'body': resp.body,
})
# пока без малейшего понятия как избавиться от ручной обертки обработки ошибок
await self.send_answer(send, resp)
class WSApp(ASGIApp):
async def _do_websocket(self, scope: Scope, receive: Receive, send: Send) -> None:
# TODO: impl ws
raise NotImplementedError()
@@ -144,6 +184,23 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self._on_shutdown = on_shutdown
self.serialize_selector = SerializeSelector()
self.infl_generator = SchemaInflatorGenerator(strict_mode=True)
self.inj_repo = ConvRepo()
@self.inj_repo.mark_injector()
def extract_query(req: BasicRequest) -> QTYPE:
return req.query
@self.inj_repo.mark_injector()
def extract_query(req: SerializedRequest) -> QTYPE:
return req.query
@self.inj_repo.mark_injector()
def extract_body(req: SerializedRequest) -> BTYPE:
return req.body
self.inj_repo.add_injector(self.extract_content_type)
self.inj_repo.add_injector(self.serialize_request)
self.inj_repo.add_injector(self.serialize_response)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
t = scope['type']
@@ -158,9 +215,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
def route(self, method: MethodType, path_pattern: str):
def decorator(fn: HandlerType):
hints = get_type_hints(fn)
handle_hints = get_type_hints(fn)
req_schema = None
for argname, tp in hints.items():
for argname, tp in handle_hints.items():
if argname == 'return':
continue
if get_origin(tp) == UnwrappedRequest:
@@ -172,24 +229,56 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
query_type, body_type = get_args(req_schema)
q_inflator = None
b_inflator = None
if query_type != Any:
q_inflator = self.infl_generator.schema_to_inflator(query_type, strict_mode_override=False)
if body_type != Any:
b_inflator = self.infl_generator.schema_to_inflator(body_type)
def internal_handler(req: SerializedRequest) -> Awaitable[SerializedResponse]:
if q_inflator is not None:
q = q_inflator(req.query)
else:
q = req.query
if b_inflator is not None:
b = b_inflator(req.body)
else:
b = req.body
fork_with = set()
return fn(UnwrappedRequest(q, b))
def none_generator(*args) -> None:
return None
self.router.add(method, path_pattern, internal_handler)
if query_type not in [QTYPE, None, Any]:
q_inflator = self.infl_generator.schema_to_inflator(
query_type,
strict_mode_override=False,
from_type_override=QTYPE
)
fork_with.add(ConversionPoint(q_inflator, query_type, (QTYPE,)))
else:
fork_with.add(ConversionPoint(none_generator, query_type, (QTYPE,)))
if body_type != [BTYPE, None, Any]:
b_inflator = self.infl_generator.schema_to_inflator(
body_type,
from_type_override=BTYPE
)
fork_with.add(ConversionPoint(b_inflator, body_type, (BTYPE,)))
else:
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
def construct_unwrap(q, b) -> UnwrappedRequest:
print(f'unwrapping {query_type} and {body_type} with {q} and {b}')
return UnwrappedRequest(q, b)
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))}
tmp_repo = self.inj_repo.fork(fork_with)
conv = tmp_repo.get_conversion(
(BasicRequest,),
fn,
force_async=True,
)
out_conv = self.inj_repo.get_conversion(
(Send, BasicRequest, handle_hints['return'],),
self.send_answer,
force_async=True
)
async def pipeline(send: Send, req: BasicRequest):
ret = await conv(req)
await out_conv(send, req, ret)
self.router.add(method, path_pattern, pipeline)
return fn
return decorator

View File

@@ -12,7 +12,6 @@ class SerializeChoise(NamedTuple):
req: type[SerializedRequest]
resp: type[SerializedResponse]
charset: str
content_properties: dict[str, str]
class SerializeSelector:
@@ -34,22 +33,7 @@ class SerializeSelector:
ser[k] = v
self.serializers = ser
def select(self, headers: CaseInsensitiveDict[str, str]) -> SerializeChoise:
contenttype_header = headers.get('Content-Type')
properties: dict[str, str]
if contenttype_header is None:
contenttype = self.default_content_type
properties = {}
else:
contenttype, properties = parse_content_type(contenttype_header)
charset = properties.get('charset')
if charset is None:
if contenttype == 'application/json':
charset = 'utf-8'
else:
charset = 'latin1'
def select(self, contenttype: str, charset: str) -> SerializeChoise:
choise = self.serializers.get(typing.cast(str, contenttype))
if choise is None and self.default_content_type is not None:
@@ -58,4 +42,4 @@ class SerializeSelector:
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys()))
req, resp = choise
return SerializeChoise(req, resp, charset, properties)
return SerializeChoise(req, resp, charset)

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Awaitable, Literal
from typing import Any, Callable, Awaitable, Literal, Annotated
type Scope = dict[str, Any]
type ASGIMessage = dict[str, Any]
@@ -16,3 +16,6 @@ type MethodType = (
Literal['CONNECT'] |
Literal['OPTIONS'] |
Literal['TRACE'])
type QTYPE = Annotated[dict[str, Any], 'query_params']
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']

View File

@@ -1,9 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Awaitable, Literal, Any
from turbosloth.interfaces.base import BasicRequest
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.internal_types import Send
from turbosloth.req_schema import UnwrappedRequest
type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]]
type InternalHandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]]
type InternalHandlerType = Callable[[Send, BasicRequest], Awaitable[None]]
@dataclass
class ContentType:
contenttype: str
charset: str

12
uv.lock generated
View File

@@ -4,15 +4,15 @@ requires-python = ">=3.13"
[[package]]
name = "breakshaft"
version = "0.1.0"
version = "0.1.0.post2"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [
{ name = "hatchling" },
{ name = "jinja2" },
]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0.tar.gz", hash = "sha256:77f114bab957b78f96b5b087e8ef9778c9b1fc4b3433f0b428d9dd084243debe" }
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2.tar.gz", hash = "sha256:d04f8685336080cddb5111422a9624c8afcc8c47264e1b847339139bfd690570" }
wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0-py3-none-any.whl", hash = "sha256:e7c8831e5889fbcd788d59fee05f7aee113663ebef9273cd280893e1ed04f5c2" },
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2-py3-none-any.whl", hash = "sha256:fd4c9b213e2569c2c4d4f86cef2fe04065c08f9bffac3aa46b8ae21459f35074" },
]
[[package]]
@@ -184,15 +184,15 @@ wheels = [
[[package]]
name = "megasniff"
version = "0.2.0"
version = "0.2.1"
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
dependencies = [
{ name = "hatchling" },
{ name = "jinja2" },
]
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0.tar.gz", hash = "sha256:feceefd8a618c7b930832b4ff8edcc474961f5d5f53743c64e142a796fea0e30" }
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1.tar.gz", hash = "sha256:c6ed47b40fbdc92d7aa061267d158eea457b560f1bc8f84d297b55e5e4a0ef2e" }
wheels = [
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0-py3-none-any.whl", hash = "sha256:4f03a1e81dcb1020b6fe80f3fbc7b192a61b4dc55f7c6783141aa464dc3c77d1" },
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1-py3-none-any.whl", hash = "sha256:cff759fd61e9a4b8634329620cf13a941d7d5b38b834f8eab68e6ac78fa23589" },
]
[[package]]