Add basic pipeline construction, callseq deduplication pending

This commit is contained in:
2025-07-19 22:32:40 +03:00
parent eae2cd9a4b
commit b058a701a0
5 changed files with 139 additions and 23 deletions

View File

@@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence
from .graph_walker import GraphWalker from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
from .util import extract_return_type
Tin = TypeVarTuple('Tin') Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout') Tout = TypeVar('Tout')
@@ -27,6 +28,31 @@ class ConvRepo:
self.walker = graph_walker self.walker = graph_walker
self.renderer = renderer self.renderer = renderer
def create_pipeline(self,
from_types: Sequence[type],
fns: Sequence[Callable],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
):
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
pipeline_callseq = []
orig_from_types = tuple(from_types)
from_types = tuple(from_types)
for fn in fns:
injects = extract_return_type(fn)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
pipeline_callseq += callseq
if injects is not None:
from_types += (injects,)
return self.renderer.render(orig_from_types, pipeline_callseq, force_async=force_async)
@property @property
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
@@ -46,30 +72,26 @@ class ConvRepo:
ret += [variant.injector] ret += [variant.injector]
return ret return ret
def get_conversion(self, def filtered_injectors(self, allow_async: bool, allow_sync: bool) -> frozenset[ConversionPoint]:
from_types: tuple[type[Unpack[Tin]]], filtered_injectors: frozenset[ConversionPoint] = frozenset()
fn: Callable[..., Tout], for inj in self.convertor_set:
force_commutative: bool = True, if inj.is_async and not allow_async:
allow_async: bool = True, continue
allow_sync: bool = True, if not inj.is_async and not allow_sync:
force_async: bool = False continue
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: filtered_injectors |= {inj}
if not allow_async or force_async: return filtered_injectors
filtered_injectors: frozenset[ConversionPoint] = frozenset()
for inj in self.convertor_set:
if inj.is_async and not allow_async:
continue
if not inj.is_async and not allow_sync:
continue
filtered_injectors |= {inj}
else:
filtered_injectors = frozenset(self.convertor_set)
cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn) def get_callseq(self,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type], fn: Callable,
force_commutative: bool) -> list[ConversionPoint]:
cg = self.walker.generate_callgraph(injectors, from_types, fn)
if cg is None: if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}')
exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types)) exploded = self.walker.explode_callgraph_branches(cg, from_types)
selected = self.walker.filter_exploded_callgraph_branch(exploded) selected = self.walker.filter_exploded_callgraph_branch(exploded)
if len(selected) == 0: if len(selected) == 0:
@@ -79,6 +101,25 @@ class ConvRepo:
raise ValueError('Conversion path is not commutative') raise ValueError('Conversion path is not commutative')
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
if len(callseq) > 0:
injects = extract_return_type(fn)
callseq[-1] = callseq[-1].copy_with(injects=injects)
return callseq
def get_conversion(self,
from_types: Sequence[type[Unpack[Tin]]],
fn: Callable[..., Tout],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]:
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
return self.renderer.render(from_types, callseq, force_async=force_async) return self.renderer.render(from_types, callseq, force_async=force_async)
def mark_injector(self, *, rettype: Optional[type] = None): def mark_injector(self, *, rettype: Optional[type] = None):

View File

@@ -3,7 +3,7 @@ from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type
class GraphWalker: class GraphWalker:
@@ -15,6 +15,7 @@ class GraphWalker:
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable) -> Optional[Callgraph]:
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
rettype = extract_return_type(consumer_fn)
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого

View File

@@ -20,6 +20,13 @@ class ConversionPoint:
requires: tuple[type, ...] requires: tuple[type, ...]
opt_args: tuple[type, ...] opt_args: tuple[type, ...]
def copy_with(self, **kwargs):
fn = kwargs.get('fn', self.fn)
injects = kwargs.get('injects', self.injects)
requires = kwargs.get('requires', self.requires)
opt_args = kwargs.get('opt_args', self.opt_args)
return ConversionPoint(fn, injects, requires, opt_args)
def __hash__(self): def __hash__(self):
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))

View File

@@ -1,6 +1,11 @@
import inspect import inspect
from itertools import product from itertools import product
from typing import Callable, get_type_hints, TypeVar, Any from typing import Callable, get_type_hints, TypeVar, Any, Optional
def extract_return_type(func: Callable) -> Optional[type]:
hints = get_type_hints(func)
return hints.get('return')
def extract_func_args(func: Callable) -> list[tuple[str, type]]: def extract_func_args(func: Callable) -> list[tuple[str, type]]:

62
tests/test_pipeline.py Normal file
View File

@@ -0,0 +1,62 @@
from dataclasses import dataclass
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
type optC = str
def test_default_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type ret1 = tuple[int, str]
def consumer1(dep: A, opt_dep: optC = '42') -> ret1:
return dep.a, opt_dep
def consumer2(dep: A, dep1: ret1) -> optC:
return str((dep.a, dep1))
p1 = repo.create_pipeline(
(B,),
[consumer1, consumer2],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p1(B(42.1))
assert res == "(42, (42, '42'))"
p2 = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer1],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p2(B(42.1))
assert res == (42, "(42, (42, '42'))")