From a2cf1bb6e658a00cce93f375c083af1451c7e789 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:35:27 +0300 Subject: [PATCH] Get rid of manual consumer fn unwrapping for callgraph generation --- src/breakshaft/graph_walker.py | 19 ++++++++----------- src/breakshaft/models.py | 2 +- tests/test_basic.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 7c1d20c..0eb8918 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -14,19 +14,16 @@ class GraphWalker: from_types: frozenset[type], consumer_fn: Callable) -> Optional[Callgraph]: - into_types: frozenset[type] = extract_func_argtypes(consumer_fn) - branches: frozenset[Callgraph] = frozenset() - for into_type in into_types: - cg = cls.generate_callgraph_singletype(injectors, from_types, into_type) - if cg is None: - return None - branches |= {cg} - variant = CallgraphVariant( - ConversionPoint.from_fn(consumer_fn, NoneType)[0], - branches, frozenset()) - return Callgraph(frozenset({variant})) + # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer + # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого + # При этом, TypeAliasType также выступает в роли ключа преобразования + # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований + type _tmp_type_for_consumer = object + injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer)) + + return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer) @classmethod def generate_callgraph_singletype(cls, diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 591c30d..fdd1012 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -39,7 +39,7 @@ class ConversionPoint: return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) @classmethod - def from_fn(cls, func: Callable, rettype: Optional[type] = None): + def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: if rettype is None: annot = get_type_hints(func) rettype = annot.get('return') diff --git a/tests/test_basic.py b/tests/test_basic.py index b29a7df..e6f7972 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -38,3 +38,33 @@ def test_basic(): fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn2(123) assert dep == 123 + + +def test_union_deps(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A | B) -> int: + if isinstance(dep, A): + return dep.a + else: + return int(dep.b) + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123