Compare commits

..

3 Commits

7 changed files with 257 additions and 29 deletions

View File

@@ -103,6 +103,49 @@ assert tst == 1
``` ```
----
#### Сборка конвейеров преобразований:
Пусть, имеется несколько методов-потребителей, которые необходимо вызывать последовательно:
```python
from breakshaft.convertor import ConvRepo
repo = ConvRepo()
# Объявляем A и B, а также методы преобразований - как в прошлом примере
type cons2ret = str # избегаем использования builtin-типов, чтобы избежать простых коллизий
def consumer1(dep: A) -> B:
return B(float(42))
def consumer2(dep: B) -> cons2ret:
return str(dep.b)
def consumer3(dep: cons2ret) -> int:
return int(float(dep))
pipeline = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer3],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
dat = pipeline(B(42))
assert dat == 42
```
---- ----
#### Как получить граф преобразований: #### Как получить граф преобразований:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence
from .graph_walker import GraphWalker from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
from .util import extract_return_type
Tin = TypeVarTuple('Tin') Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout') Tout = TypeVar('Tout')
@@ -27,6 +28,31 @@ class ConvRepo:
self.walker = graph_walker self.walker = graph_walker
self.renderer = renderer self.renderer = renderer
def create_pipeline(self,
from_types: Sequence[type],
fns: Sequence[Callable],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
):
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
pipeline_callseq = []
orig_from_types = tuple(from_types)
from_types = tuple(from_types)
for fn in fns:
injects = extract_return_type(fn)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
pipeline_callseq += callseq
if injects is not None:
from_types += (injects,)
return self.renderer.render(orig_from_types, pipeline_callseq, force_async=force_async)
@property @property
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
@@ -46,30 +72,26 @@ class ConvRepo:
ret += [variant.injector] ret += [variant.injector]
return ret return ret
def get_conversion(self, def filtered_injectors(self, allow_async: bool, allow_sync: bool) -> frozenset[ConversionPoint]:
from_types: tuple[type[Unpack[Tin]]], filtered_injectors: frozenset[ConversionPoint] = frozenset()
fn: Callable[..., Tout], for inj in self.convertor_set:
force_commutative: bool = True, if inj.is_async and not allow_async:
allow_async: bool = True, continue
allow_sync: bool = True, if not inj.is_async and not allow_sync:
force_async: bool = False continue
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: filtered_injectors |= {inj}
if not allow_async or force_async: return filtered_injectors
filtered_injectors: frozenset[ConversionPoint] = frozenset()
for inj in self.convertor_set:
if inj.is_async and not allow_async:
continue
if not inj.is_async and not allow_sync:
continue
filtered_injectors |= {inj}
else:
filtered_injectors = frozenset(self.convertor_set)
cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn) def get_callseq(self,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type], fn: Callable,
force_commutative: bool) -> list[ConversionPoint]:
cg = self.walker.generate_callgraph(injectors, from_types, fn)
if cg is None: if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}')
exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types)) exploded = self.walker.explode_callgraph_branches(cg, from_types)
selected = self.walker.filter_exploded_callgraph_branch(exploded) selected = self.walker.filter_exploded_callgraph_branch(exploded)
if len(selected) == 0: if len(selected) == 0:
@@ -79,6 +101,25 @@ class ConvRepo:
raise ValueError('Conversion path is not commutative') raise ValueError('Conversion path is not commutative')
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
if len(callseq) > 0:
injects = extract_return_type(fn)
callseq[-1] = callseq[-1].copy_with(injects=injects)
return callseq
def get_conversion(self,
from_types: Sequence[type[Unpack[Tin]]],
fn: Callable[..., Tout],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]:
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
return self.renderer.render(from_types, callseq, force_async=force_async) return self.renderer.render(from_types, callseq, force_async=force_async)
def mark_injector(self, *, rettype: Optional[type] = None): def mark_injector(self, *, rettype: Optional[type] = None):

View File

@@ -3,7 +3,7 @@ from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type
class GraphWalker: class GraphWalker:
@@ -15,6 +15,7 @@ class GraphWalker:
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable) -> Optional[Callgraph]:
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
rettype = extract_return_type(consumer_fn)
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого

View File

@@ -20,6 +20,13 @@ class ConversionPoint:
requires: tuple[type, ...] requires: tuple[type, ...]
opt_args: tuple[type, ...] opt_args: tuple[type, ...]
def copy_with(self, **kwargs):
fn = kwargs.get('fn', self.fn)
injects = kwargs.get('injects', self.injects)
requires = kwargs.get('requires', self.requires)
opt_args = kwargs.get('opt_args', self.opt_args)
return ConversionPoint(fn, injects, requires, opt_args)
def __hash__(self): def __hash__(self):
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))

View File

@@ -51,6 +51,28 @@ class ConversionArgRenderData:
typehash: str typehash: str
def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]:
deduplicated_conv_models: list[ConversionRenderData] = []
for conv_model in conversion_models:
if conv_model not in deduplicated_conv_models:
deduplicated_conv_models.append(conv_model)
continue
argnames = list(map(lambda x: x[1], conv_model.funcargs))
argument_changed = False
found_model = False
for m in deduplicated_conv_models:
if not found_model and m == conv_model:
found_model = True
if found_model and m.inj_hash in argnames:
argument_changed = True
break
if argument_changed:
deduplicated_conv_models.append(conv_model)
return deduplicated_conv_models
class InTimeGenerationConvertorRenderer(ConvertorRenderer): class InTimeGenerationConvertorRenderer(ConvertorRenderer):
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
@@ -72,7 +94,7 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
force_async: bool = False) -> Callable: force_async: bool = False) -> Callable:
fnmap = {} fnmap = {}
conversion_models = [] conversion_models: list[ConversionRenderData] = []
ret_hash = 0 ret_hash = 0
is_async = force_async is_async = force_async
for call_id, call in enumerate(callseq): for call_id, call in enumerate(callseq):
@@ -84,12 +106,13 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
fnmap[hash(call.fn)] = call.fn fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types) conv = ConversionRenderData.from_inj(call, provided_types)
if conv not in conversion_models: conversion_models.append(conv)
conversion_models.append(conv) if call.is_async:
if call.is_async: is_async = True
is_async = True
ret_hash = hash(callseq[-1].injects) conversion_models = deduplicate_callseq(conversion_models)
ret_hash = hashname(callseq[-1].injects)
conv_args = [] conv_args = []
for i, from_type in enumerate(from_types): for i, from_type in enumerate(from_types):

View File

@@ -1,6 +1,11 @@
import inspect import inspect
from itertools import product from itertools import product
from typing import Callable, get_type_hints, TypeVar, Any from typing import Callable, get_type_hints, TypeVar, Any, Optional
def extract_return_type(func: Callable) -> Optional[type]:
hints = get_type_hints(func)
return hints.get('return')
def extract_func_args(func: Callable) -> list[tuple[str, type]]: def extract_func_args(func: Callable) -> list[tuple[str, type]]:

108
tests/test_pipeline.py Normal file
View File

@@ -0,0 +1,108 @@
from dataclasses import dataclass
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
type optC = str
def test_default_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type ret1 = tuple[int, str]
def consumer1(dep: A, opt_dep: optC = '42') -> ret1:
return dep.a, opt_dep
def consumer2(dep: A, dep1: ret1) -> optC:
return str((dep.a, dep1))
p1 = repo.create_pipeline(
(B,),
[consumer1, consumer2],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p1(B(42.1))
assert res == "(42, (42, '42'))"
p2 = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer1],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p2(B(42.1))
assert res == (42, "(42, (42, '42'))")
def test_pipeline_with_subgraph_duplicates():
repo = ConvRepo()
b_to_a_calls = [0]
@repo.mark_injector()
def b_to_a(b: B) -> A:
b_to_a_calls[0] += 1
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type ret1 = tuple[int, str]
cons1_calls = [0]
cons2_calls = [0]
def consumer1(dep: A, opt_dep: optC = '42') -> A:
cons1_calls[0] += 1
return A(dep.a + int(opt_dep))
def consumer2(dep: A) -> optC:
cons2_calls[0] += 1
return str(dep.a)
p1 = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p1(B(42.1))
assert res.a == 42 + (42 * 31)
assert b_to_a_calls[0] == 1
assert cons1_calls[0] == 5
assert cons2_calls[0] == 4