85 lines
2.0 KiB
Python
85 lines
2.0 KiB
Python
from dataclasses import dataclass
|
|
|
|
from breakshaft.models import ConversionPoint
|
|
from src.breakshaft.convertor import ConvRepo
|
|
|
|
|
|
@dataclass
|
|
class A:
|
|
a: int
|
|
|
|
|
|
@dataclass
|
|
class B:
|
|
b: float
|
|
|
|
|
|
@dataclass
|
|
class C:
|
|
c: int
|
|
|
|
|
|
@dataclass
|
|
class D:
|
|
d: str
|
|
|
|
|
|
def test_conv_point_tuple_unwrap():
|
|
def conv_into_bc(a: A) -> tuple[B, C]:
|
|
return B(a.a), C(a.a)
|
|
|
|
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
|
|
return B(a.a), (C(a.a), D(str(a.a)))
|
|
|
|
def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]:
|
|
return B(a.a), (C(a.a), (D(str(a.a)), a))
|
|
|
|
cps_bc = ConversionPoint.from_fn(conv_into_bc)
|
|
assert len(cps_bc) == 3 # tuple[...], B, C
|
|
|
|
cps_bcd = ConversionPoint.from_fn(conv_into_bcd)
|
|
|
|
assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D
|
|
|
|
cps_bcda = ConversionPoint.from_fn(conv_into_bcda)
|
|
|
|
assert len(cps_bcda) == 6 # ignores (A,...)->A
|
|
|
|
|
|
def test_ignore_basic_types():
|
|
def conv_into_b_int(a: A) -> tuple[B, int]:
|
|
return B(a.a), a.a
|
|
|
|
cps = ConversionPoint.from_fn(conv_into_b_int)
|
|
assert len(cps) == 2 # tuple[...], B
|
|
|
|
|
|
def test_codegen_tuple_unwrap():
|
|
repo = ConvRepo(store_sources=True)
|
|
|
|
@repo.mark_injector()
|
|
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
|
|
return B(a.a), (C(a.a), D(str(a.a)))
|
|
|
|
type Z = A
|
|
|
|
@repo.mark_injector()
|
|
def conv_d_a(d: D) -> Z:
|
|
return A(int(d.d))
|
|
|
|
def consumer1(dep: D) -> int:
|
|
return int(dep.d)
|
|
|
|
def consumer2(dep: Z) -> int:
|
|
return int(dep.a)
|
|
|
|
fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False)
|
|
assert fn1(A(1)) == 1
|
|
|
|
fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False)
|
|
assert fn2(A(1)) == 1
|
|
|
|
pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False)
|
|
assert pip(A(1)) == 1
|
|
print(pip.__breakshaft_render_src__)
|