From 74d78b19572fd54c13130b87011853cdef179d69 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Fri, 17 Oct 2025 00:51:29 +0300 Subject: [PATCH] Fix callseq deduplication error, allow using Some|None=None args with no commutativity error, add `ignore_basictypes_return` for a `ConversionPoint.from_fn` --- pyproject.toml | 2 +- src/breakshaft/graph_walker.py | 2 +- src/breakshaft/models.py | 10 +++++++--- src/breakshaft/renderer.py | 2 +- tests/test_default_args.py | 34 ++++++++++++++++++++++++++++++++-- tests/test_pipeline.py | 12 +++++++++++- 6 files changed, 53 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 561b382..d3db4d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "breakshaft" -version = "0.1.6.post4" +version = "0.1.6.post5" description = "Library for in-time codegen for type conversion" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 768c7d5..4666f88 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -90,7 +90,7 @@ class GraphWalker: if subg is not None: variant_subgraphs.add(subg) - consumed = frozenset(point.requires) & from_types + consumed = (frozenset(point.requires) | frozenset(point.opt_args)) & from_types variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) head = head.add_subgraph_variant(variant) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 8aab9e0..c4a4c60 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -56,7 +56,8 @@ class ConversionPoint: def from_fn(cls, func: Callable, rettype: Optional[type] = None, - type_remap: Optional[dict[str, type]] = None) -> list[ConversionPoint]: + type_remap: Optional[dict[str, type]] = None, + ignore_basictype_return: bool = False) -> list[ConversionPoint]: if type_remap is None: annot = get_type_hints(func) else: @@ -86,7 +87,7 @@ class ConversionPoint: if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): fn_rettype = get_args(fn_rettype)[0] - if is_basic_type_annot(rettype): + if not ignore_basictype_return and is_basic_type_annot(rettype): return [] ret = [] @@ -96,7 +97,10 @@ class ConversionPoint: if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped: for t in tuple_unwrapped: if not is_basic_type_annot(t): - ret += ConversionPoint.from_fn(func, rettype=t, type_remap=type_remap) + ret += ConversionPoint.from_fn(func, + rettype=t, + type_remap=type_remap, + ignore_basictype_return=ignore_basictype_return) argtypes: list[list[type]] = [] orig_args = extract_func_args(func, type_remap) diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index 4a38f0f..5928ab1 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -99,7 +99,7 @@ def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[C argument_changed = False found_model = False for m in deduplicated_conv_models: - if not found_model and m == conv_model: + if not found_model and m.funchash == conv_model.funchash: found_model = True if found_model and m.inj_hash in argnames: diff --git a/tests/test_default_args.py b/tests/test_default_args.py index ce4d9af..de3ef12 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -47,6 +47,37 @@ def test_default_consumer_args(): assert dep == (123, '1') +def test_optional_default_none_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B | None = None) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B | None: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]: + return dep.a, opt_dep + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == (42, '42') + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == (123, '42') + + fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '1') + assert dep == (123, '1') + + def test_default_inj_args(): repo = ConvRepo() @@ -79,7 +110,6 @@ def test_default_inj_args(): def test_default_graph_override(): - repo = ConvRepo() @repo.mark_injector() @@ -111,4 +141,4 @@ def test_default_graph_override(): fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn3(123, '0') - assert dep == 123 \ No newline at end of file + assert dep == 123 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 056d54b..aec00fc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -17,7 +17,7 @@ type optC = str def test_default_consumer_args(): - repo = ConvRepo() + repo = ConvRepo(store_sources=True) @repo.mark_injector() def b_to_a(b: B) -> A: @@ -106,3 +106,13 @@ def test_pipeline_with_subgraph_duplicates(): assert b_to_a_calls[0] == 1 assert cons1_calls[0] == 5 assert cons2_calls[0] == 4 + + +def convertor(_5891515089754: ""): + # .b_to_a at 0x7f5bb1be02c0> + _5891515089643 = _conv_funcmap[8751987548204](b=_5891515089754) + # .consumer1 at 0x7f5bb1be0c20> + _8751987542640 = _conv_funcmap[8751987548354](dep=_5891515089643) + # .consumer2 at 0x7f5bb1be0540> + _8751987537115 = _conv_funcmap[8751987548244](dep=_5891515089643, dep1=_8751987542640) + return _8751987542640 \ No newline at end of file