From d6f8038efa43c3852b8f59a27f52ad8c4f36b422 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 16 Aug 2025 18:38:46 +0300 Subject: [PATCH 1/4] Make tuple return types unwrap --- src/breakshaft/models.py | 31 +++++++-- src/breakshaft/renderer.py | 36 +++++++++- src/breakshaft/templates/convertor.jinja2 | 25 ++++++- src/breakshaft/util.py | 39 +++++++++++ tests/test_tuple_unwrap.py | 84 +++++++++++++++++++++++ 5 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 tests/test_tuple_unwrap.py diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 7a20ed5..24bff02 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -10,22 +10,25 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ is_async_context_manager_factory, \ - all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames + all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \ + get_tuple_types, is_basic_type_annot @dataclass(frozen=True) class ConversionPoint: fn: Callable injects: type + rettype: type requires: tuple[type, ...] opt_args: tuple[type, ...] def copy_with(self, **kwargs): fn = kwargs.get('fn', self.fn) + rettype = kwargs.get('rettype', self.rettype) injects = kwargs.get('injects', self.injects) requires = kwargs.get('requires', self.requires) opt_args = kwargs.get('opt_args', self.opt_args) - return ConversionPoint(fn, injects, requires, opt_args) + return ConversionPoint(fn, injects, rettype, requires, opt_args) def __hash__(self): return hash((self.fn, self.injects, self.requires)) @@ -58,9 +61,11 @@ class ConversionPoint: @classmethod def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: + + annot = get_type_hints(func) + fn_rettype = annot.get('return') if rettype is None: - annot = get_type_hints(func) - rettype = annot.get('return') + rettype = fn_rettype if rettype is None: raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified') @@ -79,6 +84,18 @@ class ConversionPoint: if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): rettype = get_args(rettype)[0] + if is_basic_type_annot(rettype): + return [] + + ret = [] + + tuple_unwrapped = get_tuple_types(rettype) + # Do not unwrap elipsis, but unwrap non-empty tuples + if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped: + for t in tuple_unwrapped: + if not is_basic_type_annot(t): + ret += ConversionPoint.from_fn(func, rettype=t) + argtypes: list[list[type]] = [] orig_args = extract_func_args(func) defaults = extract_func_arg_defaults(func) @@ -97,7 +114,7 @@ class ConversionPoint: argtypes.append(u_types) argtype_combinations = all_combinations(argtypes) - ret = [] + for argtype_combination in argtype_combinations: req_args = [] opt_args = [] @@ -106,7 +123,9 @@ class ConversionPoint: opt_args.append(argt) else: req_args.append(argt) - ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args))) + if rettype in req_args: + continue + ret.append(ConversionPoint(func, rettype, fn_rettype, tuple(req_args), tuple(opt_args))) return ret diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index a29c87c..79a7aa3 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -7,7 +7,7 @@ import importlib.resources import jinja2 from .models import ConversionPoint -from .util import hashname +from .util import hashname, get_tuple_types, is_basic_type_annot class ConvertorRenderer(Protocol): @@ -19,6 +19,29 @@ class ConvertorRenderer(Protocol): raise NotImplementedError() +type UnwprappedTuple = tuple[tuple[UnwprappedTuple, str] | str | None, ...] + + +def unwrap_tuple_type(typ: type) -> UnwprappedTuple: + unwrap_tuple_result = () + tuple_types = get_tuple_types(typ) + if len(tuple_types) > 0 and Ellipsis not in tuple_types: + for t in tuple_types: + if not is_basic_type_annot(t): + subtuple = unwrap_tuple_type(t) + hn = hashname(t) + if len(subtuple) > 0: + unwrap_tuple_result += ((subtuple, hn),) + else: + unwrap_tuple_result += (hn,) + else: + unwrap_tuple_result += (None,) + + if not any(map(lambda x: x is not None, unwrap_tuple_result)): + return () + return unwrap_tuple_result + + @dataclass class ConversionRenderData: inj_hash: str @@ -27,6 +50,7 @@ class ConversionRenderData: funcargs: list[tuple[str, str]] is_ctxmanager: bool is_async: bool + unwrap_tuple_result: UnwprappedTuple @classmethod def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): @@ -42,7 +66,15 @@ class ConversionRenderData: if argtype in provided_types: fnargs.append((argname, hashname(argtype))) - return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async) + unwrap_tuple_result = unwrap_tuple_type(inj.rettype) + + return cls(hashname(inj.rettype), + hashname(inj.fn), + repr(inj.fn), + fnargs, + inj.is_ctx_manager, + inj.is_async, + unwrap_tuple_result) @dataclass diff --git a/src/breakshaft/templates/convertor.jinja2 b/src/breakshaft/templates/convertor.jinja2 index 1bc3189..2538e74 100644 --- a/src/breakshaft/templates/convertor.jinja2 +++ b/src/breakshaft/templates/convertor.jinja2 @@ -1,13 +1,34 @@ {% set ns = namespace(indent=0) %} + +{% macro unwrap_tuple(tupl, unwrap_name) -%} +{%- set out -%} +{% if tupl | length > 0 %} +{% for t in tupl %} +{% if t is string %} +_{{t}} = _{{unwrap_name}}[{{loop.index0}}] +{% endif %} +{% if t.__class__.__name__ == 'tuple' %} +_{{t[1]}} = _{{unwrap_name}}[{{loop.index0}}] +{{unwrap_tuple(t[0], t[1])}} +{% endif %} +{% endfor %} + +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + {% if is_async %}async {% endif %}def convertor({% for arg in conv_args %}_{{arg.typehash}}: "{{arg.typename}}",{% endfor %}){% if rettype %} -> '{{rettype}}'{% endif %}: {% for conv in conversions %} {% if conv.is_ctxmanager %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}: + {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}}=_{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}: {% set ns.indent = ns.indent + 1 %} {% else %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) + {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}}=_{{conv_arg[1]}}, {% endfor %}) {% endif %} +{{unwrap_tuple(conv.unwrap_tuple_result, conv.inj_hash) | indent((ns.indent + 1) * 4)}} {% endfor %} {{ ' ' * ns.indent }}return _{{ret_hash}} diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index 6668253..deba74b 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -1,4 +1,5 @@ import inspect +import typing from itertools import product from typing import Callable, get_type_hints, TypeVar, Any, Optional @@ -88,3 +89,41 @@ T = TypeVar('T') def all_combinations(options: list[list[T]]) -> list[list[T]]: return [list(comb) for comb in product(*options)] + + +def get_tuple_types(type_obj: type) -> tuple: + ret = () + + origin = getattr(type_obj, '__origin__', None) + if origin is tuple: + args = getattr(type_obj, '__args__', ()) + ret = args if args else () + + return ret + + +def is_basic_type_annot(type_annot) -> bool: + basic_types = { + int, float, str, bool, complex, + list, dict, tuple, set, frozenset, + bytes, bytearray, memoryview, + type(None), object + } + + origin = getattr(type_annot, '__origin__', None) + args = getattr(type_annot, '__args__', None) + + if type_annot in basic_types: + return True + + if origin is not None: + if origin in basic_types or origin in {list, dict, tuple, set, frozenset}: + if args: + return all(is_basic_type_annot(arg) for arg in args) + return True + return False + + if origin is typing.Union: + return all(is_basic_type_annot(arg) for arg in args) + + return False diff --git a/tests/test_tuple_unwrap.py b/tests/test_tuple_unwrap.py new file mode 100644 index 0000000..10db0f5 --- /dev/null +++ b/tests/test_tuple_unwrap.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass + +from breakshaft.models import ConversionPoint +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +@dataclass +class C: + c: int + + +@dataclass +class D: + d: str + + +def test_conv_point_tuple_unwrap(): + def conv_into_bc(a: A) -> tuple[B, C]: + return B(a.a), C(a.a) + + def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]: + return B(a.a), (C(a.a), D(str(a.a))) + + def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]: + return B(a.a), (C(a.a), (D(str(a.a)), a)) + + cps_bc = ConversionPoint.from_fn(conv_into_bc) + assert len(cps_bc) == 3 # tuple[...], B, C + + cps_bcd = ConversionPoint.from_fn(conv_into_bcd) + + assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D + + cps_bcda = ConversionPoint.from_fn(conv_into_bcda) + + assert len(cps_bcda) == 6 # ignores (A,...)->A + + +def test_ignore_basic_types(): + def conv_into_b_int(a: A) -> tuple[B, int]: + return B(a.a), a.a + + cps = ConversionPoint.from_fn(conv_into_b_int) + assert len(cps) == 2 # tuple[...], B + + +def test_codegen_tuple_unwrap(): + repo = ConvRepo(store_sources=True) + + @repo.mark_injector() + def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]: + return B(a.a), (C(a.a), D(str(a.a))) + + type Z = A + + @repo.mark_injector() + def conv_d_a(d: D) -> Z: + return A(int(d.d)) + + def consumer1(dep: D) -> int: + return int(dep.d) + + def consumer2(dep: Z) -> int: + return int(dep.a) + + fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False) + assert fn1(A(1)) == 1 + + fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False) + assert fn2(A(1)) == 1 + + pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False) + assert pip(A(1)) == 1 + print(pip.__breakshaft_render_src__) -- 2.49.1 From 3150c4b2d0deaae0ba2931bd68a6480dedd8d662 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 16 Aug 2025 18:44:58 +0300 Subject: [PATCH 2/4] Fix ctxmanager injects hash --- src/breakshaft/models.py | 3 +++ src/breakshaft/renderer.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 24bff02..abd4b2a 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -71,6 +71,7 @@ class ConversionPoint: raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified') rettype_origin = get_origin(rettype) + fn_rettype_origin = get_origin(fn_rettype) cm_out_origins = [ typing.Generator, typing.Iterator, @@ -83,6 +84,8 @@ class ConversionPoint: ] if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): rettype = get_args(rettype)[0] + if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): + fn_rettype = get_args(fn_rettype)[0] if is_basic_type_annot(rettype): return [] diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index 79a7aa3..473686c 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -146,7 +146,7 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): conversion_models = deduplicate_callseq(conversion_models) - ret_hash = hashname(callseq[-1].injects) + ret_hash = hashname(callseq[-1].rettype) conv_args = [] for i, from_type in enumerate(from_types): -- 2.49.1 From fd8026a2a5efa89e54f4d2c25d4379cd7051f6d6 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 16 Aug 2025 18:45:55 +0300 Subject: [PATCH 3/4] Update `README.md`: sync feature list --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d5ec52d..d05b39c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ - Поддерживает `Union`-типы в зависимостях - Учитывает default-параметры - Позволяет выстраивать конвейеры преобразований +- Опционально разворачивает кортежи в возвращаемых значениях #### Ограничения библиотеки: - Выбор графа преобразований вызывает комбинаторный взрыв -- 2.49.1 From 742c21e1994c93e8a729764aa4cb6681ca2580d0 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 16 Aug 2025 18:46:06 +0300 Subject: [PATCH 4/4] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7c88694..a55262a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "breakshaft" -version = "0.1.3.post2" +version = "0.1.4" description = "Library for in-time codegen for type conversion" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } -- 2.49.1