from dataclasses import dataclass from breakshaft.models import ConversionPoint from src.breakshaft.convertor import ConvRepo @dataclass class A: a: int @dataclass class B: b: float @dataclass class C: c: int @dataclass class D: d: str def test_conv_point_tuple_unwrap(): def conv_into_bc(a: A) -> tuple[B, C]: return B(a.a), C(a.a) def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]: return B(a.a), (C(a.a), D(str(a.a))) def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]: return B(a.a), (C(a.a), (D(str(a.a)), a)) cps_bc = ConversionPoint.from_fn(conv_into_bc) assert len(cps_bc) == 3 # tuple[...], B, C cps_bcd = ConversionPoint.from_fn(conv_into_bcd) assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D cps_bcda = ConversionPoint.from_fn(conv_into_bcda) assert len(cps_bcda) == 6 # ignores (A,...)->A def test_ignore_basic_types(): def conv_into_b_int(a: A) -> tuple[B, int]: return B(a.a), a.a cps = ConversionPoint.from_fn(conv_into_b_int) assert len(cps) == 2 # tuple[...], B def test_codegen_tuple_unwrap(): repo = ConvRepo(store_sources=True) @repo.mark_injector() def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]: return B(a.a), (C(a.a), D(str(a.a))) type Z = A @repo.mark_injector() def conv_d_a(d: D) -> Z: return A(int(d.d)) def consumer1(dep: D) -> int: return int(dep.d) def consumer2(dep: Z) -> int: return int(dep.a) fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False) assert fn1(A(1)) == 1 fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False) assert fn2(A(1)) == 1 pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False) assert pip(A(1)) == 1 print(pip.__breakshaft_render_src__)