Files
breakshaft/tests/test_tuple_unwrap.py

85 lines
2.0 KiB
Python

from dataclasses import dataclass
from breakshaft.models import ConversionPoint
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
@dataclass
class C:
c: int
@dataclass
class D:
d: str
def test_conv_point_tuple_unwrap():
def conv_into_bc(a: A) -> tuple[B, C]:
return B(a.a), C(a.a)
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
return B(a.a), (C(a.a), D(str(a.a)))
def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]:
return B(a.a), (C(a.a), (D(str(a.a)), a))
cps_bc = ConversionPoint.from_fn(conv_into_bc)
assert len(cps_bc) == 3 # tuple[...], B, C
cps_bcd = ConversionPoint.from_fn(conv_into_bcd)
assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D
cps_bcda = ConversionPoint.from_fn(conv_into_bcda)
assert len(cps_bcda) == 6 # ignores (A,...)->A
def test_ignore_basic_types():
def conv_into_b_int(a: A) -> tuple[B, int]:
return B(a.a), a.a
cps = ConversionPoint.from_fn(conv_into_b_int)
assert len(cps) == 2 # tuple[...], B
def test_codegen_tuple_unwrap():
repo = ConvRepo(store_sources=True)
@repo.mark_injector()
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
return B(a.a), (C(a.a), D(str(a.a)))
type Z = A
@repo.mark_injector()
def conv_d_a(d: D) -> Z:
return A(int(d.d))
def consumer1(dep: D) -> int:
return int(dep.d)
def consumer2(dep: Z) -> int:
return int(dep.a)
fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False)
assert fn1(A(1)) == 1
fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False)
assert fn2(A(1)) == 1
pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False)
assert pip(A(1)) == 1
print(pip.__breakshaft_render_src__)