From b058a701a03a53c767b68cc6e8acdd389e2fbce9 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 22:32:40 +0300 Subject: [PATCH 1/3] Add basic pipeline construction, callseq deduplication pending --- src/breakshaft/convertor.py | 83 +++++++++++++++++++++++++--------- src/breakshaft/graph_walker.py | 3 +- src/breakshaft/models.py | 7 +++ src/breakshaft/util.py | 7 ++- tests/test_pipeline.py | 62 +++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 23 deletions(-) create mode 100644 tests/test_pipeline.py diff --git a/src/breakshaft/convertor.py b/src/breakshaft/convertor.py index c45f350..6089871 100644 --- a/src/breakshaft/convertor.py +++ b/src/breakshaft/convertor.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any +from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence from .graph_walker import GraphWalker from .models import ConversionPoint, Callgraph from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer +from .util import extract_return_type Tin = TypeVarTuple('Tin') Tout = TypeVar('Tout') @@ -27,6 +28,31 @@ class ConvRepo: self.walker = graph_walker self.renderer = renderer + def create_pipeline(self, + from_types: Sequence[type], + fns: Sequence[Callable], + force_commutative: bool = True, + allow_async: bool = True, + allow_sync: bool = True, + force_async: bool = False + ): + filtered_injectors = self.filtered_injectors(allow_async, allow_sync) + pipeline_callseq = [] + orig_from_types = tuple(from_types) + from_types = tuple(from_types) + + for fn in fns: + injects = extract_return_type(fn) + + callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) + + pipeline_callseq += callseq + + if injects is not None: + from_types += (injects,) + + return self.renderer.render(orig_from_types, pipeline_callseq, force_async=force_async) + @property def convertor_set(self): return self._convertor_set @@ -46,30 +72,26 @@ class ConvRepo: ret += [variant.injector] return ret - def get_conversion(self, - from_types: tuple[type[Unpack[Tin]]], - fn: Callable[..., Tout], - force_commutative: bool = True, - allow_async: bool = True, - allow_sync: bool = True, - force_async: bool = False - ) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: - if not allow_async or force_async: - filtered_injectors: frozenset[ConversionPoint] = frozenset() - for inj in self.convertor_set: - if inj.is_async and not allow_async: - continue - if not inj.is_async and not allow_sync: - continue - filtered_injectors |= {inj} - else: - filtered_injectors = frozenset(self.convertor_set) + def filtered_injectors(self, allow_async: bool, allow_sync: bool) -> frozenset[ConversionPoint]: + filtered_injectors: frozenset[ConversionPoint] = frozenset() + for inj in self.convertor_set: + if inj.is_async and not allow_async: + continue + if not inj.is_async and not allow_sync: + continue + filtered_injectors |= {inj} + return filtered_injectors - cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn) + def get_callseq(self, + injectors: frozenset[ConversionPoint], + from_types: frozenset[type], fn: Callable, + force_commutative: bool) -> list[ConversionPoint]: + + cg = self.walker.generate_callgraph(injectors, from_types, fn) if cg is None: raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') - exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types)) + exploded = self.walker.explode_callgraph_branches(cg, from_types) selected = self.walker.filter_exploded_callgraph_branch(exploded) if len(selected) == 0: @@ -79,6 +101,25 @@ class ConvRepo: raise ValueError('Conversion path is not commutative') callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) + + if len(callseq) > 0: + injects = extract_return_type(fn) + callseq[-1] = callseq[-1].copy_with(injects=injects) + + return callseq + + def get_conversion(self, + from_types: Sequence[type[Unpack[Tin]]], + fn: Callable[..., Tout], + force_commutative: bool = True, + allow_async: bool = True, + allow_sync: bool = True, + force_async: bool = False + ) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: + + filtered_injectors = self.filtered_injectors(allow_async, allow_sync) + callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) + return self.renderer.render(from_types, callseq, force_async=force_async) def mark_injector(self, *, rettype: Optional[type] = None): diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 2aa0aac..b7ff4aa 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -3,7 +3,7 @@ from types import NoneType from typing import Callable, Optional from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection -from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq +from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type class GraphWalker: @@ -15,6 +15,7 @@ class GraphWalker: consumer_fn: Callable) -> Optional[Callgraph]: branches: frozenset[Callgraph] = frozenset() + rettype = extract_return_type(consumer_fn) # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index c49fc91..8ce80d9 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -20,6 +20,13 @@ class ConversionPoint: requires: tuple[type, ...] opt_args: tuple[type, ...] + def copy_with(self, **kwargs): + fn = kwargs.get('fn', self.fn) + injects = kwargs.get('injects', self.injects) + requires = kwargs.get('requires', self.requires) + opt_args = kwargs.get('opt_args', self.opt_args) + return ConversionPoint(fn, injects, requires, opt_args) + def __hash__(self): return hash((self.fn, self.injects, self.requires)) diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index bfdae60..b343f76 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -1,6 +1,11 @@ import inspect from itertools import product -from typing import Callable, get_type_hints, TypeVar, Any +from typing import Callable, get_type_hints, TypeVar, Any, Optional + + +def extract_return_type(func: Callable) -> Optional[type]: + hints = get_type_hints(func) + return hints.get('return') def extract_func_args(func: Callable) -> list[tuple[str, type]]: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..42e3565 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass + +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +type optC = str + + +def test_default_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + type ret1 = tuple[int, str] + + def consumer1(dep: A, opt_dep: optC = '42') -> ret1: + return dep.a, opt_dep + + def consumer2(dep: A, dep1: ret1) -> optC: + return str((dep.a, dep1)) + + p1 = repo.create_pipeline( + (B,), + [consumer1, consumer2], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p1(B(42.1)) + assert res == "(42, (42, '42'))" + + p2 = repo.create_pipeline( + (B,), + [consumer1, consumer2, consumer1], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p2(B(42.1)) + assert res == (42, "(42, (42, '42'))") -- 2.49.1 From a0de9fcda8bdbdf71eefdbd882941ea2a17c6a09 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 22:49:15 +0300 Subject: [PATCH 2/3] Make smart call deduplication --- src/breakshaft/renderer.py | 35 ++++++++++++++++++++++++----- tests/test_pipeline.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index ffe4b28..27a1244 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -51,6 +51,28 @@ class ConversionArgRenderData: typehash: str +def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]: + deduplicated_conv_models: list[ConversionRenderData] = [] + for conv_model in conversion_models: + if conv_model not in deduplicated_conv_models: + deduplicated_conv_models.append(conv_model) + continue + + argnames = list(map(lambda x: x[1], conv_model.funcargs)) + argument_changed = False + found_model = False + for m in deduplicated_conv_models: + if not found_model and m == conv_model: + found_model = True + + if found_model and m.inj_hash in argnames: + argument_changed = True + break + if argument_changed: + deduplicated_conv_models.append(conv_model) + return deduplicated_conv_models + + class InTimeGenerationConvertorRenderer(ConvertorRenderer): templateLoader: jinja2.BaseLoader templateEnv: jinja2.Environment @@ -72,7 +94,7 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): force_async: bool = False) -> Callable: fnmap = {} - conversion_models = [] + conversion_models: list[ConversionRenderData] = [] ret_hash = 0 is_async = force_async for call_id, call in enumerate(callseq): @@ -84,12 +106,13 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): fnmap[hash(call.fn)] = call.fn conv = ConversionRenderData.from_inj(call, provided_types) - if conv not in conversion_models: - conversion_models.append(conv) - if call.is_async: - is_async = True + conversion_models.append(conv) + if call.is_async: + is_async = True - ret_hash = hash(callseq[-1].injects) + conversion_models = deduplicate_callseq(conversion_models) + + ret_hash = hashname(callseq[-1].injects) conv_args = [] for i, from_type in enumerate(from_types): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 42e3565..056d54b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -60,3 +60,49 @@ def test_default_consumer_args(): ) res = p2(B(42.1)) assert res == (42, "(42, (42, '42'))") + + +def test_pipeline_with_subgraph_duplicates(): + repo = ConvRepo() + + b_to_a_calls = [0] + + @repo.mark_injector() + def b_to_a(b: B) -> A: + b_to_a_calls[0] += 1 + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + type ret1 = tuple[int, str] + + cons1_calls = [0] + cons2_calls = [0] + + def consumer1(dep: A, opt_dep: optC = '42') -> A: + cons1_calls[0] += 1 + return A(dep.a + int(opt_dep)) + + def consumer2(dep: A) -> optC: + cons2_calls[0] += 1 + return str(dep.a) + + p1 = repo.create_pipeline( + (B,), + [consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p1(B(42.1)) + assert res.a == 42 + (42 * 31) + assert b_to_a_calls[0] == 1 + assert cons1_calls[0] == 5 + assert cons2_calls[0] == 4 -- 2.49.1 From 66241cd01a5dd70ec0f5730a9e58a79a652ebfe2 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Mon, 21 Jul 2025 15:31:26 +0300 Subject: [PATCH 3/3] Update `README.md`: add pipeline descripiton --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/README.md b/README.md index 93fe6a7..c307967 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,49 @@ assert tst == 1 ``` +---- + +#### Сборка конвейеров преобразований: + +Пусть, имеется несколько методов-потребителей, которые необходимо вызывать последовательно: + +```python + +from breakshaft.convertor import ConvRepo + +repo = ConvRepo() + +# Объявляем A и B, а также методы преобразований - как в прошлом примере + +type cons2ret = str # избегаем использования builtin-типов, чтобы избежать простых коллизий + + +def consumer1(dep: A) -> B: + return B(float(42)) + + +def consumer2(dep: B) -> cons2ret: + return str(dep.b) + + +def consumer3(dep: cons2ret) -> int: + return int(float(dep)) + + +pipeline = repo.create_pipeline( + (B,), + [consumer1, consumer2, consumer3], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False +) + +dat = pipeline(B(42)) +assert dat == 42 +``` + + ---- #### Как получить граф преобразований: -- 2.49.1