diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index ba37e35..127b814 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -33,7 +33,7 @@ import breakshaft.util class ASGIApp(Protocol): router: Router - def route(self, method: MethodType, path_pattern: str): + def route(self, method: MethodType, path_pattern: str, ok_return_code: str): raise RuntimeError('stub!') def add_subroute(self, subr: Route | Router, basepath: str) -> None: @@ -178,35 +178,32 @@ class LifespanApp: class MethodRoutersApp(ASGIApp): - def get(self, path_pattern: str): - return self.route('GET', path_pattern) + def get(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('GET', path_pattern, ok_return_code) - def post(self, path_pattern: str): - return self.route('POST', path_pattern) + def post(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('POST', path_pattern, ok_return_code) - def push(self, path_pattern: str): - return self.route('PUSH', path_pattern) + def put(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('PUT', path_pattern, ok_return_code) - def put(self, path_pattern: str): - return self.route('PUT', path_pattern) + def patch(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('PATCH', path_pattern, ok_return_code) - def patch(self, path_pattern: str): - return self.route('PATCH', path_pattern) + def delete(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('DELETE', path_pattern, ok_return_code) - def delete(self, path_pattern: str): - return self.route('DELETE', path_pattern) + def head(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('HEAD', path_pattern, ok_return_code) - def head(self, path_pattern: str): - return self.route('HEAD', path_pattern) + def connect(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('CONNECT', path_pattern, ok_return_code) - def connect(self, path_pattern: str): - return self.route('CONNECT', path_pattern) + def options(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('OPTIONS', path_pattern, ok_return_code) - def options(self, path_pattern: str): - return self.route('OPTIONS', path_pattern) - - def trace(self, path_pattern: str): - return self.route('TRACE', path_pattern) + def trace(self, path_pattern: str, ok_return_code: str = '200'): + return self.route('TRACE', path_pattern, ok_return_code) class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): @@ -442,11 +439,11 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): return ret - def route(self, method: MethodType, path_pattern: str): + def route(self, method: MethodType, path_pattern: str, ok_return_code: str): def decorator(fn: HandlerType): path_substs = self.router.find_pattern_substs(path_pattern) - fork_with, fn_type_hints = self._integrate_func(fn, path_substs) + fork_with, fn_type_hints = self._integrate_func(fn, path_substs, ok_return_code=ok_return_code) tmp_repo = self.inj_repo.fork(fork_with) p = tmp_repo.create_pipeline( (Send, BasicRequest), @@ -471,7 +468,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): ) ) - self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern)( + self.route('GET', self.di_autodoc_prefix + '/' + method + path_pattern, '200')( create_di_autodoc_handler(method, path_pattern, p, depgraph)) if self.openapi_app is not None: @@ -504,12 +501,13 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): def _integrate_func(self, func: Callable, - path_params: Optional[Iterable[str]] = None): + path_params: Optional[Iterable[str]] = None, + ok_return_code: Optional[str] = None): injectors = self.inj_repo.filtered_injectors(True, True) injected_types = list(map(lambda x: x.injects, injectors)) - config = EndpointConfig.from_handler(func, path_params or set(), injected_types) + config = EndpointConfig.from_handler(func, path_params or set(), injected_types, ok_return_code=ok_return_code) fork_with = set() diff --git a/src/turbosloth/internal_types.py b/src/turbosloth/internal_types.py index 5fe4c55..d82e72a 100644 --- a/src/turbosloth/internal_types.py +++ b/src/turbosloth/internal_types.py @@ -10,7 +10,6 @@ type Send = Callable[[ASGIMessage], Awaitable[None]] type MethodType = ( Literal['GET'] | Literal['POST'] | - Literal['PUSH'] | Literal['PUT'] | Literal['PATCH'] | Literal['DELETE'] | diff --git a/src/turbosloth/schema.py b/src/turbosloth/schema.py index 21ef2f4..626c316 100644 --- a/src/turbosloth/schema.py +++ b/src/turbosloth/schema.py @@ -9,6 +9,9 @@ from typing import TypeVar, Generic, Annotated from megasniff.utils import TupleSchemaItem +from turbosloth.interfaces.base import BasicResponse +from turbosloth.interfaces.serialized import SerializedResponse + T = TypeVar("T") from typing import Annotated, TypeAlias @@ -63,22 +66,33 @@ def get_endpoint_params_info(func) -> dict[str, ParamSchema]: return result +@dataclass +class ReturnSchema: + code_map: dict[str, type] + + @dataclass class EndpointConfig: body_schema: ParamSchema | None query_schemas: dict[str, ParamSchema] path_schemas: dict[str, ParamSchema] header_schemas: dict[str, ParamSchema] + return_schema: ReturnSchema | None fn: Callable type_replacement: dict[str, type] @classmethod - def from_handler(cls, h: Callable, path_substituts: set[str], ignore_types: Iterable[type]) -> EndpointConfig: + def from_handler(cls, + h: Callable, + path_substituts: set[str], + ignore_types: Iterable[type], + ok_return_code: Optional[str] = None) -> EndpointConfig: body_schema = None query_schemas = {} path_schemas = {} header_schemas = {} type_replacement = {} + return_schema = None handle_hints = get_endpoint_params_info(h) for argname, s in handle_hints.items(): @@ -88,7 +102,11 @@ class EndpointConfig: type_replacement[argname] = s.replacement_type if argname == 'return': - continue + if ok_return_code is None: + continue + if issubclass(s.schema, (SerializedResponse, BasicResponse)): + continue + return_schema = ReturnSchema({ok_return_code: s.schema}) if get_origin(tp) == Annotated: args = get_args(tp) @@ -120,4 +138,10 @@ class EndpointConfig: else: query_schemas[argname] = s - return EndpointConfig(body_schema, query_schemas, path_schemas, header_schemas, h, type_replacement) + return EndpointConfig(body_schema, + query_schemas, + path_schemas, + header_schemas, + return_schema, + h, + type_replacement)