From 9d03affd4146389db754ffbe7eb9dc7b5c992918 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Tue, 19 Aug 2025 02:37:01 +0300 Subject: [PATCH] Allow passing [ConversionPoint] into create_pipeline with a type remap for ConversionPoint --- src/breakshaft/convertor.py | 12 ++++++++++-- tests/test_typehints_remap.py | 23 +++++++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/breakshaft/convertor.py b/src/breakshaft/convertor.py index 9439798..a8f24cf 100644 --- a/src/breakshaft/convertor.py +++ b/src/breakshaft/convertor.py @@ -38,7 +38,7 @@ class ConvRepo: def create_pipeline(self, from_types: Sequence[type], - fns: Sequence[Callable], + fns: Sequence[Callable | Iterable[ConversionPoint] | ConversionPoint], force_commutative: bool = True, allow_async: bool = True, allow_sync: bool = True, @@ -50,7 +50,15 @@ class ConvRepo: from_types = tuple(from_types) for fn in fns: - injects = extract_return_type(fn) + injects = None + if isinstance(fn, collections.abc.Iterable): + for f in fn: + injects = f.injects + break + elif isinstance(fn, ConversionPoint): + injects = fn.injects + else: + injects = extract_return_type(fn) callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) diff --git a/tests/test_typehints_remap.py b/tests/test_typehints_remap.py index 0c6d4d0..46d5508 100644 --- a/tests/test_typehints_remap.py +++ b/tests/test_typehints_remap.py @@ -24,13 +24,11 @@ def test_basic(): def int_to_a(i: int) -> A: return A(i) - type HackInt = int - - def consumer(dep: A) -> int: - return dep.a + def consumer(dep: A) -> B: + return B(float(dep.a)) type NewA = A - type_remap = {'dep': NewA, 'return': Annotated[HackInt, 'fuck']} + type_remap = {'dep': NewA, 'return': B} assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1 @@ -43,5 +41,18 @@ def test_basic(): fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap), force_commutative=True, force_async=False, allow_async=False) - assert fn1(42) == 42 + assert fn1(42).b == 42.0 + def consumer1(dep: B) -> A: + return A(int(dep.b)) + + p1 = repo.create_pipeline( + (int,), + [ConversionPoint.from_fn(consumer, type_remap=type_remap), consumer1, consumer], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + + assert p1(123).b == 123.0