Muplipart upload support

This commit is contained in:
2025-10-15 16:59:41 +03:00
parent 4a5ca2cca7
commit e761dd3fdf
13 changed files with 279 additions and 74 deletions

View File

@@ -13,6 +13,7 @@ dependencies = [
"case-insensitive-dictionary>=0.2.1",
"mypy>=1.17.0",
"jinja2>=3.1.6",
"python-multipart>=0.0.20",
]
[tool.uv.sources]

View File

@@ -1,22 +1,27 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Literal
from typing import Optional, Literal, Annotated
import uvicorn
from python_multipart.multipart import File, Field
from turbosloth import SlothApp
from turbosloth.doc.openapi_app import OpenAPIApp
from turbosloth.doc.openapi_models import Info
from turbosloth.interfaces.base import BasicResponse
from turbosloth.interfaces.base import BasicResponse, BasicRequest
from turbosloth.interfaces.serialize_selector import SerializeSelector
from turbosloth.interfaces.serialized import SerializedResponse
from turbosloth.interfaces.serialized.multipart_form_data import MultipartFormSerializedRequest
from turbosloth.schema import RequestBody, HeaderParam, QueryParam, Resp
# from turbosloth.types import HTTPResponse
app = SlothApp(di_autodoc_prefix='/didoc',
serialize_selector=SerializeSelector(default_content_type='application/json'),
serialize_selector=SerializeSelector(
default_content_type='application/json',
default_accept_type='application/json'
),
openapi_app=OpenAPIApp(Info('asdf', '1.0.0')))
@@ -156,6 +161,53 @@ async def test_body(r: RequestBody(UserPostSchema),
return TestResp(**resp), HeadersResp('asdf')
@dataclass
class FieldData:
type_: Annotated[str, 'type']
name: str
value: str
@dataclass
class FileData:
type_: Annotated[str, 'type']
name: str
fname: str
@dataclass
class FileRespSchema:
fields: list[FileData | FieldData]
@app.post('/upload_multipart')
async def upload_multipart(req: MultipartFormSerializedRequest) -> Resp(FileRespSchema, 200):
fields = []
fields_raw = []
async for e in req:
ee = {}
if isinstance(e, File):
f = FileData('file', e.field_name.decode(), e.file_name.decode())
fields.append(f)
ee = f.__dict__
fields_raw.append(ee)
if e.in_memory:
# оно в любом случае будет сохраняться на диск если слишком большое, что мне не очень нравится
# TODO: кастомизация поведения File и Field полей, потоковое чтение multipart/*
e.flush_to_disk()
print(e.actual_file_name)
elif isinstance(e, Field):
f = FieldData('field', e.field_name.decode(), e.value.decode())
fields.append(f)
ee = f.__dict__
fields_raw.append(ee)
else:
pass
print(e)
return FileRespSchema(fields)
@app.get('/openapi.json')
async def openapi_schema(a: SlothApp, version: QueryParam(Optional[str], 'v') = None) -> SerializedResponse:
dat = a.openapi_app.export_as_dict()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import html
import traceback
import typing
from dataclasses import dataclass
from types import NoneType
@@ -90,14 +91,14 @@ class HTTPApp(ASGIApp):
return ContentType(contenttype, charset)
def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
return ser.req.deserialize(req, ct.charset)
async def serialize_request(self, ct: ContentType, req: BasicRequest) -> SerializedRequest:
ser = self.serialize_selector.select_req(ct.contenttype)
return await ser.deserialize(req, ct.charset)
def serialize_response(self, req: BasicRequest, sresp: SerializedResponse, ct: Accept) -> BasicResponse:
if type(sresp) is SerializedResponse:
ser = self.serialize_selector.select(ct.contenttype, ct.charset)
sresponser = ser.resp
ser = self.serialize_selector.select_resp(ct.contenttype)
sresponser = ser
try:
return sresponser.into_basic(sresp, ct.charset)
@@ -120,15 +121,15 @@ class HTTPApp(ASGIApp):
method = scope['method']
path = scope['path']
body = b''
while True:
event = await receive()
body += event.get('body', b'')
if not event.get('more_body', False):
break
scope['body'] = body
# body = b''
# while True:
# event = await receive()
# body += event.get('body', b'')
# if not event.get('more_body', False):
# break
# scope['body'] = body
req = BasicRequest.from_scope(scope)
req = BasicRequest.from_scope(scope, receive)
sresponser: type[SerializedResponse] = TextSerializedResponse
charset = 'latin1'
@@ -286,6 +287,19 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
self.inj_repo.add_injector(self.serialize_request)
self.inj_repo.add_injector(self.serialize_response)
added_req_serializers = set()
for ser in self.serialize_selector.req_serializers.values():
if ser in added_req_serializers:
continue
added_req_serializers.add(ser)
async def _serialize(ct: ContentType, req: BasicRequest):
return await ser.deserialize(req, ct.charset)
hints = get_type_hints(_serialize, include_extras=True)
hints['return'] = ser
self.inj_repo.add_injector(_serialize, type_remap=hints)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
t = scope['type']
if t == 'http':
@@ -474,6 +488,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
config, fork_with, fn_type_hints = self._integrate_func(fn, path_substs, ok_return_code=ok_return_code)
tmp_repo = self.inj_repo.fork(fork_with)
try:
p = tmp_repo.create_pipeline(
(Send, BasicRequest),
[ConversionPoint.from_fn(fn, type_remap=fn_type_hints), self.send_answer],
@@ -481,6 +496,10 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
)
self.router.add(method, path_pattern, p)
except Exception as e:
print(f'Error: unable to register handler {method} {path_pattern}: {e}')
traceback.print_exc()
p = None
if self.di_autodoc_prefix is not None and not path_pattern.startswith(
self.di_autodoc_prefix + '/' + method + self.di_autodoc_prefix):
@@ -507,7 +526,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
schema = self.schema_from_type(v.body)
self.openapi_app.register_component(v.body)
content = {}
for ctype in self.serialize_selector.serializers.keys():
for ctype in self.serialize_selector.resp_serializers.keys():
content[ctype] = MediaType(schema=schema)
headers = None
if v.headers_provided:
@@ -519,7 +538,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
)
req_body = {}
req_schema = self.construct_req_body(p)
for ctype in self.serialize_selector.serializers.keys():
for ctype in self.serialize_selector.req_serializers.keys():
req_body[ctype] = MediaType(req_schema)
self.openapi_app.register_endpoint(
method,
@@ -548,7 +567,7 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
ok_return_code: Optional[str] = None):
injectors = self.inj_repo.filtered_injectors(True, True)
injected_types = list(map(lambda x: x.injects, injectors))
injected_types = list(map(lambda x: x.injects, injectors)) + [BasicRequest, SerializedRequest]
config = EndpointConfig.from_handler(func, path_params or set(), injected_types, ok_return_code=ok_return_code)
@@ -602,7 +621,15 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
})
)
if len(combined_header_return_types) > 0:
is_only_dict_headers = True
for t in combined_header_return_types:
origin = t
while o := get_origin(origin):
origin = o
if isinstance(origin, dict):
is_only_dict_headers = False
break
if len(combined_header_return_types) > 0 and not is_only_dict_headers:
fork_with |= set(
ConversionPoint.from_fn(self.defl_generator.schema_to_deflator(combined_header_return_type),
rettype=HTTPResponseHeadersPlaceholder,
@@ -611,6 +638,9 @@ class SlothApp(HTTPApp, WSApp, LifespanApp, MethodRoutersApp):
'from_data': combined_header_return_type
})
)
else:
combined_return_type = tuple[combined_body_return_type, HTTPResponseHeadersPlaceholder]
fork_with |= set(
ConversionPoint.from_fn(

View File

@@ -3,11 +3,15 @@ from __future__ import annotations
import typing
import urllib.parse
from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any, Mapping, Awaitable, Callable
from case_insensitive_dict import CaseInsensitiveDict
from turbosloth.internal_types import MethodType, Scope, ASGIMessage
from turbosloth.internal_types import MethodType, Scope, ASGIMessage, Receive
async def _dummy_recv() -> ASGIMessage:
return {}
@dataclass
@@ -16,24 +20,49 @@ class BasicRequest:
path: str
headers: CaseInsensitiveDict[str, str]
query: dict[str, list[Any] | Any]
body: bytes
path_matches: dict[str, str]
_do_body_recv: Receive
_body_recv_all = False
@property
def is_body_recv_done(self):
return self._body_recv_all
async def fetch_full_body(self) -> bytes:
is_done = False
buf = b''
while not is_done:
b, is_done = await self.fetch_body_part()
buf += b
return buf
async def fetch_body_part(self) -> tuple[bytes | bytearray, bool]:
if self._body_recv_all:
return b'', True
event = await self._do_body_recv()
buf = event.get('body', b'')
is_done = not event.get('more_body', False)
if is_done:
self._body_recv_all = True
return buf, is_done
def __init__(self,
method: MethodType,
path: str,
headers: Mapping[str, str],
query: dict[str, list[Any] | Any],
body: bytes):
body_recv: Receive):
self.method = method
self.path = path
self.headers = CaseInsensitiveDict(headers)
self.query = query
self.body = body
self._do_body_recv = body_recv
self.path_matches = {}
@classmethod
def from_scope(cls, scope: Scope) -> BasicRequest:
def from_scope(cls,
scope: Scope,
recv: Receive) -> BasicRequest:
path = scope['path']
method = typing.cast(MethodType, scope['method'])
headers = {}
@@ -50,9 +79,7 @@ class BasicRequest:
query[k] = v
query = typing.cast(dict[str, list[Any] | Any], query)
body = scope['body']
return BasicRequest(method, path, headers, query, body)
return BasicRequest(method, path, headers, query, recv)
@dataclass

View File

@@ -4,42 +4,60 @@ from typing import NamedTuple, Optional
from case_insensitive_dict import CaseInsensitiveDict
from turbosloth.exceptions import NotAcceptableException
from turbosloth.interfaces.serialized import SerializedRequest, SerializedResponse, default_serializers
from turbosloth.interfaces.serialized import SerializedRequest, SerializedResponse, default_req_serializers, \
default_resp_serializers
from turbosloth.util import parse_content_type
class SerializeChoise(NamedTuple):
req: type[SerializedRequest]
resp: type[SerializedResponse]
charset: str
type SerializeReqChoise = tuple[type[SerializedRequest], str]
type SerializeRespChoise = tuple[type[SerializedResponse], str]
class SerializeSelector:
default_content_type: Optional[str]
serializers: dict[str, tuple[type[SerializedRequest], type[SerializedResponse]]]
default_accept_type: Optional[str]
req_serializers: dict[str, type[SerializedRequest]]
resp_serializers: dict[str, type[SerializedResponse]]
def __init__(self,
default_content_type: Optional[str] = 'text/plain',
filter_content_types: Optional[list[str]] = None):
default_accept_type: Optional[str] = 'text/plain',
filter_content_types: Optional[list[str]] = None,
filter_accept_types: Optional[list[str]] = None):
self.default_content_type = default_content_type
ser = {}
self.default_accept_type = default_accept_type
req_ser = {}
resp_ser = {}
if filter_content_types is None:
filter_content_types = list(default_serializers.keys())
filter_content_types = list(default_req_serializers.keys())
for k, v in default_serializers.items():
if filter_accept_types is None:
filter_accept_types = list(default_resp_serializers.keys())
for k, v in default_req_serializers.items():
if k in filter_content_types:
ser[k] = v
self.serializers = ser
req_ser[k] = v
for k, v in default_resp_serializers.items():
if k in filter_accept_types:
resp_ser[k] = v
self.req_serializers = req_ser
self.resp_serializers = resp_ser
def select(self, contenttype: str, charset: str) -> SerializeChoise:
choise = self.serializers.get(typing.cast(str, contenttype))
if choise is None and self.default_content_type is not None:
choise = self.serializers.get(self.default_content_type)
def select_req(self, contenttype: str) -> type[SerializedRequest]:
choise = self.req_serializers.get(typing.cast(str, contenttype))
if choise is None and self.default_accept_type is not None:
choise = self.req_serializers.get(self.default_accept_type)
if choise is None:
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.serializers.keys()))
req, resp = choise
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.req_serializers.keys()))
return SerializeChoise(req, resp, charset)
return choise
def select_resp(self, accepttype: str) -> type[SerializedResponse]:
choise = self.resp_serializers.get(typing.cast(str, accepttype))
if choise is None and self.default_content_type is not None:
choise = self.resp_serializers.get(self.default_content_type)
if choise is None:
raise NotAcceptableException('acceptable content types: ' + ', '.join(self.resp_serializers.keys()))
return choise

View File

@@ -1,31 +1,44 @@
from .base import SerializedRequest, SerializedResponse
default_serializers: dict[str, tuple[type[SerializedRequest], type[SerializedResponse]]] = {}
default_req_serializers: dict[str, type[SerializedRequest]] = {}
default_resp_serializers: dict[str, type[SerializedResponse]] = {}
try:
from .text import TextSerializedRequest, TextSerializedResponse
default_serializers['text/plain'] = (TextSerializedRequest, TextSerializedResponse)
default_req_serializers['text/plain'] = TextSerializedRequest
default_resp_serializers['text/plain'] = TextSerializedResponse
except:
pass
try:
from .json import JsonSerializedRequest, JsonSerializedResponse
default_serializers['application/json'] = (JsonSerializedRequest, JsonSerializedResponse)
default_req_serializers['application/json'] = JsonSerializedRequest
default_resp_serializers['application/json'] = JsonSerializedResponse
except:
pass
try:
from .xml import XMLSerializedRequest, XMLSerializedResponse
default_serializers['application/xml'] = (XMLSerializedRequest, XMLSerializedResponse)
default_req_serializers['application/xml'] = XMLSerializedRequest
default_resp_serializers['application/xml'] = XMLSerializedResponse
except:
pass
try:
from .msgpack import MessagePackSerializedRequest, MessagePackSerializedResponse
default_serializers['application/vnd.msgpack'] = (MessagePackSerializedRequest, MessagePackSerializedResponse)
default_req_serializers['application/vnd.msgpack'] = MessagePackSerializedRequest
default_resp_serializers['application/vnd.msgpack'] = MessagePackSerializedResponse
except:
pass
try:
from .multipart_form_data import MultipartFormSerializedRequest
default_req_serializers['application/octet-stream'] = MultipartFormSerializedRequest
default_req_serializers['application/x-www-form-urlencoded'] = MultipartFormSerializedRequest
except:
pass

View File

@@ -35,7 +35,7 @@ class SerializedRequest:
return self.basic.path_matches
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
raise NotImplementedError()

View File

@@ -7,13 +7,14 @@ from .base import SerializedRequest, SerializedResponse
class JsonSerializedRequest(SerializedRequest):
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
if len(basic.body) == 0:
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
body = await basic.fetch_full_body()
if len(body) == 0:
b = None
elif charset.lower() in {'utf-8', 'utf8'}:
b = orjson.loads(basic.body)
b = orjson.loads(body)
else:
btxt = basic.body.decode(charset)
btxt = body.decode(charset)
b = orjson.loads(btxt.encode('utf-8'))
return cls(b, basic, charset)

View File

@@ -7,11 +7,12 @@ from .base import SerializedRequest, SerializedResponse
class MessagePackSerializedRequest(SerializedRequest):
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
if len(basic.body) == 0:
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
body = await basic.fetch_full_body()
if len(body) == 0:
b = None
else:
b = msgpack.unpackb(basic.body)
b = msgpack.unpackb(body)
return cls(b, basic, charset)

View File

@@ -0,0 +1,49 @@
import asyncio
from dataclasses import dataclass
from queue import Queue, Empty
from typing import AsyncIterable
import python_multipart
from case_insensitive_dict import CaseInsensitiveDict
from python_multipart import MultipartParser, FormParser
from turbosloth.interfaces.base import BasicRequest, BasicResponse
from .base import SerializedRequest, SerializedResponse
@dataclass
class MultipartFormSerializedRequest(SerializedRequest):
_parser: FormParser
_parse_q: Queue
async def _fetch_part(self) -> bool:
was_done = self.basic.is_body_recv_done
buf, _ = await self.basic.fetch_body_part()
self._parser.write(buf)
return was_done
async def __aiter__(self) -> AsyncIterable:
is_done = False
while not is_done:
try:
it = self._parse_q.get(block=False)
except Empty:
is_done = await self._fetch_part()
else:
yield it
def __init__(self, basic: BasicRequest, charset: str, ):
content_type = basic.headers.get('content-type', 'multipart/form-data')
self.basic = basic
self.body = None
self.charset = charset
self._parse_q = Queue()
self._parser = python_multipart.create_form_parser(
basic.headers,
on_field=self._parse_q.put,
on_file=self._parse_q.put,
)
@classmethod
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
return MultipartFormSerializedRequest(basic, charset)

View File

@@ -6,8 +6,9 @@ from .base import SerializedRequest, SerializedResponse
class TextSerializedRequest(SerializedRequest):
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
b = basic.body.decode(charset)
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
body = await basic.fetch_full_body()
b = body.decode(charset)
return cls(b, basic, charset)

View File

@@ -9,11 +9,12 @@ from .base import SerializedRequest, SerializedResponse
class XMLSerializedRequest(SerializedRequest):
@classmethod
def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
if len(basic.body) == 0:
async def deserialize(cls, basic: BasicRequest, charset: str) -> SerializedRequest:
body = await basic.fetch_full_body()
if len(body) == 0:
b = {}
else:
btxt = basic.body.decode(charset)
btxt = body.decode(charset)
parsed = etree.fromstring(btxt)
b = {child.tag: child.text for child in parsed}

11
uv.lock generated
View File

@@ -390,6 +390,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
]
[[package]]
name = "python-multipart"
version = "0.0.20"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158, upload-time = "2024-12-16T19:45:46.972Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" },
]
[[package]]
name = "trove-classifiers"
version = "2025.9.11.17"
@@ -409,6 +418,7 @@ dependencies = [
{ name = "jinja2" },
{ name = "megasniff" },
{ name = "mypy" },
{ name = "python-multipart" },
]
[package.dev-dependencies]
@@ -445,6 +455,7 @@ requires-dist = [
{ name = "jinja2", specifier = ">=3.1.6" },
{ name = "megasniff", specifier = ">=0.2.4", index = "https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple" },
{ name = "mypy", specifier = ">=1.17.0" },
{ name = "python-multipart", specifier = ">=0.0.20" },
]
[package.metadata.requires-dev]