Allow passing [ConversionPoint] into create_pipeline with a type remap for ConversionPoint
This commit is contained in:
@@ -38,7 +38,7 @@ class ConvRepo:
|
|||||||
|
|
||||||
def create_pipeline(self,
|
def create_pipeline(self,
|
||||||
from_types: Sequence[type],
|
from_types: Sequence[type],
|
||||||
fns: Sequence[Callable],
|
fns: Sequence[Callable | Iterable[ConversionPoint] | ConversionPoint],
|
||||||
force_commutative: bool = True,
|
force_commutative: bool = True,
|
||||||
allow_async: bool = True,
|
allow_async: bool = True,
|
||||||
allow_sync: bool = True,
|
allow_sync: bool = True,
|
||||||
@@ -50,6 +50,14 @@ class ConvRepo:
|
|||||||
from_types = tuple(from_types)
|
from_types = tuple(from_types)
|
||||||
|
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
|
injects = None
|
||||||
|
if isinstance(fn, collections.abc.Iterable):
|
||||||
|
for f in fn:
|
||||||
|
injects = f.injects
|
||||||
|
break
|
||||||
|
elif isinstance(fn, ConversionPoint):
|
||||||
|
injects = fn.injects
|
||||||
|
else:
|
||||||
injects = extract_return_type(fn)
|
injects = extract_return_type(fn)
|
||||||
|
|
||||||
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
|
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
|
||||||
|
|||||||
@@ -24,13 +24,11 @@ def test_basic():
|
|||||||
def int_to_a(i: int) -> A:
|
def int_to_a(i: int) -> A:
|
||||||
return A(i)
|
return A(i)
|
||||||
|
|
||||||
type HackInt = int
|
def consumer(dep: A) -> B:
|
||||||
|
return B(float(dep.a))
|
||||||
def consumer(dep: A) -> int:
|
|
||||||
return dep.a
|
|
||||||
|
|
||||||
type NewA = A
|
type NewA = A
|
||||||
type_remap = {'dep': NewA, 'return': Annotated[HackInt, 'fuck']}
|
type_remap = {'dep': NewA, 'return': B}
|
||||||
|
|
||||||
assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1
|
assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1
|
||||||
|
|
||||||
@@ -43,5 +41,18 @@ def test_basic():
|
|||||||
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
|
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
|
||||||
force_commutative=True, force_async=False, allow_async=False)
|
force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
|
||||||
assert fn1(42) == 42
|
assert fn1(42).b == 42.0
|
||||||
|
|
||||||
|
def consumer1(dep: B) -> A:
|
||||||
|
return A(int(dep.b))
|
||||||
|
|
||||||
|
p1 = repo.create_pipeline(
|
||||||
|
(int,),
|
||||||
|
[ConversionPoint.from_fn(consumer, type_remap=type_remap), consumer1, consumer],
|
||||||
|
force_commutative=True,
|
||||||
|
allow_sync=True,
|
||||||
|
allow_async=False,
|
||||||
|
force_async=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert p1(123).b == 123.0
|
||||||
|
|||||||
Reference in New Issue
Block a user