diff --git a/src/breakshaft/convertor.py b/src/breakshaft/convertor.py index c45f350..6089871 100644 --- a/src/breakshaft/convertor.py +++ b/src/breakshaft/convertor.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any +from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence from .graph_walker import GraphWalker from .models import ConversionPoint, Callgraph from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer +from .util import extract_return_type Tin = TypeVarTuple('Tin') Tout = TypeVar('Tout') @@ -27,6 +28,31 @@ class ConvRepo: self.walker = graph_walker self.renderer = renderer + def create_pipeline(self, + from_types: Sequence[type], + fns: Sequence[Callable], + force_commutative: bool = True, + allow_async: bool = True, + allow_sync: bool = True, + force_async: bool = False + ): + filtered_injectors = self.filtered_injectors(allow_async, allow_sync) + pipeline_callseq = [] + orig_from_types = tuple(from_types) + from_types = tuple(from_types) + + for fn in fns: + injects = extract_return_type(fn) + + callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) + + pipeline_callseq += callseq + + if injects is not None: + from_types += (injects,) + + return self.renderer.render(orig_from_types, pipeline_callseq, force_async=force_async) + @property def convertor_set(self): return self._convertor_set @@ -46,30 +72,26 @@ class ConvRepo: ret += [variant.injector] return ret - def get_conversion(self, - from_types: tuple[type[Unpack[Tin]]], - fn: Callable[..., Tout], - force_commutative: bool = True, - allow_async: bool = True, - allow_sync: bool = True, - force_async: bool = False - ) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: - if not allow_async or force_async: - filtered_injectors: frozenset[ConversionPoint] = frozenset() - for inj in self.convertor_set: - if inj.is_async and not allow_async: - continue - if not inj.is_async and not allow_sync: - continue - filtered_injectors |= {inj} - else: - filtered_injectors = frozenset(self.convertor_set) + def filtered_injectors(self, allow_async: bool, allow_sync: bool) -> frozenset[ConversionPoint]: + filtered_injectors: frozenset[ConversionPoint] = frozenset() + for inj in self.convertor_set: + if inj.is_async and not allow_async: + continue + if not inj.is_async and not allow_sync: + continue + filtered_injectors |= {inj} + return filtered_injectors - cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn) + def get_callseq(self, + injectors: frozenset[ConversionPoint], + from_types: frozenset[type], fn: Callable, + force_commutative: bool) -> list[ConversionPoint]: + + cg = self.walker.generate_callgraph(injectors, from_types, fn) if cg is None: raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') - exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types)) + exploded = self.walker.explode_callgraph_branches(cg, from_types) selected = self.walker.filter_exploded_callgraph_branch(exploded) if len(selected) == 0: @@ -79,6 +101,25 @@ class ConvRepo: raise ValueError('Conversion path is not commutative') callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) + + if len(callseq) > 0: + injects = extract_return_type(fn) + callseq[-1] = callseq[-1].copy_with(injects=injects) + + return callseq + + def get_conversion(self, + from_types: Sequence[type[Unpack[Tin]]], + fn: Callable[..., Tout], + force_commutative: bool = True, + allow_async: bool = True, + allow_sync: bool = True, + force_async: bool = False + ) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: + + filtered_injectors = self.filtered_injectors(allow_async, allow_sync) + callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) + return self.renderer.render(from_types, callseq, force_async=force_async) def mark_injector(self, *, rettype: Optional[type] = None): diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 2aa0aac..b7ff4aa 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -3,7 +3,7 @@ from types import NoneType from typing import Callable, Optional from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection -from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq +from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type class GraphWalker: @@ -15,6 +15,7 @@ class GraphWalker: consumer_fn: Callable) -> Optional[Callgraph]: branches: frozenset[Callgraph] = frozenset() + rettype = extract_return_type(consumer_fn) # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index c49fc91..8ce80d9 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -20,6 +20,13 @@ class ConversionPoint: requires: tuple[type, ...] opt_args: tuple[type, ...] + def copy_with(self, **kwargs): + fn = kwargs.get('fn', self.fn) + injects = kwargs.get('injects', self.injects) + requires = kwargs.get('requires', self.requires) + opt_args = kwargs.get('opt_args', self.opt_args) + return ConversionPoint(fn, injects, requires, opt_args) + def __hash__(self): return hash((self.fn, self.injects, self.requires)) diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index bfdae60..b343f76 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -1,6 +1,11 @@ import inspect from itertools import product -from typing import Callable, get_type_hints, TypeVar, Any +from typing import Callable, get_type_hints, TypeVar, Any, Optional + + +def extract_return_type(func: Callable) -> Optional[type]: + hints = get_type_hints(func) + return hints.get('return') def extract_func_args(func: Callable) -> list[tuple[str, type]]: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..42e3565 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass + +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +type optC = str + + +def test_default_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + type ret1 = tuple[int, str] + + def consumer1(dep: A, opt_dep: optC = '42') -> ret1: + return dep.a, opt_dep + + def consumer2(dep: A, dep1: ret1) -> optC: + return str((dep.a, dep1)) + + p1 = repo.create_pipeline( + (B,), + [consumer1, consumer2], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p1(B(42.1)) + assert res == "(42, (42, '42'))" + + p2 = repo.create_pipeline( + (B,), + [consumer1, consumer2, consumer1], + force_commutative=True, + allow_sync=True, + allow_async=False, + force_async=False + ) + res = p2(B(42.1)) + assert res == (42, "(42, (42, '42'))")