Wrap api endpoint handlers into a breakshaft DI repo
This commit is contained in:
@@ -3,15 +3,18 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
|
||||
from turbosloth import SlothApp
|
||||
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
|
||||
from turbosloth.internal_types import QTYPE, BTYPE
|
||||
from turbosloth.req_schema import UnwrappedRequest
|
||||
|
||||
app = SlothApp()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index(req: UnwrappedRequest[Any, Any]) -> SerializedResponse:
|
||||
async def index(req: UnwrappedRequest[QTYPE, BTYPE]) -> SerializedResponse:
|
||||
return SerializedResponse(200, {}, 'Hello, ASGI Router!')
|
||||
|
||||
|
||||
@@ -21,7 +24,7 @@ class UserIdSchema:
|
||||
|
||||
|
||||
@app.get("/user/")
|
||||
async def get_user(req: UnwrappedRequest[UserIdSchema, Any]) -> SerializedResponse:
|
||||
async def get_user(req: UnwrappedRequest[UserIdSchema, BTYPE]) -> SerializedResponse:
|
||||
print(req)
|
||||
resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
|
||||
return SerializedResponse(200, {}, resp)
|
||||
@@ -39,8 +42,29 @@ class UserPostSchema(UserIdSchema):
|
||||
data: SomeData
|
||||
|
||||
|
||||
class SomeInternalData:
|
||||
a: list[int] = [0]
|
||||
|
||||
|
||||
@app.inj_repo.mark_injector()
|
||||
def foo() -> SomeInternalData:
|
||||
s = SomeInternalData()
|
||||
s.a[0] += 1
|
||||
return s
|
||||
|
||||
|
||||
@app.post("/user")
|
||||
async def post_user(req: UnwrappedRequest[Any, UserPostSchema]) -> SerializedResponse:
|
||||
async def post_user(req: UnwrappedRequest[QTYPE, UserPostSchema], dat: SomeInternalData) -> SerializedResponse:
|
||||
print(req)
|
||||
resp: dict[str, Any] = {'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data}
|
||||
print(dat)
|
||||
resp: dict[str, Any] = {
|
||||
'message': f'Hello, User {req.body.user_id}!',
|
||||
'from': 'server',
|
||||
'data': req.body.data,
|
||||
'inj': dat.a
|
||||
}
|
||||
return SerializedResponse(200, {}, resp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('turbosloth.__main__:app', host='0.0.0.0', port=8000, reload=True)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any
|
||||
import typing
|
||||
from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any, Annotated
|
||||
|
||||
import breakshaft.util_mermaid
|
||||
import megasniff.exceptions
|
||||
from breakshaft.models import ConversionPoint
|
||||
from megasniff import SchemaInflatorGenerator
|
||||
|
||||
from .exceptions import HTTPException
|
||||
@@ -10,8 +13,11 @@ from .interfaces.serialized import SerializedResponse, SerializedRequest
|
||||
from .interfaces.serialized.text import TextSerializedResponse
|
||||
from .req_schema import UnwrappedRequest
|
||||
from .router import Router
|
||||
from .types import HandlerType, InternalHandlerType
|
||||
from .internal_types import Scope, Receive, Send, MethodType
|
||||
from .types import HandlerType, InternalHandlerType, ContentType
|
||||
from .internal_types import Scope, Receive, Send, MethodType, QTYPE, BTYPE
|
||||
from breakshaft.convertor import ConvRepo
|
||||
|
||||
from .util import parse_content_type
|
||||
|
||||
|
||||
class ASGIApp(Protocol):
|
||||
@@ -24,6 +30,45 @@ class ASGIApp(Protocol):
|
||||
class HTTPApp(ASGIApp):
|
||||
serialize_selector: SerializeSelector
|
||||
|
||||
def extract_content_type(self, req: BasicRequest) -> ContentType:
|
||||
contenttype_header = req.headers.get('Content-Type')
|
||||
properties: dict[str, str]
|
||||
if contenttype_header is None:
|
||||
contenttype = self.serialize_selector.default_content_type
|
||||
properties = {}
|
||||
else:
|
||||
contenttype, properties = parse_content_type(contenttype_header)
|
||||
|
||||
charset = properties.get('charset')
|
||||
|
||||
if charset is None:
|
||||
if contenttype == 'application/json':
|
||||
charset = 'utf-8'
|
||||
else:
|
||||
charset = 'latin1'
|
||||
|
||||
return ContentType(contenttype, charset)
|
||||
|
||||
def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest:
|
||||
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
|
||||
return ser.req.deserialize(req, ct.charset)
|
||||
|
||||
def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: ContentType) -> BasicResponse:
|
||||
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
|
||||
sresponser = ser.resp
|
||||
|
||||
try:
|
||||
return sresponser.into_basic(sresp, ct.charset)
|
||||
except UnicodeEncodeError:
|
||||
return sresponser.into_basic(sresp, 'utf-8')
|
||||
|
||||
async def send_answer(self, send: Send, resp: BasicResponse):
|
||||
await send(resp.into_start_message())
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': resp.body,
|
||||
})
|
||||
|
||||
async def _do_http(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
method = scope['method']
|
||||
path = scope['path']
|
||||
@@ -45,31 +90,26 @@ class HTTPApp(ASGIApp):
|
||||
try:
|
||||
handler = self.router.match(method, path)
|
||||
|
||||
ser = self.serialize_selector.select(req.headers)
|
||||
sresponser = ser.resp
|
||||
charset = ser.charset
|
||||
sreq = ser.req.deserialize(req, charset)
|
||||
|
||||
sresp = await handler(sreq)
|
||||
await handler(send, req)
|
||||
return
|
||||
except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException):
|
||||
sresp = SerializedResponse(400, {}, 'Schema error')
|
||||
except HTTPException as e:
|
||||
sresp = SerializedResponse(e.code, {}, str(e))
|
||||
|
||||
try:
|
||||
resp = sresponser.into_basic(sresp, charset)
|
||||
ct = self.extract_content_type(req)
|
||||
resp = sresponser.into_basic(sresp, ct.charset)
|
||||
except UnicodeEncodeError:
|
||||
resp = sresponser.into_basic(sresp, 'utf-8')
|
||||
|
||||
await send(resp.into_start_message())
|
||||
await send({
|
||||
'type': 'http.response.body',
|
||||
'body': resp.body,
|
||||
})
|
||||
# пока без малейшего понятия как избавиться от ручной обертки обработки ошибок
|
||||
await self.send_answer(send, resp)
|
||||
|
||||
|
||||
class WSApp(ASGIApp):
|
||||
async def _do_websocket(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# TODO: impl ws
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -144,6 +184,23 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
||||
self._on_shutdown = on_shutdown
|
||||
self.serialize_selector = SerializeSelector()
|
||||
self.infl_generator = SchemaInflatorGenerator(strict_mode=True)
|
||||
self.inj_repo = ConvRepo()
|
||||
|
||||
@self.inj_repo.mark_injector()
|
||||
def extract_query(req: BasicRequest) -> QTYPE:
|
||||
return req.query
|
||||
|
||||
@self.inj_repo.mark_injector()
|
||||
def extract_query(req: SerializedRequest) -> QTYPE:
|
||||
return req.query
|
||||
|
||||
@self.inj_repo.mark_injector()
|
||||
def extract_body(req: SerializedRequest) -> BTYPE:
|
||||
return req.body
|
||||
|
||||
self.inj_repo.add_injector(self.extract_content_type)
|
||||
self.inj_repo.add_injector(self.serialize_request)
|
||||
self.inj_repo.add_injector(self.serialize_response)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
t = scope['type']
|
||||
@@ -158,9 +215,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
||||
|
||||
def route(self, method: MethodType, path_pattern: str):
|
||||
def decorator(fn: HandlerType):
|
||||
hints = get_type_hints(fn)
|
||||
handle_hints = get_type_hints(fn)
|
||||
req_schema = None
|
||||
for argname, tp in hints.items():
|
||||
for argname, tp in handle_hints.items():
|
||||
if argname == 'return':
|
||||
continue
|
||||
if get_origin(tp) == UnwrappedRequest:
|
||||
@@ -172,24 +229,56 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
|
||||
query_type, body_type = get_args(req_schema)
|
||||
q_inflator = None
|
||||
b_inflator = None
|
||||
if query_type != Any:
|
||||
q_inflator = self.infl_generator.schema_to_inflator(query_type, strict_mode_override=False)
|
||||
if body_type != Any:
|
||||
b_inflator = self.infl_generator.schema_to_inflator(body_type)
|
||||
|
||||
def internal_handler(req: SerializedRequest) -> Awaitable[SerializedResponse]:
|
||||
if q_inflator is not None:
|
||||
q = q_inflator(req.query)
|
||||
else:
|
||||
q = req.query
|
||||
if b_inflator is not None:
|
||||
b = b_inflator(req.body)
|
||||
else:
|
||||
b = req.body
|
||||
fork_with = set()
|
||||
|
||||
return fn(UnwrappedRequest(q, b))
|
||||
def none_generator(*args) -> None:
|
||||
return None
|
||||
|
||||
self.router.add(method, path_pattern, internal_handler)
|
||||
if query_type not in [QTYPE, None, Any]:
|
||||
q_inflator = self.infl_generator.schema_to_inflator(
|
||||
query_type,
|
||||
strict_mode_override=False,
|
||||
from_type_override=QTYPE
|
||||
)
|
||||
fork_with.add(ConversionPoint(q_inflator, query_type, (QTYPE,)))
|
||||
else:
|
||||
fork_with.add(ConversionPoint(none_generator, query_type, (QTYPE,)))
|
||||
|
||||
if body_type != [BTYPE, None, Any]:
|
||||
b_inflator = self.infl_generator.schema_to_inflator(
|
||||
body_type,
|
||||
from_type_override=BTYPE
|
||||
)
|
||||
fork_with.add(ConversionPoint(b_inflator, body_type, (BTYPE,)))
|
||||
else:
|
||||
fork_with.add(ConversionPoint(none_generator, body_type, (BTYPE,)))
|
||||
|
||||
def construct_unwrap(q, b) -> UnwrappedRequest:
|
||||
print(f'unwrapping {query_type} and {body_type} with {q} and {b}')
|
||||
return UnwrappedRequest(q, b)
|
||||
|
||||
fork_with |= {ConversionPoint(construct_unwrap, req_schema, (query_type or QTYPE, body_type or BTYPE,))}
|
||||
|
||||
tmp_repo = self.inj_repo.fork(fork_with)
|
||||
|
||||
conv = tmp_repo.get_conversion(
|
||||
(BasicRequest,),
|
||||
fn,
|
||||
force_async=True,
|
||||
)
|
||||
|
||||
out_conv = self.inj_repo.get_conversion(
|
||||
(Send, BasicRequest, handle_hints['return'],),
|
||||
self.send_answer,
|
||||
force_async=True
|
||||
)
|
||||
|
||||
async def pipeline(send: Send, req: BasicRequest):
|
||||
ret = await conv(req)
|
||||
await out_conv(send, req, ret)
|
||||
|
||||
self.router.add(method, path_pattern, pipeline)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -12,7 +12,6 @@ class SerializeChoise(NamedTuple):
|
||||
req: type[SerializedRequest]
|
||||
resp: type[SerializedResponse]
|
||||
charset: str
|
||||
content_properties: dict[str, str]
|
||||
|
||||
|
||||
class SerializeSelector:
|
||||
@@ -34,22 +33,7 @@ class SerializeSelector:
|
||||
ser[k] = v
|
||||
self.serializers = ser
|
||||
|
||||
def select(self, headers: CaseInsensitiveDict[str, str]) -> SerializeChoise:
|
||||
contenttype_header = headers.get('Content-Type')
|
||||
properties: dict[str, str]
|
||||
if contenttype_header is None:
|
||||
contenttype = self.default_content_type
|
||||
properties = {}
|
||||
else:
|
||||
contenttype, properties = parse_content_type(contenttype_header)
|
||||
|
||||
charset = properties.get('charset')
|
||||
|
||||
if charset is None:
|
||||
if contenttype == 'application/json':
|
||||
charset = 'utf-8'
|
||||
else:
|
||||
charset = 'latin1'
|
||||
def select(self, contenttype: str, charset: str) -> SerializeChoise:
|
||||
|
||||
choise = self.serializers.get(typing.cast(str, contenttype))
|
||||
if choise is None and self.default_content_type is not None:
|
||||
@@ -58,4 +42,4 @@ class SerializeSelector:
|
||||
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys()))
|
||||
req, resp = choise
|
||||
|
||||
return SerializeChoise(req, resp, charset, properties)
|
||||
return SerializeChoise(req, resp, charset)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Awaitable, Literal
|
||||
from typing import Any, Callable, Awaitable, Literal, Annotated
|
||||
|
||||
type Scope = dict[str, Any]
|
||||
type ASGIMessage = dict[str, Any]
|
||||
@@ -16,3 +16,6 @@ type MethodType = (
|
||||
Literal['CONNECT'] |
|
||||
Literal['OPTIONS'] |
|
||||
Literal['TRACE'])
|
||||
|
||||
type QTYPE = Annotated[dict[str, Any], 'query_params']
|
||||
type BTYPE = Annotated[dict[str, Any] | list[Any] | str | None, 'body']
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Awaitable, Literal, Any
|
||||
|
||||
from turbosloth.interfaces.base import BasicRequest
|
||||
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
|
||||
from turbosloth.internal_types import Send
|
||||
from turbosloth.req_schema import UnwrappedRequest
|
||||
|
||||
type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]]
|
||||
type InternalHandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]]
|
||||
type InternalHandlerType = Callable[[Send, BasicRequest], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentType:
|
||||
contenttype: str
|
||||
charset: str
|
||||
|
||||
12
uv.lock
generated
12
uv.lock
generated
@@ -4,15 +4,15 @@ requires-python = ">=3.13"
|
||||
|
||||
[[package]]
|
||||
name = "breakshaft"
|
||||
version = "0.1.0"
|
||||
version = "0.1.0.post2"
|
||||
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "hatchling" },
|
||||
{ name = "jinja2" },
|
||||
]
|
||||
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0.tar.gz", hash = "sha256:77f114bab957b78f96b5b087e8ef9778c9b1fc4b3433f0b428d9dd084243debe" }
|
||||
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2.tar.gz", hash = "sha256:d04f8685336080cddb5111422a9624c8afcc8c47264e1b847339139bfd690570" }
|
||||
wheels = [
|
||||
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0/breakshaft-0.1.0-py3-none-any.whl", hash = "sha256:e7c8831e5889fbcd788d59fee05f7aee113663ebef9273cd280893e1ed04f5c2" },
|
||||
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/breakshaft/0.1.0.post2/breakshaft-0.1.0.post2-py3-none-any.whl", hash = "sha256:fd4c9b213e2569c2c4d4f86cef2fe04065c08f9bffac3aa46b8ae21459f35074" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -184,15 +184,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "megasniff"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
source = { registry = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }
|
||||
dependencies = [
|
||||
{ name = "hatchling" },
|
||||
{ name = "jinja2" },
|
||||
]
|
||||
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0.tar.gz", hash = "sha256:feceefd8a618c7b930832b4ff8edcc474961f5d5f53743c64e142a796fea0e30" }
|
||||
sdist = { url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1.tar.gz", hash = "sha256:c6ed47b40fbdc92d7aa061267d158eea457b560f1bc8f84d297b55e5e4a0ef2e" }
|
||||
wheels = [
|
||||
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.0/megasniff-0.2.0-py3-none-any.whl", hash = "sha256:4f03a1e81dcb1020b6fe80f3fbc7b192a61b4dc55f7c6783141aa464dc3c77d1" },
|
||||
{ url = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/files/megasniff/0.2.1/megasniff-0.2.1-py3-none-any.whl", hash = "sha256:cff759fd61e9a4b8634329620cf13a941d7d5b38b834f8eab68e6ac78fa23589" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user