Allow default option to be overriden if there is any conversion point that injects this default option

This commit is contained in:
2025-07-19 21:12:35 +03:00
parent f2ec4fad14
commit 69def6e74c
2 changed files with 46 additions and 0 deletions

View File

@@ -69,6 +69,16 @@ class GraphWalker:
variant_subgraphs.add(subg)
if not dead_end:
for opt, _ in point.opt_args:
subg = cls.generate_callgraph_singletype(injectors,
from_types,
opt,
visited_path=visited_path.copy(),
visited_types=visited_types.copy())
if subg is not None:
variant_subgraphs.add(subg)
consumed = frozenset(point.requires) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant)

View File

@@ -76,3 +76,39 @@ def test_default_inj_args():
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '0')
assert dep == 123
def test_default_graph_override():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int, opt_dep: optC = '42') -> A:
return A(i + int(opt_dep))
@repo.mark_injector()
def inject_opt_dep() -> optC:
return '12345'
def consumer(dep: A) -> int:
return dep.a
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == 42
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == 123 + 12345
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '0')
assert dep == 123