diff --git a/pyproject.toml b/pyproject.toml index fb12384..63dea1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "case-insensitive-dictionary>=0.2.1", "mypy>=1.17.0", "jinja2>=3.1.6", + "python-multipart>=0.0.20", ] [tool.uv.sources] diff --git a/src/turbosloth/__main__.py b/src/turbosloth/__main__.py index 29b01e6..bb22e09 100644 --- a/src/turbosloth/__main__.py +++ b/src/turbosloth/__main__.py @@ -1,22 +1,27 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Literal +from typing import Optional, Literal, Annotated import uvicorn +from python_multipart.multipart import File, Field from turbosloth import SlothApp from turbosloth.doc.openapi_app import OpenAPIApp from turbosloth.doc.openapi_models import Info -from turbosloth.interfaces.base import BasicResponse +from turbosloth.interfaces.base import BasicResponse, BasicRequest from turbosloth.interfaces.serialize_selector import SerializeSelector from turbosloth.interfaces.serialized import SerializedResponse +from turbosloth.interfaces.serialized.multipart_form_data import MultipartFormSerializedRequest from turbosloth.schema import RequestBody, HeaderParam, QueryParam, Resp # from turbosloth.types import HTTPResponse app = SlothApp(di_autodoc_prefix='/didoc', - serialize_selector=SerializeSelector(default_content_type='application/json'), + serialize_selector=SerializeSelector( + default_content_type='application/json', + default_accept_type='application/json' + ), openapi_app=OpenAPIApp(Info('asdf', '1.0.0'))) @@ -156,6 +161,53 @@ async def test_body(r: RequestBody(UserPostSchema), return TestResp(**resp), HeadersResp('asdf') +@dataclass +class FieldData: + type_: Annotated[str, 'type'] + name: str + value: str + + +@dataclass +class FileData: + type_: Annotated[str, 'type'] + name: str + fname: str + + +@dataclass +class FileRespSchema: + fields: list[FileData | FieldData] + + +@app.post('/upload_multipart') +async def upload_multipart(req: MultipartFormSerializedRequest) -> Resp(FileRespSchema, 200): + fields = [] + fields_raw = [] + async for e in req: + ee = {} + if isinstance(e, File): + f = FileData('file', e.field_name.decode(), e.file_name.decode()) + fields.append(f) + ee = f.__dict__ + fields_raw.append(ee) + if e.in_memory: + # оно в любом случае будет сохраняться на диск если слишком большое, что мне не очень нравится + # TODO: кастомизация поведения File и Field полей, потоковое чтение multipart/* + e.flush_to_disk() + print(e.actual_file_name) + elif isinstance(e, Field): + f = FieldData('field', e.field_name.decode(), e.value.decode()) + fields.append(f) + ee = f.__dict__ + fields_raw.append(ee) + else: + pass + print(e) + + return FileRespSchema(fields) + + @app.get('/openapi.json') async def openapi_schema(a: SlothApp, version: QueryParam(Optional[str], 'v') = None) -> SerializedResponse: dat = a.openapi_app.export_as_dict() diff --git a/src/turbosloth/app.py b/src/turbosloth/app.py index 55141a4..04b01a1 100644 --- a/src/turbosloth/app.py +++ b/src/turbosloth/app.py @@ -1,5 +1,6 @@ from __future__ import annotations import html +import traceback import typing from dataclasses import dataclass from types import NoneType @@ -90,14 +91,14 @@ class HTTPApp(ASGIApp): return ContentType(contenttype, charset) - def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest: - ser = self.serialize_selector.select(ct.contenttype, ct.charset) - return ser.req.deserialize(req, ct.charset) + async def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest: + ser = self.serialize_selector.select_req(ct.contenttype) + return await ser.deserialize(req, ct.charset) def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: Accept) -> BasicResponse: if type(sresp) is SerializedResponse: - ser = self.serialize_selector.select(ct.contenttype, ct.charset) - sresponser = ser.resp + ser = self.serialize_selector.select_resp(ct.contenttype) + sresponser = ser try: return sresponser.into_basic(sresp, ct.charset) @@ -120,15 +121,15 @@ class HTTPApp(ASGIApp): method = scope['method'] path = scope['path'] - body = b'' - while True: - event = await receive() - body += event.get('body', b'') - if not event.get('more_body', False): - break - scope['body'] = body + # body = b'' + # while True: + # event = await receive() + # body += event.get('body', b'') + # if not event.get('more_body', False): + # break + # scope['body'] = body - req = BasicRequest.from_scope(scope) + req = BasicRequest.from_scope(scope, receive) sresponser: type[SerializedResponse] = TextSerializedResponse charset = 'latin1' @@ -286,6 +287,19 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): self.inj_repo.add_injector(self.serialize_request) self.inj_repo.add_injector(self.serialize_response) + added_req_serializers = set() + for ser in self.serialize_selector.req_serializers.values(): + if ser in added_req_serializers: + continue + added_req_serializers.add(ser) + + async def _serialize(ct: ContentType, req: BasicRequest): + return await ser.deserialize(req, ct.charset) + + hints = get_type_hints(_serialize, include_extras=True) + hints['return'] = ser + self.inj_repo.add_injector(_serialize, type_remap=hints) + async def __call__(self, scope: Scope, receive: Receive, send: Send): t = scope['type'] if t == 'http': @@ -474,13 +488,18 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): config, fork_with, fn_type_hints = self._integrate_func(fn, path_substs, ok_return_code=ok_return_code) tmp_repo = self.inj_repo.fork(fork_with) - p = tmp_repo.create_pipeline( - (Send, BasicRequest), - [ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer], - force_async=True, - ) + try: + p = tmp_repo.create_pipeline( + (Send, BasicRequest), + [ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer], + force_async=True, + ) - self.router.add(method, path_pattern, p) + self.router.add(method, path_pattern, p) + except Exception as e: + print(f'Error: unable to register handler {method} {path_pattern}: {e}') + traceback.print_exc() + p = None if self.di_autodoc_prefix is not None and not path_pattern.startswith( self.di_autodoc_prefix + '/' + method + self.di_autodoc_prefix): @@ -507,7 +526,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): schema = self.schema_from_type(v.body) self.openapi_app.register_component(v.body) content = {} - for ctype in self.serialize_selector.serializers.keys(): + for ctype in self.serialize_selector.resp_serializers.keys(): content[ctype] = MediaType(schema=schema) headers = None if v.headers_provided: @@ -519,7 +538,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): ) req_body = {} req_schema = self.construct_req_body(p) - for ctype in self.serialize_selector.serializers.keys(): + for ctype in self.serialize_selector.req_serializers.keys(): req_body[ctype] = MediaType(req_schema) self.openapi_app.register_endpoint( method, @@ -548,7 +567,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): ok_return_code: Optional[str] = None): injectors = self.inj_repo.filtered_injectors(True, True) - injected_types = list(map(lambda x: x.injects, injectors)) + injected_types = list(map(lambda x: x.injects, injectors)) + [BasicRequest, SerializedRequest] config = EndpointConfig.from_handler(func, path_params or set(), injected_types, ok_return_code=ok_return_code) @@ -602,7 +621,15 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): }) ) - if len(combined_header_return_types) > 0: + is_only_dict_headers = True + for t in combined_header_return_types: + origin = t + while o := get_origin(origin): + origin = o + if isinstance(origin, dict): + is_only_dict_headers = False + break + if len(combined_header_return_types) > 0 and not is_only_dict_headers: fork_with |= set( ConversionPoint.from_fn(self.defl_generator.schema_to_deflator(combined_header_return_type), rettype=HTTPResponseHeadersPlaceholder, @@ -611,6 +638,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp): 'from_data': combined_header_return_type }) ) + else: + combined_return_type = tuple[combined_body_return_type, HTTPResponseHeadersPlaceholder] + fork_with |= set( ConversionPoint.from_fn( diff --git a/src/turbosloth/interfaces/base.py b/src/turbosloth/interfaces/base.py index c7bc417..cd1c214 100644 --- a/src/turbosloth/interfaces/base.py +++ b/src/turbosloth/interfaces/base.py @@ -3,11 +3,15 @@ from __future__ import annotations import typing import urllib.parse from dataclasses import dataclass -from typing import Any, Mapping +from typing import Any, Mapping, Awaitable, Callable from case_insensitive_dict import CaseInsensitiveDict -from turbosloth.internal_types import MethodType, Scope, ASGIMessage +from turbosloth.internal_types import MethodType, Scope, ASGIMessage, Receive + + +async def _dummy_recv() -> ASGIMessage: + return {} @dataclass @@ -16,24 +20,49 @@ class BasicRequest: path: str headers: CaseInsensitiveDict[str, str] query: dict[str, list[Any] | Any] - body: bytes path_matches: dict[str, str] + _do_body_recv: Receive + _body_recv_all = False + + @property + def is_body_recv_done(self): + return self._body_recv_all + + async def fetch_full_body(self) -> bytes: + is_done = False + buf = b'' + while not is_done: + b, is_done = await self.fetch_body_part() + buf += b + return buf + + async def fetch_body_part(self) -> tuple[bytes | bytearray, bool]: + if self._body_recv_all: + return b'', True + event = await self._do_body_recv() + buf = event.get('body', b'') + is_done = not event.get('more_body', False) + if is_done: + self._body_recv_all = True + return buf, is_done def __init__(self, method: MethodType, path: str, headers: Mapping[str, str], query: dict[str, list[Any] | Any], - body: bytes): + body_recv: Receive): self.method = method self.path = path self.headers = CaseInsensitiveDict(headers) self.query = query - self.body = body + self._do_body_recv = body_recv self.path_matches = {} @classmethod - def from_scope(cls, scope: Scope) -> BasicRequest: + def from_scope(cls, + scope: Scope, + recv: Receive) -> BasicRequest: path = scope['path'] method = typing.cast(MethodType, scope['method']) headers = {} @@ -50,9 +79,7 @@ class BasicRequest: query[k] = v query = typing.cast(dict[str, list[Any] | Any], query) - body = scope['body'] - - return BasicRequest(method, path, headers, query, body) + return BasicRequest(method, path, headers, query, recv) @dataclass diff --git a/src/turbosloth/interfaces/serialize_selector.py b/src/turbosloth/interfaces/serialize_selector.py index d909b40..de3a605 100644 --- a/src/turbosloth/interfaces/serialize_selector.py +++ b/src/turbosloth/interfaces/serialize_selector.py @@ -4,42 +4,60 @@ from typing import NamedTuple, Optional from case_insensitive_dict import CaseInsensitiveDict from turbosloth.exceptions import NotAcceptableException -from turbosloth.interfaces.serialized import SerializedRequest, SerializedResponse, default_serializers +from turbosloth.interfaces.serialized import SerializedRequest, SerializedResponse, default_req_serializers, \ + default_resp_serializers from turbosloth.util import parse_content_type - -class SerializeChoise(NamedTuple): - req: type[SerializedRequest] - resp: type[SerializedResponse] - charset: str +type SerializeReqChoise = tuple[type[SerializedRequest], str] +type SerializeRespChoise = tuple[type[SerializedResponse], str] class SerializeSelector: default_content_type: Optional[str] - serializers: dict[str, tuple[type[SerializedRequest], type[SerializedResponse]]] + default_accept_type: Optional[str] + req_serializers: dict[str, type[SerializedRequest]] + resp_serializers: dict[str, type[SerializedResponse]] def __init__(self, default_content_type: Optional[str] = 'text/plain', - filter_content_types: Optional[list[str]] = None): + default_accept_type: Optional[str] = 'text/plain', + filter_content_types: Optional[list[str]] = None, + filter_accept_types: Optional[list[str]] = None): self.default_content_type = default_content_type - ser = {} + self.default_accept_type = default_accept_type + req_ser = {} + resp_ser = {} if filter_content_types is None: - filter_content_types = list(default_serializers.keys()) + filter_content_types = list(default_req_serializers.keys()) - for k, v in default_serializers.items(): + if filter_accept_types is None: + filter_accept_types = list(default_resp_serializers.keys()) + + for k, v in default_req_serializers.items(): if k in filter_content_types: - ser[k] = v - self.serializers = ser + req_ser[k] = v + for k, v in default_resp_serializers.items(): + if k in filter_accept_types: + resp_ser[k] = v + self.req_serializers = req_ser + self.resp_serializers = resp_ser - def select(self, contenttype: str, charset: str) -> SerializeChoise: - - choise = self.serializers.get(typing.cast(str, contenttype)) - if choise is None and self.default_content_type is not None: - choise = self.serializers.get(self.default_content_type) + def select_req(self, contenttype: str) -> type[SerializedRequest]: + choise = self.req_serializers.get(typing.cast(str, contenttype)) + if choise is None and self.default_accept_type is not None: + choise = self.req_serializers.get(self.default_accept_type) if choise is None: - raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys())) - req, resp = choise + raise NotAcceptableException('acceptable content types: ' + ', '.join(self.req_serializers.keys())) - return SerializeChoise(req, resp, charset) + return choise + + def select_resp(self, accepttype: str) -> type[SerializedResponse]: + choise = self.resp_serializers.get(typing.cast(str, accepttype)) + if choise is None and self.default_content_type is not None: + choise = self.resp_serializers.get(self.default_content_type) + if choise is None: + raise NotAcceptableException('acceptable content types: ' + ', '.join(self.resp_serializers.keys())) + + return choise diff --git a/src/turbosloth/interfaces/serialized/__init__.py b/src/turbosloth/interfaces/serialized/__init__.py index 995fe80..8ccf000 100644 --- a/src/turbosloth/interfaces/serialized/__init__.py +++ b/src/turbosloth/interfaces/serialized/__init__.py @@ -1,31 +1,44 @@ from .base import SerializedRequest, SerializedResponse -default_serializers: dict[str, tuple[type[SerializedRequest], type[SerializedResponse]]] = {} +default_req_serializers: dict[str, type[SerializedRequest]] = {} +default_resp_serializers: dict[str, type[SerializedResponse]] = {} try: from .text import TextSerializedRequest, TextSerializedResponse - default_serializers['text/plain'] = (TextSerializedRequest, TextSerializedResponse) + default_req_serializers['text/plain'] = TextSerializedRequest + default_resp_serializers['text/plain'] = TextSerializedResponse except: pass try: from .json import JsonSerializedRequest, JsonSerializedResponse - default_serializers['application/json'] = (JsonSerializedRequest, JsonSerializedResponse) + default_req_serializers['application/json'] = JsonSerializedRequest + default_resp_serializers['application/json'] = JsonSerializedResponse except: pass try: from .xml import XMLSerializedRequest, XMLSerializedResponse - default_serializers['application/xml'] = (XMLSerializedRequest, XMLSerializedResponse) + default_req_serializers['application/xml'] = XMLSerializedRequest + default_resp_serializers['application/xml'] = XMLSerializedResponse except: pass try: from .msgpack import MessagePackSerializedRequest, MessagePackSerializedResponse - default_serializers['application/vnd.msgpack'] = (MessagePackSerializedRequest, MessagePackSerializedResponse) + default_req_serializers['application/vnd.msgpack'] = MessagePackSerializedRequest + default_resp_serializers['application/vnd.msgpack'] = MessagePackSerializedResponse +except: + pass + +try: + from .multipart_form_data import MultipartFormSerializedRequest + + default_req_serializers['application/octet-stream'] = MultipartFormSerializedRequest + default_req_serializers['application/x-www-form-urlencoded'] = MultipartFormSerializedRequest except: pass diff --git a/src/turbosloth/interfaces/serialized/base.py b/src/turbosloth/interfaces/serialized/base.py index 4ac4cd4..dcd78c4 100644 --- a/src/turbosloth/interfaces/serialized/base.py +++ b/src/turbosloth/interfaces/serialized/base.py @@ -35,7 +35,7 @@ class SerializedRequest: return self.basic.path_matches @classmethod - def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: raise NotImplementedError() diff --git a/src/turbosloth/interfaces/serialized/json.py b/src/turbosloth/interfaces/serialized/json.py index 47b080e..0623be0 100644 --- a/src/turbosloth/interfaces/serialized/json.py +++ b/src/turbosloth/interfaces/serialized/json.py @@ -7,13 +7,14 @@ from .base import SerializedRequest, SerializedResponse class JsonSerializedRequest(SerializedRequest): @classmethod - def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: - if len(basic.body) == 0: + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + body = await basic.fetch_full_body() + if len(body) == 0: b = None elif charset.lower() in {'utf-8', 'utf8'}: - b = orjson.loads(basic.body) + b = orjson.loads(body) else: - btxt = basic.body.decode(charset) + btxt = body.decode(charset) b = orjson.loads(btxt.encode('utf-8')) return cls(b, basic, charset) diff --git a/src/turbosloth/interfaces/serialized/msgpack.py b/src/turbosloth/interfaces/serialized/msgpack.py index 396657e..b741af5 100644 --- a/src/turbosloth/interfaces/serialized/msgpack.py +++ b/src/turbosloth/interfaces/serialized/msgpack.py @@ -7,11 +7,12 @@ from .base import SerializedRequest, SerializedResponse class MessagePackSerializedRequest(SerializedRequest): @classmethod - def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: - if len(basic.body) == 0: + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + body = await basic.fetch_full_body() + if len(body) == 0: b = None else: - b = msgpack.unpackb(basic.body) + b = msgpack.unpackb(body) return cls(b, basic, charset) diff --git a/src/turbosloth/interfaces/serialized/multipart_form_data.py b/src/turbosloth/interfaces/serialized/multipart_form_data.py new file mode 100644 index 0000000..db964c5 --- /dev/null +++ b/src/turbosloth/interfaces/serialized/multipart_form_data.py @@ -0,0 +1,49 @@ +import asyncio +from dataclasses import dataclass +from queue import Queue, Empty +from typing import AsyncIterable + +import python_multipart +from case_insensitive_dict import CaseInsensitiveDict +from python_multipart import MultipartParser, FormParser + +from turbosloth.interfaces.base import BasicRequest, BasicResponse +from .base import SerializedRequest, SerializedResponse + + +@dataclass +class MultipartFormSerializedRequest(SerializedRequest): + _parser: FormParser + _parse_q: Queue + + async def _fetch_part(self) -> bool: + was_done = self.basic.is_body_recv_done + buf, _ = await self.basic.fetch_body_part() + self._parser.write(buf) + return was_done + + async def __aiter__(self) -> AsyncIterable: + is_done = False + while not is_done: + try: + it = self._parse_q.get(block=False) + except Empty: + is_done = await self._fetch_part() + else: + yield it + + def __init__(self, basic: BasicRequest, charset: str, ): + content_type = basic.headers.get('content-type', 'multipart/form-data') + self.basic = basic + self.body = None + self.charset = charset + self._parse_q = Queue() + self._parser = python_multipart.create_form_parser( + basic.headers, + on_field=self._parse_q.put, + on_file=self._parse_q.put, + ) + + @classmethod + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + return MultipartFormSerializedRequest(basic, charset) diff --git a/src/turbosloth/interfaces/serialized/text.py b/src/turbosloth/interfaces/serialized/text.py index 0d59e3c..1b08af6 100644 --- a/src/turbosloth/interfaces/serialized/text.py +++ b/src/turbosloth/interfaces/serialized/text.py @@ -6,8 +6,9 @@ from .base import SerializedRequest, SerializedResponse class TextSerializedRequest(SerializedRequest): @classmethod - def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: - b = basic.body.decode(charset) + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + body = await basic.fetch_full_body() + b = body.decode(charset) return cls(b, basic, charset) diff --git a/src/turbosloth/interfaces/serialized/xml.py b/src/turbosloth/interfaces/serialized/xml.py index 39dc260..ff22f42 100644 --- a/src/turbosloth/interfaces/serialized/xml.py +++ b/src/turbosloth/interfaces/serialized/xml.py @@ -9,11 +9,12 @@ from .base import SerializedRequest, SerializedResponse class XMLSerializedRequest(SerializedRequest): @classmethod - def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: - if len(basic.body) == 0: + async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest: + body = await basic.fetch_full_body() + if len(body) == 0: b = {} else: - btxt = basic.body.decode(charset) + btxt = body.decode(charset) parsed = etree.fromstring(btxt) b = {child.tag: child.text for child in parsed} diff --git a/uv.lock b/uv.lock index 233eb47..2504ad0 100644 --- a/uv.lock +++ b/uv.lock @@ -390,6 +390,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158, upload-time = "2024-12-16T19:45:46.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, +] + [[package]] name = "trove-classifiers" version = "2025.9.11.17" @@ -409,6 +418,7 @@ dependencies = [ { name = "jinja2" }, { name = "megasniff" }, { name = "mypy" }, + { name = "python-multipart" }, ] [package.dev-dependencies] @@ -445,6 +455,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.6" }, { name = "megasniff", specifier = ">=0.2.4", index = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" }, { name = "mypy", specifier = ">=1.17.0" }, + { name = "python-multipart", specifier = ">=0.0.20" }, ] [package.metadata.requires-dev]