From a0de9fcda8bdbdf71eefdbd882941ea2a17c6a09 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 22:49:15 +0300 Subject: [PATCH] Make smart call deduplication --- src/breakshaft/renderer.py | 35 ++++++++++++++++++++++++----- tests/test_pipeline.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index ffe4b28..27a1244 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -51,6 +51,28 @@ class ConversionArgRenderData: typehash: str +def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]: + deduplicated_conv_models: list[ConversionRenderData] = [] + for conv_model in conversion_models: + if conv_model not in deduplicated_conv_models: + deduplicated_conv_models.append(conv_model) + continue + + argnames = list(map(lambda x: x[1], conv_model.funcargs)) + argument_changed = False + found_model = False + for m in deduplicated_conv_models: + if not found_model and m == conv_model: + found_model = True + + if found_model and m.inj_hash in argnames: + argument_changed = True + break + if argument_changed: + deduplicated_conv_models.append(conv_model) + return deduplicated_conv_models + + class InTimeGenerationConvertorRenderer(ConvertorRenderer): templateLoader: jinja2.BaseLoader templateEnv: jinja2.Environment @@ -72,7 +94,7 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): force_async: bool = False) -> Callable: fnmap = {} - conversion_models = [] + conversion_models: list[ConversionRenderData] = [] ret_hash = 0 is_async = force_async for call_id, call in enumerate(callseq): @@ -84,12 +106,13 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): fnmap[hash(call.fn)] = call.fn conv = ConversionRenderData.from_inj(call, provided_types) - if conv not in conversion_models: - conversion_models.append(conv) - if call.is_async: - is_async = True + conversion_models.append(conv) + if call.is_async: + is_async = True - ret_hash = hash(callseq[-1].injects) + conversion_models = deduplicate_callseq(conversion_models) + + ret_hash = hashname(callseq[-1].injects) conv_args = [] for i, from_type in enumerate(from_types): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 42e3565..056d54b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -60,3 +60,49 @@ def test_default_consumer_args(): ) res = p2(B(42.1)) assert res == (42, "(42, (42, '42'))") + + +def test_pipeline_with_subgraph_duplicates(): + repo = ConvRepo() + + b_to_a_calls = [0] + + @repo.mark_injector() + def b_to_a(b: B) -> A: + b_to_a_calls[0] += 1 + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + type ret1 = tuple[int, str] + + cons1_calls = [0] + cons2_calls = [0] + + def consumer1(dep: A, opt_dep: optC = '42') -> A: + cons1_calls[0] += 1 + return A(dep.a + int(opt_dep)) + + def consumer2(dep: A) -> optC: + cons2_calls[0] += 1 + return str(dep.a) + + p1 = repo.create_pipeline( + (B,), + [consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p1(B(42.1)) + assert res.a == 42 + (42 * 31) + assert b_to_a_calls[0] == 1 + assert cons1_calls[0] == 5 + assert cons2_calls[0] == 4