diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 0eb8918..1bdfc9d 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -69,6 +69,16 @@ class GraphWalker: variant_subgraphs.add(subg) if not dead_end: + + for opt, _ in point.opt_args: + subg = cls.generate_callgraph_singletype(injectors, + from_types, + opt, + visited_path=visited_path.copy(), + visited_types=visited_types.copy()) + if subg is not None: + variant_subgraphs.add(subg) + consumed = frozenset(point.requires) & from_types variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) head = head.add_subgraph_variant(variant) diff --git a/tests/test_default_args.py b/tests/test_default_args.py index 3c80b14..ce4d9af 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -76,3 +76,39 @@ def test_default_inj_args(): fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn3(123, '0') assert dep == 123 + + +def test_default_graph_override(): + + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int, opt_dep: optC = '42') -> A: + return A(i + int(opt_dep)) + + @repo.mark_injector() + def inject_opt_dep() -> optC: + return '12345' + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + 12345 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123 \ No newline at end of file