Dummy attach megasniff scheme inflation

This commit is contained in:
2025-07-16 19:43:24 +03:00
parent ddfd6e8d78
commit 95c47a5e90
11 changed files with 121 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
import interfaces # import interfaces
import router # import router
from .app import SlothApp from .app import SlothApp
__all__ = ['SlothApp', 'router', 'interfaces'] __all__ = ['SlothApp', 'router', 'interfaces']

View File

@@ -1,21 +1,46 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from dataclasses import dataclass
from typing import Any, Optional
from turbosloth import SlothApp from turbosloth import SlothApp
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.req_schema import UnwrappedRequest
app = SlothApp() app = SlothApp()
@app.get("/") @app.get("/")
async def index(req: SerializedRequest) -> SerializedResponse: async def index(req: UnwrappedRequest[Any, Any]) -> SerializedResponse:
return SerializedResponse(200, {}, 'Hello, ASGI Router!') return SerializedResponse(200, {}, 'Hello, ASGI Router!')
@dataclass
class UserIdSchema:
user_id: int
@app.get("/user/") @app.get("/user/")
@app.post("/user/") async def get_user(req: UnwrappedRequest[UserIdSchema, Any]) -> SerializedResponse:
async def get_user(req: SerializedRequest) -> SerializedResponse: print(req)
print(req.basic.query) resp: dict[str, Any] = {'message': f'Hello, User ы {req.query.user_id}!', 'from': 'server', 'echo': req.body}
resp: dict[str, Any] = {'message': f'Hello, User ы {req.basic.query["id"]}!', 'from': 'server', 'echo': req.body} return SerializedResponse(200, {}, resp)
@dataclass
class SomeData:
a: int
b: float
c: Optional[SomeData] = None
@dataclass
class UserPostSchema(UserIdSchema):
data: SomeData
@app.post("/user")
async def post_user(req: UnwrappedRequest[Any, UserPostSchema]) -> SerializedResponse:
print(req)
resp: dict[str, Any] = {'message': f'Hello, User {req.body.user_id}!', 'from': 'server', 'data': req.body.data}
return SerializedResponse(200, {}, resp) return SerializedResponse(200, {}, resp)

View File

@@ -1,12 +1,17 @@
from typing import Optional, Callable, Awaitable, Protocol from typing import Optional, Callable, Awaitable, Protocol, get_type_hints, get_origin, get_args, Any
import megasniff.exceptions
from megasniff import SchemaInflatorGenerator
from .exceptions import HTTPException from .exceptions import HTTPException
from .interfaces.base import BasicRequest, BasicResponse from .interfaces.base import BasicRequest, BasicResponse
from .interfaces.serialize_selector import SerializeSelector from .interfaces.serialize_selector import SerializeSelector
from .interfaces.serialized import SerializedResponse from .interfaces.serialized import SerializedResponse, SerializedRequest
from .interfaces.serialized.text import TextSerializedResponse from .interfaces.serialized.text import TextSerializedResponse
from .req_schema import UnwrappedRequest
from .router import Router from .router import Router
from .types import Scope, Receive, Send, MethodType, HandlerType from .types import HandlerType, InternalHandlerType
from .internal_types import Scope, Receive, Send, MethodType
class ASGIApp(Protocol): class ASGIApp(Protocol):
@@ -46,6 +51,8 @@ class HTTPApp(ASGIApp):
sreq = ser.req.deserialize(req, charset) sreq = ser.req.deserialize(req, charset)
sresp = await handler(sreq) sresp = await handler(sreq)
except (megasniff.exceptions.FieldValidationException, megasniff.exceptions.MissingFieldException):
sresp = SerializedResponse(400, {}, 'Schema error')
except HTTPException as e: except HTTPException as e:
sresp = SerializedResponse(e.code, {}, str(e)) sresp = SerializedResponse(e.code, {}, str(e))
@@ -136,6 +143,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self._on_startup = on_startup self._on_startup = on_startup
self._on_shutdown = on_shutdown self._on_shutdown = on_shutdown
self.serialize_selector = SerializeSelector() self.serialize_selector = SerializeSelector()
self.infl_generator = SchemaInflatorGenerator(strict_mode=True)
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
t = scope['type'] t = scope['type']
@@ -150,7 +158,38 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
def route(self, method: MethodType, path_pattern: str): def route(self, method: MethodType, path_pattern: str):
def decorator(fn: HandlerType): def decorator(fn: HandlerType):
self.router.add(method, path_pattern, fn) hints = get_type_hints(fn)
req_schema = None
for argname, tp in hints.items():
if argname == 'return':
continue
if get_origin(tp) == UnwrappedRequest:
req_schema = tp
if req_schema is None:
raise ValueError(f'Unable to find request schema in handler {fn}')
query_type, body_type = get_args(req_schema)
q_inflator = None
b_inflator = None
if query_type != Any:
q_inflator = self.infl_generator.schema_to_inflator(query_type, strict_mode_override=False)
if body_type != Any:
b_inflator = self.infl_generator.schema_to_inflator(body_type)
def internal_handler(req: SerializedRequest) -> Awaitable[SerializedResponse]:
if q_inflator is not None:
q = q_inflator(req.query)
else:
q = req.query
if b_inflator is not None:
b = b_inflator(req.body)
else:
b = req.body
return fn(UnwrappedRequest(q, b))
self.router.add(method, path_pattern, internal_handler)
return fn return fn
return decorator return decorator

View File

@@ -1,4 +1,4 @@
from http_base import HTTPException from .http_base import HTTPException
from .client_errors import * from .client_errors import *
from .server_errors import * from .server_errors import *

View File

View File

@@ -7,7 +7,7 @@ from typing import Any, Mapping
from case_insensitive_dict import CaseInsensitiveDict from case_insensitive_dict import CaseInsensitiveDict
from turbosloth.types import MethodType, Scope, ASGIMessage from turbosloth.internal_types import MethodType, Scope, ASGIMessage
@dataclass @dataclass

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Mapping from typing import Any, Mapping
from case_insensitive_dict import CaseInsensitiveDict from case_insensitive_dict import CaseInsensitiveDict
from turbosloth.interfaces.base import BasicRequest, BasicResponse from turbosloth.interfaces.base import BasicRequest, BasicResponse
from turbosloth.types import MethodType from turbosloth.internal_types import MethodType
@dataclass @dataclass

View File

@@ -0,0 +1,18 @@
from typing import Any, Callable, Awaitable, Literal
type Scope = dict[str, Any]
type ASGIMessage = dict[str, Any]
type Receive = Callable[[], Awaitable[ASGIMessage]]
type Send = Callable[[ASGIMessage], Awaitable[None]]
type MethodType = (
Literal['GET'] |
Literal['POST'] |
Literal['PUSH'] |
Literal['PUT'] |
Literal['PATCH'] |
Literal['DELETE'] |
Literal['HEAD'] |
Literal['CONNECT'] |
Literal['OPTIONS'] |
Literal['TRACE'])

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
from typing import Any, TypeVar
from mypy.visitor import Generic
Q = TypeVar('Q')
B = TypeVar('B')
@dataclass
class UnwrappedRequest(Generic[Q, B]):
query: Q
body: B

View File

@@ -4,18 +4,19 @@ import typing
from typing import Optional, Sequence from typing import Optional, Sequence
from .exceptions import MethodNotAllowedException, NotFoundException from .exceptions import MethodNotAllowedException, NotFoundException
from .types import HandlerType, MethodType from .types import InternalHandlerType
from .internal_types import MethodType
class Route: class Route:
static_subroutes: dict[str, Route] static_subroutes: dict[str, Route]
handler: dict[MethodType, HandlerType] handler: dict[MethodType, InternalHandlerType]
def __init__(self) -> None: def __init__(self) -> None:
self.static_subroutes = {} self.static_subroutes = {}
self.handler = {} self.handler = {}
def add(self, method: MethodType, sequence: Sequence[str], handler: HandlerType) -> None: def add(self, method: MethodType, sequence: Sequence[str], handler: InternalHandlerType) -> None:
if len(sequence) == 0: if len(sequence) == 0:
self.handler[method] = handler self.handler[method] = handler
return return
@@ -27,7 +28,7 @@ class Route:
subroute.add(method, sequence[1:], handler) subroute.add(method, sequence[1:], handler)
def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[HandlerType]: def get(self, method: MethodType, sequence: Sequence[str]) -> Optional[InternalHandlerType]:
if len(sequence) == 0: if len(sequence) == 0:
ret = self.handler.get(method) ret = self.handler.get(method)
if ret is None: if ret is None:
@@ -45,7 +46,7 @@ class Router:
def __init__(self) -> None: def __init__(self) -> None:
self._root = Route() self._root = Route()
def add(self, method: MethodType, path_pattern: str, handler: HandlerType) -> None: def add(self, method: MethodType, path_pattern: str, handler: InternalHandlerType) -> None:
assert method.upper() == method assert method.upper() == method
segments = path_pattern.split('/') segments = path_pattern.split('/')
@@ -53,7 +54,7 @@ class Router:
segments = segments[1:] segments = segments[1:]
self._root.add(method, segments, handler) self._root.add(method, segments, handler)
def match(self, method: MethodType, path: str) -> HandlerType: def match(self, method: MethodType, path: str) -> InternalHandlerType:
method = typing.cast(MethodType, method.upper()) method = typing.cast(MethodType, method.upper())
segments = path.split('/') segments = path.split('/')

View File

@@ -3,21 +3,7 @@ from __future__ import annotations
from typing import Callable, Awaitable, Literal, Any from typing import Callable, Awaitable, Literal, Any
from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest from turbosloth.interfaces.serialized import SerializedResponse, SerializedRequest
from turbosloth.req_schema import UnwrappedRequest
type Scope = dict[str, Any] type HandlerType = Callable[[UnwrappedRequest], Awaitable[SerializedResponse]]
type ASGIMessage = dict[str, Any] type InternalHandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]]
type Receive = Callable[[], Awaitable[ASGIMessage]]
type Send = Callable[[ASGIMessage], Awaitable[None]]
type HandlerType = Callable[[SerializedRequest], Awaitable[SerializedResponse]]
type MethodType = (
Literal['GET'] |
Literal['POST'] |
Literal['PUSH'] |
Literal['PUT'] |
Literal['PATCH'] |
Literal['DELETE'] |
Literal['HEAD'] |
Literal['CONNECT'] |
Literal['OPTIONS'] |
Literal['TRACE'])