Make tuple return types unwrap
This commit is contained in:
84
tests/test_tuple_unwrap.py
Normal file
84
tests/test_tuple_unwrap.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from breakshaft.models import ConversionPoint
|
||||
from src.breakshaft.convertor import ConvRepo
|
||||
|
||||
|
||||
@dataclass
|
||||
class A:
|
||||
a: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class B:
|
||||
b: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
c: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class D:
|
||||
d: str
|
||||
|
||||
|
||||
def test_conv_point_tuple_unwrap():
|
||||
def conv_into_bc(a: A) -> tuple[B, C]:
|
||||
return B(a.a), C(a.a)
|
||||
|
||||
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
|
||||
return B(a.a), (C(a.a), D(str(a.a)))
|
||||
|
||||
def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]:
|
||||
return B(a.a), (C(a.a), (D(str(a.a)), a))
|
||||
|
||||
cps_bc = ConversionPoint.from_fn(conv_into_bc)
|
||||
assert len(cps_bc) == 3 # tuple[...], B, C
|
||||
|
||||
cps_bcd = ConversionPoint.from_fn(conv_into_bcd)
|
||||
|
||||
assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D
|
||||
|
||||
cps_bcda = ConversionPoint.from_fn(conv_into_bcda)
|
||||
|
||||
assert len(cps_bcda) == 6 # ignores (A,...)->A
|
||||
|
||||
|
||||
def test_ignore_basic_types():
|
||||
def conv_into_b_int(a: A) -> tuple[B, int]:
|
||||
return B(a.a), a.a
|
||||
|
||||
cps = ConversionPoint.from_fn(conv_into_b_int)
|
||||
assert len(cps) == 2 # tuple[...], B
|
||||
|
||||
|
||||
def test_codegen_tuple_unwrap():
|
||||
repo = ConvRepo(store_sources=True)
|
||||
|
||||
@repo.mark_injector()
|
||||
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
|
||||
return B(a.a), (C(a.a), D(str(a.a)))
|
||||
|
||||
type Z = A
|
||||
|
||||
@repo.mark_injector()
|
||||
def conv_d_a(d: D) -> Z:
|
||||
return A(int(d.d))
|
||||
|
||||
def consumer1(dep: D) -> int:
|
||||
return int(dep.d)
|
||||
|
||||
def consumer2(dep: Z) -> int:
|
||||
return int(dep.a)
|
||||
|
||||
fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False)
|
||||
assert fn1(A(1)) == 1
|
||||
|
||||
fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False)
|
||||
assert fn2(A(1)) == 1
|
||||
|
||||
pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False)
|
||||
assert pip(A(1)) == 1
|
||||
print(pip.__breakshaft_render_src__)
|
||||
Reference in New Issue
Block a user