Compare commits

...

30 Commits

Author SHA1 Message Date
74d78b1957 Fix callseq deduplication error, allow using Some|None=None args with no commutativity error, add ignore_basictypes_return for a ConversionPoint.from_fn 2025-10-17 00:51:29 +03:00
dbecef1977 Fix renderer deduplicate_callseq duplicates slips 2025-09-14 01:55:47 +03:00
27939ef3ea Fix ForkedConvRepo add_injector signature 2025-08-20 22:01:38 +03:00
5ac6ff102f Add method add_conversion_points into a ConvRepo 2025-08-20 03:11:17 +03:00
9142cb05fc Fix import universal_qualname in a GraphWalker 2025-08-20 03:07:35 +03:00
a256db0203 Add ConversionPoint reference into a ConversionRenderData for a further deduplication and reuse of raw call sequence 2025-08-20 00:31:42 +03:00
d68bb79a97 Bump version 2025-08-19 02:37:59 +03:00
9d03affd41 Allow passing [ConversionPoint] into create_pipeline with a type remap for ConversionPoint 2025-08-19 02:37:01 +03:00
52d82550e6 Allow passing [ConversionPoint] into get_conversion with a type remap for ConversionPoint 2025-08-19 02:32:15 +03:00
742c21e199 Bump version 2025-08-16 18:46:06 +03:00
fd8026a2a5 Update README.md: sync feature list 2025-08-16 18:45:55 +03:00
3150c4b2d0 Fix ctxmanager injects hash 2025-08-16 18:44:58 +03:00
d6f8038efa Make tuple return types unwrap 2025-08-16 18:38:46 +03:00
42b0badc65 Fix ConversionPoint.__repr__ on objects that does not have __qualname__ 2025-08-04 22:43:40 +03:00
849d6094a9 Fix draw_callseq_mermaid cell name 2025-08-04 22:40:09 +03:00
45010c1cf3 Add util_mermaid callseq renderer, fix forked convrepo store_* corruption 2025-07-22 00:34:27 +03:00
70e7b4fe3f Add options to store rendered sources and call sequences 2025-07-21 17:40:36 +03:00
e767ccae15 Fix ConversionPoint.fn_args ignorance of type annot override 2025-07-21 15:45:48 +03:00
90409ec774 Update README.md: sync feature list 2025-07-21 15:34:53 +03:00
6fe37a5ae1 Bump version 2025-07-21 15:32:38 +03:00
66241cd01a Update README.md: add pipeline descripiton 2025-07-21 15:31:26 +03:00
a0de9fcda8 Make smart call deduplication 2025-07-19 22:49:15 +03:00
b058a701a0 Add basic pipeline construction, callseq deduplication pending 2025-07-19 22:32:40 +03:00
eae2cd9a4b Remove unused defaults 2025-07-19 21:13:40 +03:00
69def6e74c Allow default option to be overriden if there is any conversion point that injects this default option 2025-07-19 21:12:35 +03:00
f2ec4fad14 Allow default option to be overriden if was ocasionally provided on a conversion path 2025-07-19 21:08:46 +03:00
b04ea2c16a Add test for non-provided default convertor args 2025-07-19 20:50:49 +03:00
fe53cf9270 Add test for non-provided default consumer args 2025-07-19 20:49:10 +03:00
a2cf1bb6e6 Get rid of manual consumer fn unwrapping for callgraph generation 2025-07-19 20:38:43 +03:00
6bf28e5fe8 Add defaulted args into a ConversionPoint 2025-07-19 20:21:10 +03:00
14 changed files with 936 additions and 87 deletions

View File

@@ -18,9 +18,11 @@
- Поддерживает асинхронный контекст - Поддерживает асинхронный контекст
- Поддерживает внедрение зависимости через синхронные/асинхронные менеджеры контекста - Поддерживает внедрение зависимости через синхронные/асинхронные менеджеры контекста
- Поддерживает `Union`-типы в зависимостях - Поддерживает `Union`-типы в зависимостях
- Учитывает default-параметры
- Позволяет выстраивать конвейеры преобразований
- Опционально разворачивает кортежи в возвращаемых значениях
#### Ограничения библиотеки: #### Ограничения библиотеки:
- Зависимости со стандартными параметрами пока не поддерживаются
- Выбор графа преобразований вызывает комбинаторный взрыв - Выбор графа преобразований вызывает комбинаторный взрыв
- Кэширование графов преобразований не поддерживается - Кэширование графов преобразований не поддерживается
- При некоммутативности сгенерированного графа, имеется опасность неконсистентного выбора пути, поскольку порядок обхода методов, а также графа, не гарантирован - При некоммутативности сгенерированного графа, имеется опасность неконсистентного выбора пути, поскольку порядок обхода методов, а также графа, не гарантирован
@@ -103,6 +105,49 @@ assert tst == 1
``` ```
----
#### Сборка конвейеров преобразований:
Пусть, имеется несколько методов-потребителей, которые необходимо вызывать последовательно:
```python
from breakshaft.convertor import ConvRepo
repo = ConvRepo()
# Объявляем A и B, а также методы преобразований - как в прошлом примере
type cons2ret = str # избегаем использования builtin-типов, чтобы избежать простых коллизий
def consumer1(dep: A) -> B:
return B(float(42))
def consumer2(dep: B) -> cons2ret:
return str(dep.b)
def consumer3(dep: cons2ret) -> int:
return int(float(dep))
pipeline = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer3],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
dat = pipeline(B(42))
assert dat == 42
```
---- ----
#### Как получить граф преобразований: #### Как получить граф преобразований:

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "breakshaft" name = "breakshaft"
version = "0.1.0.post2" version = "0.1.6.post5"
description = "Library for in-time codegen for type conversion" description = "Library for in-time codegen for type conversion"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -1,9 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any
import collections.abc
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence, Iterable
from .graph_walker import GraphWalker from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
from .util import extract_return_type, universal_qualname
Tin = TypeVarTuple('Tin') Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout') Tout = TypeVar('Tout')
@@ -14,10 +17,14 @@ class ConvRepo:
walker: GraphWalker walker: GraphWalker
renderer: ConvertorRenderer renderer: ConvertorRenderer
store_callseq: bool
store_sources: bool
def __init__(self, def __init__(self,
graph_walker: Optional[GraphWalker] = None, graph_walker: Optional[GraphWalker] = None,
renderer: Optional[ConvertorRenderer] = None, ): renderer: Optional[ConvertorRenderer] = None,
store_callseq: bool = False,
store_sources: bool = False):
if graph_walker is None: if graph_walker is None:
graph_walker = GraphWalker() graph_walker = GraphWalker()
if renderer is None: if renderer is None:
@@ -26,13 +33,60 @@ class ConvRepo:
self._convertor_set = set() self._convertor_set = set()
self.walker = graph_walker self.walker = graph_walker
self.renderer = renderer self.renderer = renderer
self.store_callseq = store_callseq
self.store_sources = store_sources
def create_pipeline(self,
from_types: Sequence[type],
fns: Sequence[Callable | Iterable[ConversionPoint] | ConversionPoint],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
):
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
pipeline_callseq = []
orig_from_types = tuple(from_types)
from_types = tuple(from_types)
for fn in fns:
injects = None
if isinstance(fn, collections.abc.Iterable):
for f in fn:
injects = f.injects
break
elif isinstance(fn, ConversionPoint):
injects = fn.injects
else:
injects = extract_return_type(fn)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
pipeline_callseq += callseq
if injects is not None:
from_types += (injects,)
ret_fn = self.renderer.render(orig_from_types,
pipeline_callseq,
force_async=force_async,
store_sources=self.store_sources)
if self.store_callseq:
setattr(ret_fn, '__breakshaft_callseq__', pipeline_callseq)
return ret_fn
@property @property
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
def add_injector(self, func: Callable, rettype: Optional[type] = None): def add_conversion_points(self, conversion_points: Iterable[ConversionPoint]):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype)) self._convertor_set |= set(conversion_points)
def add_injector(self,
func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None):
self.add_conversion_points(ConversionPoint.from_fn(func, rettype=rettype, type_remap=type_remap))
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]: def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
if len(cg.variants) == 0: if len(cg.variants) == 0:
@@ -46,30 +100,27 @@ class ConvRepo:
ret += [variant.injector] ret += [variant.injector]
return ret return ret
def get_conversion(self, def filtered_injectors(self, allow_async: bool, allow_sync: bool) -> frozenset[ConversionPoint]:
from_types: tuple[type[Unpack[Tin]]], filtered_injectors: frozenset[ConversionPoint] = frozenset()
fn: Callable[..., Tout], for inj in self.convertor_set:
force_commutative: bool = True, if inj.is_async and not allow_async:
allow_async: bool = True, continue
allow_sync: bool = True, if not inj.is_async and not allow_sync:
force_async: bool = False continue
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]: filtered_injectors |= {inj}
if not allow_async or force_async: return filtered_injectors
filtered_injectors: frozenset[ConversionPoint] = frozenset()
for inj in self.convertor_set:
if inj.is_async and not allow_async:
continue
if not inj.is_async and not allow_sync:
continue
filtered_injectors |= {inj}
else:
filtered_injectors = frozenset(self.convertor_set)
cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn) def get_callseq(self,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type],
fn: Callable | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool) -> list[ConversionPoint]:
cg = self.walker.generate_callgraph(injectors, from_types, fn)
if cg is None: if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') raise ValueError(f'Unable to compute conversion graph on {from_types}->{universal_qualname(fn)}')
exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types)) exploded = self.walker.explode_callgraph_branches(cg, from_types)
selected = self.walker.filter_exploded_callgraph_branch(exploded) selected = self.walker.filter_exploded_callgraph_branch(exploded)
if len(selected) == 0: if len(selected) == 0:
@@ -79,17 +130,51 @@ class ConvRepo:
raise ValueError('Conversion path is not commutative') raise ValueError('Conversion path is not commutative')
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
return self.renderer.render(from_types, callseq, force_async=force_async)
def mark_injector(self, *, rettype: Optional[type] = None): if len(callseq) > 0:
injects = None
if isinstance(fn, collections.abc.Iterable):
for f in fn:
injects = f.injects
break
elif isinstance(fn, ConversionPoint):
injects = fn.injects
else:
injects = extract_return_type(fn)
callseq[-1] = callseq[-1].copy_with(injects=injects)
return callseq
def get_conversion(self,
from_types: Sequence[type[Unpack[Tin]]],
fn: Callable[..., Tout] | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]:
filtered_injectors = self.filtered_injectors(allow_async, allow_sync)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
ret_fn = self.renderer.render(from_types, callseq, force_async=force_async, store_sources=self.store_sources)
if self.store_callseq:
setattr(ret_fn, '__breakshaft_callseq__', callseq)
return ret_fn
def mark_injector(self, *, rettype: Optional[type] = None, type_remap: Optional[dict[str, type]] = None):
def inner(func: Callable): def inner(func: Callable):
self.add_injector(func) self.add_injector(func, rettype=rettype, type_remap=type_remap)
return func return func
return inner return inner
def fork(self, fork_with: Optional[set[ConversionPoint]] = None) -> ConvRepo: def fork(self, fork_with: Optional[set[ConversionPoint]] = None) -> ConvRepo:
return ForkedConvRepo(self, fork_with or None, self.walker, self.renderer) return ForkedConvRepo(self, fork_with or None,
self.walker,
self.renderer,
self.store_callseq,
self.store_sources)
class ForkedConvRepo(ConvRepo): class ForkedConvRepo(ConvRepo):
@@ -99,16 +184,16 @@ class ForkedConvRepo(ConvRepo):
fork_from: ConvRepo, fork_from: ConvRepo,
fork_with: Optional[set[ConversionPoint]] = None, fork_with: Optional[set[ConversionPoint]] = None,
graph_walker: Optional[GraphWalker] = None, graph_walker: Optional[GraphWalker] = None,
renderer: Optional[ConvertorRenderer] = None): renderer: Optional[ConvertorRenderer] = None,
super().__init__(graph_walker, renderer) store_callseq: bool = False,
store_sources: bool = False,
):
super().__init__(graph_walker, renderer, store_callseq, store_sources)
if fork_with is None: if fork_with is None:
fork_with = set() fork_with = set()
self._convertor_set = fork_with self._convertor_set = fork_with
self._base_repo = fork_from self._base_repo = fork_from
def add_injector(self, func: Callable, rettype: Optional[type] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
@property @property
def convertor_set(self): def convertor_set(self):
return self._base_repo.convertor_set | self._convertor_set return self._base_repo.convertor_set | self._convertor_set

View File

@@ -1,9 +1,11 @@
import collections.abc
import typing import typing
from types import NoneType from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type, universal_qualname
from typing import Iterable
class GraphWalker: class GraphWalker:
@@ -12,21 +14,27 @@ class GraphWalker:
def generate_callgraph(cls, def generate_callgraph(cls,
injectors: frozenset[ConversionPoint], injectors: frozenset[ConversionPoint],
from_types: frozenset[type], from_types: frozenset[type],
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable | Iterable[ConversionPoint] | ConversionPoint) -> Optional[Callgraph]:
into_types: frozenset[type] = extract_func_argtypes(consumer_fn)
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
for into_type in into_types: # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
cg = cls.generate_callgraph_singletype(injectors, from_types, into_type) # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
if cg is None: # При этом, TypeAliasType также выступает в роли ключа преобразования
return None # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
branches |= {cg} type _tmp_type_for_consumer = object
variant = CallgraphVariant(
ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))), if isinstance(consumer_fn, collections.abc.Iterable):
branches, frozenset()) new_consumer_injectors = set()
return Callgraph(frozenset({variant})) for fn in consumer_fn:
new_consumer_injectors.add(fn.copy_with(injects=_tmp_type_for_consumer))
injectors |= new_consumer_injectors
elif isinstance(consumer_fn, ConversionPoint):
injectors |= set(consumer_fn.copy_with(injects=_tmp_type_for_consumer))
else:
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
@classmethod @classmethod
def generate_callgraph_singletype(cls, def generate_callgraph_singletype(cls,
@@ -72,7 +80,17 @@ class GraphWalker:
variant_subgraphs.add(subg) variant_subgraphs.add(subg)
if not dead_end: if not dead_end:
consumed = frozenset(point.requires) & from_types
for opt in point.opt_args:
subg = cls.generate_callgraph_singletype(injectors,
from_types,
opt,
visited_path=visited_path.copy(),
visited_types=visited_types.copy())
if subg is not None:
variant_subgraphs.add(subg)
consumed = (frozenset(point.requires) | frozenset(point.opt_args)) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant) head = head.add_subgraph_variant(variant)
@@ -135,7 +153,7 @@ class GraphWalker:
if len(variants) > 1: if len(variants) > 1:
# sorting by first injector func name for creating minimal cosistancy # sorting by first injector func name for creating minimal cosistancy
# could lead to heizenbugs due to incosistancy in path selection between calls # could lead to heizenbugs due to incosistancy in path selection between calls
variants.sort(key=lambda x: x.injector.fn.__qualname__) variants.sort(key=lambda x: universal_qualname(x.injector.fn))
return variants return variants
if len(variants) < 2: if len(variants) < 2:

View File

@@ -10,24 +10,39 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
is_async_context_manager_factory, \ is_async_context_manager_factory, \
all_combinations, is_context_manager_factory all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \
get_tuple_types, is_basic_type_annot, universal_qualname
@dataclass(frozen=True) @dataclass(frozen=True)
class ConversionPoint: class ConversionPoint:
fn: Callable fn: Callable
injects: type injects: type
rettype: type
requires: tuple[type, ...] requires: tuple[type, ...]
opt_args: tuple[type, ...]
def copy_with(self, **kwargs):
fn = kwargs.get('fn', self.fn)
rettype = kwargs.get('rettype', self.rettype)
injects = kwargs.get('injects', self.injects)
requires = kwargs.get('requires', self.requires)
opt_args = kwargs.get('opt_args', self.opt_args)
return ConversionPoint(fn, injects, rettype, requires, opt_args)
def __hash__(self): def __hash__(self):
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))
def __repr__(self): def __repr__(self):
return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}' injects_name = universal_qualname(self.injects)
fn_name = universal_qualname(self.fn)
return f'({",".join(map(str, self.requires))}) -> {injects_name}: {fn_name}'
@property @property
def fn_args(self) -> list[type]: def fn_args(self) -> list[tuple[str, type]]:
return extract_func_argtypes_seq(self.fn) funcnames = extract_func_argnames(self.fn)
return list(zip(funcnames, self.requires + self.opt_args))
@property @property
def is_ctx_manager(self) -> bool: def is_ctx_manager(self) -> bool:
@@ -38,15 +53,25 @@ class ConversionPoint:
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
@classmethod @classmethod
def from_fn(cls, func: Callable, rettype: Optional[type] = None): def from_fn(cls,
if rettype is None: func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None,
ignore_basictype_return: bool = False) -> list[ConversionPoint]:
if type_remap is None:
annot = get_type_hints(func) annot = get_type_hints(func)
rettype = annot.get('return') else:
annot = type_remap
fn_rettype = annot.get('return')
if rettype is None:
rettype = fn_rettype
if rettype is None: if rettype is None:
raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified') raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified')
rettype_origin = get_origin(rettype) rettype_origin = get_origin(rettype)
fn_rettype_origin = get_origin(fn_rettype)
cm_out_origins = [ cm_out_origins = [
typing.Generator, typing.Generator,
typing.Iterator, typing.Iterator,
@@ -59,22 +84,55 @@ class ConversionPoint:
] ]
if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
rettype = get_args(rettype)[0] rettype = get_args(rettype)[0]
if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
fn_rettype = get_args(fn_rettype)[0]
if not ignore_basictype_return and is_basic_type_annot(rettype):
return []
ret = []
tuple_unwrapped = get_tuple_types(rettype)
# Do not unwrap elipsis, but unwrap non-empty tuples
if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped:
for t in tuple_unwrapped:
if not is_basic_type_annot(t):
ret += ConversionPoint.from_fn(func,
rettype=t,
type_remap=type_remap,
ignore_basictype_return=ignore_basictype_return)
argtypes: list[list[type]] = [] argtypes: list[list[type]] = []
orig_argtypes = extract_func_argtypes_seq(func) orig_args = extract_func_args(func, type_remap)
for argtype in orig_argtypes: defaults = extract_func_arg_defaults(func)
orig_argtypes = []
for argname, argtype in orig_args:
orig_argtypes.append((argtype, argname in defaults.keys()))
default_map: list[bool] = []
for argtype, has_default in orig_argtypes:
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
u_types = list(get_args(argtype)) + [argtype] u_types = list(get_args(argtype)) + [argtype]
else: else:
u_types = [argtype] u_types = [argtype]
default_map.append(has_default)
argtypes.append(u_types) argtypes.append(u_types)
argtype_combinations = all_combinations(argtypes) argtype_combinations = all_combinations(argtypes)
ret = []
for argtype_combination in argtype_combinations:
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination)))
# return InjectorPoint(func, rettype, argtypes) for argtype_combination in argtype_combinations:
req_args = []
opt_args = []
for argt, has_default in zip(argtype_combination, default_map):
if has_default:
opt_args.append(argt)
else:
req_args.append(argt)
if rettype in req_args:
continue
ret.append(ConversionPoint(func, rettype, fn_rettype, tuple(req_args), tuple(opt_args)))
return ret return ret

View File

@@ -7,32 +7,76 @@ import importlib.resources
import jinja2 import jinja2
from .models import ConversionPoint from .models import ConversionPoint
from .util import hashname from .util import hashname, get_tuple_types, is_basic_type_annot, universal_qualname
class ConvertorRenderer(Protocol): class ConvertorRenderer(Protocol):
def render(self, def render(self,
from_types: Sequence[type], from_types: Sequence[type],
callseq: Sequence[ConversionPoint], callseq: Sequence[ConversionPoint],
force_async: bool = False) -> Callable: force_async: bool = False,
store_sources: bool = False) -> Callable:
raise NotImplementedError() raise NotImplementedError()
type UnwprappedTuple = tuple[tuple[UnwprappedTuple, str] | str | None, ...]
def unwrap_tuple_type(typ: type) -> UnwprappedTuple:
unwrap_tuple_result = ()
tuple_types = get_tuple_types(typ)
if len(tuple_types) > 0 and Ellipsis not in tuple_types:
for t in tuple_types:
if not is_basic_type_annot(t):
subtuple = unwrap_tuple_type(t)
hn = hashname(t)
if len(subtuple) > 0:
unwrap_tuple_result += ((subtuple, hn),)
else:
unwrap_tuple_result += (hn,)
else:
unwrap_tuple_result += (None,)
if not any(map(lambda x: x is not None, unwrap_tuple_result)):
return ()
return unwrap_tuple_result
@dataclass @dataclass
class ConversionRenderData: class ConversionRenderData:
inj_hash: str inj_hash: str
funchash: str funchash: str
funcname: str funcname: str
funcargs: list[str] funcargs: list[tuple[str, str]]
is_ctxmanager: bool is_ctxmanager: bool
is_async: bool is_async: bool
unwrap_tuple_result: UnwprappedTuple
_injection: ConversionPoint
@classmethod @classmethod
def from_inj(cls, inj: ConversionPoint): def from_inj(cls, inj: ConversionPoint, provided_types: set[type]):
argmap = inj.fn_args
fnargs = [] fnargs = []
for argtype in inj.requires: for arg_id, argtype in enumerate(inj.requires):
fnargs.append(hashname(argtype)) argname = argmap[arg_id][0]
return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async) fnargs.append((argname, hashname(argtype)))
for arg_id, argtype in enumerate(inj.opt_args, len(inj.requires)):
argname = argmap[arg_id][0]
if argtype in provided_types:
fnargs.append((argname, hashname(argtype)))
unwrap_tuple_result = unwrap_tuple_type(inj.rettype)
return cls(hashname(inj.rettype),
hashname(inj.fn),
repr(inj.fn),
fnargs,
inj.is_ctx_manager,
inj.is_async,
unwrap_tuple_result,
inj)
@dataclass @dataclass
@@ -42,6 +86,49 @@ class ConversionArgRenderData:
typehash: str typehash: str
def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]:
deduplicated_conv_models: list[ConversionRenderData] = []
deduplicated_hashes = set()
for conv_model in conversion_models:
if hash((conv_model.inj_hash, conv_model.funchash)) not in deduplicated_hashes:
deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
continue
argnames = list(map(lambda x: x[1], conv_model.funcargs))
argument_changed = False
found_model = False
for m in deduplicated_conv_models:
if not found_model and m.funchash == conv_model.funchash:
found_model = True
if found_model and m.inj_hash in argnames:
argument_changed = True
break
if argument_changed:
deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
return deduplicated_conv_models
def render_data_from_callseq(from_types: Sequence[type],
fnmap: dict[int, Callable],
callseq: Sequence[ConversionPoint]):
conversion_models: list[ConversionRenderData] = []
ret_hash = 0
for call_id, call in enumerate(callseq):
provided_types = set(from_types)
for _call in callseq[:call_id]:
provided_types |= {_call.injects}
provided_types |= set(_call.requires)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types)
conversion_models.append(conv)
return conversion_models
class InTimeGenerationConvertorRenderer(ConvertorRenderer): class InTimeGenerationConvertorRenderer(ConvertorRenderer):
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
@@ -60,22 +147,20 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
def render(self, def render(self,
from_types: Sequence[type], from_types: Sequence[type],
callseq: Sequence[ConversionPoint], callseq: Sequence[ConversionPoint],
force_async: bool = False) -> Callable: force_async: bool = False,
store_sources: bool = False) -> Callable:
fnmap = {} fnmap = {}
conversion_models = [] conversion_models: list[ConversionRenderData] = render_data_from_callseq(from_types, fnmap, callseq)
ret_hash = 0 ret_hash = 0
is_async = force_async is_async = force_async
for call_id, call in enumerate(callseq):
if call.is_async:
is_async = True
for call in callseq: conversion_models = deduplicate_callseq(conversion_models)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call)
if conv not in conversion_models:
conversion_models.append(conv)
if call.is_async:
is_async = True
ret_hash = hash(callseq[-1].injects) ret_hash = hashname(callseq[-1].rettype)
conv_args = [] conv_args = []
for i, from_type in enumerate(from_types): for i, from_type in enumerate(from_types):
@@ -91,7 +176,10 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
is_async=is_async, is_async=is_async,
) )
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
convertor_functext = convertor_functext.replace(', )', ')').replace(',)', ')')
exec(convertor_functext, namespace) exec(convertor_functext, namespace)
unwrap_func = namespace['convertor'] unwrap_func = namespace['convertor']
if store_sources:
setattr(unwrap_func, '__breakshaft_render_src__', convertor_functext)
return typing.cast(Callable, unwrap_func) return typing.cast(Callable, unwrap_func)

View File

@@ -1,13 +1,34 @@
{% set ns = namespace(indent=0) %} {% set ns = namespace(indent=0) %}
{% macro unwrap_tuple(tupl, unwrap_name) -%}
{%- set out -%}
{% if tupl | length > 0 %}
{% for t in tupl %}
{% if t is string %}
_{{t}} = _{{unwrap_name}}[{{loop.index0}}]
{% endif %}
{% if t.__class__.__name__ == 'tuple' %}
_{{t[1]}} = _{{unwrap_name}}[{{loop.index0}}]
{{unwrap_tuple(t[0], t[1])}}
{% endif %}
{% endfor %}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% if is_async %}async {% endif %}def convertor({% for arg in conv_args %}_{{arg.typehash}}: "{{arg.typename}}",{% endfor %}){% if rettype %} -> '{{rettype}}'{% endif %}: {% if is_async %}async {% endif %}def convertor({% for arg in conv_args %}_{{arg.typehash}}: "{{arg.typename}}",{% endfor %}){% if rettype %} -> '{{rettype}}'{% endif %}:
{% for conv in conversions %} {% for conv in conversions %}
{% if conv.is_ctxmanager %} {% if conv.is_ctxmanager %}
{{ ' ' * ns.indent }}# {{conv.funcname}} {{ ' ' * ns.indent }}# {{conv.funcname}}
{{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}: {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}}=_{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}:
{% set ns.indent = ns.indent + 1 %} {% set ns.indent = ns.indent + 1 %}
{% else %} {% else %}
{{ ' ' * ns.indent }}# {{conv.funcname}} {{ ' ' * ns.indent }}# {{conv.funcname}}
{{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}}=_{{conv_arg[1]}}, {% endfor %})
{% endif %} {% endif %}
{{unwrap_tuple(conv.unwrap_tuple_result, conv.inj_hash) | indent((ns.indent + 1) * 4)}}
{% endfor %} {% endfor %}
{{ ' ' * ns.indent }}return _{{ret_hash}} {{ ' ' * ns.indent }}return _{{ret_hash}}

View File

@@ -1,11 +1,31 @@
import inspect import inspect
import typing
from itertools import product from itertools import product
from typing import Callable, get_type_hints, TypeVar, Any from typing import Callable, get_type_hints, TypeVar, Any, Optional
def extract_func_args(func: Callable) -> list[tuple[str, type]]: def extract_func_argnames(func: Callable) -> list[str]:
sig = inspect.signature(func) sig = inspect.signature(func)
type_hints = get_type_hints(func) params = sig.parameters
args_info = []
for name, _ in params.items():
args_info.append(name)
return args_info
def extract_return_type(func: Callable) -> Optional[type]:
hints = get_type_hints(func)
return hints.get('return')
def extract_func_args(func: Callable, type_hints_remap: Optional[dict[str, type]] = None) -> list[tuple[str, type]]:
sig = inspect.signature(func)
if type_hints_remap is None:
type_hints = get_type_hints(func)
else:
type_hints = type_hints_remap
params = sig.parameters params = sig.parameters
args_info = [] args_info = []
@@ -42,6 +62,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]:
return ret return ret
def extract_func_arg_defaults(func: Callable) -> dict[str, object]:
sig = inspect.signature(func)
defaults = {
name: param.default
for name, param in sig.parameters.items()
if param.default is not inspect._empty
}
return defaults
def is_context_manager_factory(obj: object) -> bool: def is_context_manager_factory(obj: object) -> bool:
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj) return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
@@ -63,3 +93,62 @@ T = TypeVar('T')
def all_combinations(options: list[list[T]]) -> list[list[T]]: def all_combinations(options: list[list[T]]) -> list[list[T]]:
return [list(comb) for comb in product(*options)] return [list(comb) for comb in product(*options)]
def get_tuple_types(type_obj: type) -> tuple:
ret = ()
origin = getattr(type_obj, '__origin__', None)
if origin is tuple:
args = getattr(type_obj, '__args__', ())
ret = args if args else ()
return ret
def is_basic_type_annot(type_annot) -> bool:
basic_types = {
int, float, str, bool, complex,
list, dict, tuple, set, frozenset,
bytes, bytearray, memoryview,
type(None), object
}
origin = getattr(type_annot, '__origin__', None)
args = getattr(type_annot, '__args__', None)
if type_annot in basic_types:
return True
if origin is not None:
if origin in basic_types or origin in {list, dict, tuple, set, frozenset}:
if args:
return all(is_basic_type_annot(arg) for arg in args)
return True
return False
if origin is typing.Union:
return all(is_basic_type_annot(arg) for arg in args)
return False
def universal_qualname(any: Any) -> str:
ret = ''
if hasattr(any, '__qualname__'):
ret = any.__qualname__
elif hasattr(any, '__name__'):
ret = any.__name__
else:
ret = str(any)
ret = (ret
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_')
.replace('\'', '')
.replace('<', '')
.replace('>', ''))
return ret

View File

@@ -1,4 +1,4 @@
from .models import Callgraph, TransformationPoint from .models import Callgraph, TransformationPoint, ConversionPoint
from .util import hashname from .util import hashname
@@ -68,3 +68,14 @@ def draw_callgraph_mermaid(g: Callgraph, split_duplicates=False, skip_title=Fals
ret += 'flowchart TD\n\n' ret += 'flowchart TD\n\n'
ret += ' %%defs:\n' + '\n'.join(d) + '\n\n %%edges:\n' + '\n'.join(e) ret += ' %%defs:\n' + '\n'.join(d) + '\n\n %%edges:\n' + '\n'.join(e)
return ret return ret
def draw_callseq_mermaid(callseq: list[ConversionPoint]):
ret = ['flowchart TD\n\n']
ret += [' %%defs:']
for cp_i, cp in enumerate(callseq):
ret.append(f' e{cp_i}["{shield_mermaid_name(str(cp))}"]')
ret += ['', '', ' %%edges:']
for cp_i, cp in enumerate(callseq[:-1]):
ret.append(f' e{cp_i}-->e{cp_i + 1}')
return '\n'.join(ret)

View File

@@ -38,3 +38,33 @@ def test_basic():
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123) dep = fn2(123)
assert dep == 123 assert dep == 123
def test_union_deps():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A | B) -> int:
if isinstance(dep, A):
return dep.a
else:
return int(dep.b)
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == 42
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == 123

144
tests/test_default_args.py Normal file
View File

@@ -0,0 +1,144 @@
from dataclasses import dataclass
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
type optC = str
def test_default_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]:
return dep.a, opt_dep
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == (42, '42')
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == (123, '42')
fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '1')
assert dep == (123, '1')
def test_optional_default_none_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B | None = None) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B | None:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]:
return dep.a, opt_dep
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == (42, '42')
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == (123, '42')
fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '1')
assert dep == (123, '1')
def test_default_inj_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int, opt_dep: optC = '42') -> A:
return A(i + int(opt_dep))
def consumer(dep: A) -> int:
return dep.a
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == 42
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == 123 + 42
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '0')
assert dep == 123
def test_default_graph_override():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int, opt_dep: optC = '42') -> A:
return A(i + int(opt_dep))
@repo.mark_injector()
def inject_opt_dep() -> optC:
return '12345'
def consumer(dep: A) -> int:
return dep.a
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == 42
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == 123 + 12345
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '0')
assert dep == 123

118
tests/test_pipeline.py Normal file
View File

@@ -0,0 +1,118 @@
from dataclasses import dataclass
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
type optC = str
def test_default_consumer_args():
repo = ConvRepo(store_sources=True)
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type ret1 = tuple[int, str]
def consumer1(dep: A, opt_dep: optC = '42') -> ret1:
return dep.a, opt_dep
def consumer2(dep: A, dep1: ret1) -> optC:
return str((dep.a, dep1))
p1 = repo.create_pipeline(
(B,),
[consumer1, consumer2],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p1(B(42.1))
assert res == "(42, (42, '42'))"
p2 = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer1],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p2(B(42.1))
assert res == (42, "(42, (42, '42'))")
def test_pipeline_with_subgraph_duplicates():
repo = ConvRepo()
b_to_a_calls = [0]
@repo.mark_injector()
def b_to_a(b: B) -> A:
b_to_a_calls[0] += 1
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type ret1 = tuple[int, str]
cons1_calls = [0]
cons2_calls = [0]
def consumer1(dep: A, opt_dep: optC = '42') -> A:
cons1_calls[0] += 1
return A(dep.a + int(opt_dep))
def consumer2(dep: A) -> optC:
cons2_calls[0] += 1
return str(dep.a)
p1 = repo.create_pipeline(
(B,),
[consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1, consumer2, consumer1],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
res = p1(B(42.1))
assert res.a == 42 + (42 * 31)
assert b_to_a_calls[0] == 1
assert cons1_calls[0] == 5
assert cons2_calls[0] == 4
def convertor(_5891515089754: "<class 'test_pipeline.B'>"):
# <function test_default_consumer_args.<locals>.b_to_a at 0x7f5bb1be02c0>
_5891515089643 = _conv_funcmap[8751987548204](b=_5891515089754)
# <function test_default_consumer_args.<locals>.consumer1 at 0x7f5bb1be0c20>
_8751987542640 = _conv_funcmap[8751987548354](dep=_5891515089643)
# <function test_default_consumer_args.<locals>.consumer2 at 0x7f5bb1be0540>
_8751987537115 = _conv_funcmap[8751987548244](dep=_5891515089643, dep1=_8751987542640)
return _8751987542640

View File

@@ -0,0 +1,84 @@
from dataclasses import dataclass
from breakshaft.models import ConversionPoint
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
@dataclass
class C:
c: int
@dataclass
class D:
d: str
def test_conv_point_tuple_unwrap():
def conv_into_bc(a: A) -> tuple[B, C]:
return B(a.a), C(a.a)
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
return B(a.a), (C(a.a), D(str(a.a)))
def conv_into_bcda(a: A) -> tuple[B, tuple[C, tuple[D, A]]]:
return B(a.a), (C(a.a), (D(str(a.a)), a))
cps_bc = ConversionPoint.from_fn(conv_into_bc)
assert len(cps_bc) == 3 # tuple[...], B, C
cps_bcd = ConversionPoint.from_fn(conv_into_bcd)
assert len(cps_bcd) == 5 # tuple[B,...], B, tuple[C,D], C, D
cps_bcda = ConversionPoint.from_fn(conv_into_bcda)
assert len(cps_bcda) == 6 # ignores (A,...)->A
def test_ignore_basic_types():
def conv_into_b_int(a: A) -> tuple[B, int]:
return B(a.a), a.a
cps = ConversionPoint.from_fn(conv_into_b_int)
assert len(cps) == 2 # tuple[...], B
def test_codegen_tuple_unwrap():
repo = ConvRepo(store_sources=True)
@repo.mark_injector()
def conv_into_bcd(a: A) -> tuple[B, tuple[C, D]]:
return B(a.a), (C(a.a), D(str(a.a)))
type Z = A
@repo.mark_injector()
def conv_d_a(d: D) -> Z:
return A(int(d.d))
def consumer1(dep: D) -> int:
return int(dep.d)
def consumer2(dep: Z) -> int:
return int(dep.a)
fn1 = repo.get_conversion((A,), consumer1, force_commutative=True, force_async=False, allow_async=False)
assert fn1(A(1)) == 1
fn2 = repo.get_conversion((A,), consumer2, force_commutative=True, force_async=False, allow_async=False)
assert fn2(A(1)) == 1
pip = repo.create_pipeline((A,), [consumer1, consumer2], force_commutative=True, force_async=False, allow_async=False)
assert pip(A(1)) == 1
print(pip.__breakshaft_render_src__)

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Annotated
import pytest
from breakshaft.models import ConversionPoint
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
def test_basic():
repo = ConvRepo()
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A) -> B:
return B(float(dep.a))
type NewA = A
type_remap = {'dep': NewA, 'return': B}
assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1
with pytest.raises(ValueError):
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
repo.mark_injector(type_remap={'i': int, 'return': NewA})(int_to_a)
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
assert fn1(42).b == 42.0
def consumer1(dep: B) -> A:
return A(int(dep.b))
p1 = repo.create_pipeline(
(int,),
[ConversionPoint.from_fn(consumer, type_remap=type_remap), consumer1, consumer],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
assert p1(123).b == 123.0