diff --git a/tests/test_default_args.py b/tests/test_default_args.py new file mode 100644 index 0000000..ed5b3bc --- /dev/null +++ b/tests/test_default_args.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +type optC = str + + +def test_default_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]: + return dep.a, opt_dep + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == (42, '42') + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == (123, '42') + +