Compare commits

...

9 Commits

9 changed files with 227 additions and 56 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "breakshaft" name = "breakshaft"
version = "0.1.4" version = "0.1.6.post5"
description = "Library for in-time codegen for type conversion" description = "Library for in-time codegen for type conversion"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence
import collections.abc
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence, Iterable
from .graph_walker import GraphWalker from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
from .util import extract_return_type from .util import extract_return_type, universal_qualname
Tin = TypeVarTuple('Tin') Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout') Tout = TypeVar('Tout')
@@ -36,7 +38,7 @@ class ConvRepo:
def create_pipeline(self, def create_pipeline(self,
from_types: Sequence[type], from_types: Sequence[type],
fns: Sequence[Callable], fns: Sequence[Callable | Iterable[ConversionPoint] | ConversionPoint],
force_commutative: bool = True, force_commutative: bool = True,
allow_async: bool = True, allow_async: bool = True,
allow_sync: bool = True, allow_sync: bool = True,
@@ -48,7 +50,15 @@ class ConvRepo:
from_types = tuple(from_types) from_types = tuple(from_types)
for fn in fns: for fn in fns:
injects = extract_return_type(fn) injects = None
if isinstance(fn, collections.abc.Iterable):
for f in fn:
injects = f.injects
break
elif isinstance(fn, ConversionPoint):
injects = fn.injects
else:
injects = extract_return_type(fn)
callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative) callseq = self.get_callseq(filtered_injectors, frozenset(from_types), fn, force_commutative)
@@ -69,8 +79,14 @@ class ConvRepo:
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
def add_injector(self, func: Callable, rettype: Optional[type] = None): def add_conversion_points(self, conversion_points: Iterable[ConversionPoint]):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype)) self._convertor_set |= set(conversion_points)
def add_injector(self,
func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None):
self.add_conversion_points(ConversionPoint.from_fn(func, rettype=rettype, type_remap=type_remap))
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]: def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
if len(cg.variants) == 0: if len(cg.variants) == 0:
@@ -97,12 +113,12 @@ class ConvRepo:
def get_callseq(self, def get_callseq(self,
injectors: frozenset[ConversionPoint], injectors: frozenset[ConversionPoint],
from_types: frozenset[type], from_types: frozenset[type],
fn: Callable, fn: Callable | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool) -> list[ConversionPoint]: force_commutative: bool) -> list[ConversionPoint]:
cg = self.walker.generate_callgraph(injectors, from_types, fn) cg = self.walker.generate_callgraph(injectors, from_types, fn)
if cg is None: if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') raise ValueError(f'Unable to compute conversion graph on {from_types}->{universal_qualname(fn)}')
exploded = self.walker.explode_callgraph_branches(cg, from_types) exploded = self.walker.explode_callgraph_branches(cg, from_types)
@@ -116,14 +132,22 @@ class ConvRepo:
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
if len(callseq) > 0: if len(callseq) > 0:
injects = extract_return_type(fn) injects = None
if isinstance(fn, collections.abc.Iterable):
for f in fn:
injects = f.injects
break
elif isinstance(fn, ConversionPoint):
injects = fn.injects
else:
injects = extract_return_type(fn)
callseq[-1] = callseq[-1].copy_with(injects=injects) callseq[-1] = callseq[-1].copy_with(injects=injects)
return callseq return callseq
def get_conversion(self, def get_conversion(self,
from_types: Sequence[type[Unpack[Tin]]], from_types: Sequence[type[Unpack[Tin]]],
fn: Callable[..., Tout], fn: Callable[..., Tout] | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool = True, force_commutative: bool = True,
allow_async: bool = True, allow_async: bool = True,
allow_sync: bool = True, allow_sync: bool = True,
@@ -138,9 +162,9 @@ class ConvRepo:
setattr(ret_fn, '__breakshaft_callseq__', callseq) setattr(ret_fn, '__breakshaft_callseq__', callseq)
return ret_fn return ret_fn
def mark_injector(self, *, rettype: Optional[type] = None): def mark_injector(self, *, rettype: Optional[type] = None, type_remap: Optional[dict[str, type]] = None):
def inner(func: Callable): def inner(func: Callable):
self.add_injector(func) self.add_injector(func, rettype=rettype, type_remap=type_remap)
return func return func
return inner return inner
@@ -170,9 +194,6 @@ class ForkedConvRepo(ConvRepo):
self._convertor_set = fork_with self._convertor_set = fork_with
self._base_repo = fork_from self._base_repo = fork_from
def add_injector(self, func: Callable, rettype: Optional[type] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
@property @property
def convertor_set(self): def convertor_set(self):
return self._base_repo.convertor_set | self._convertor_set return self._base_repo.convertor_set | self._convertor_set

View File

@@ -1,9 +1,11 @@
import collections.abc
import typing import typing
from types import NoneType from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type, universal_qualname
from typing import Iterable
class GraphWalker: class GraphWalker:
@@ -12,17 +14,25 @@ class GraphWalker:
def generate_callgraph(cls, def generate_callgraph(cls,
injectors: frozenset[ConversionPoint], injectors: frozenset[ConversionPoint],
from_types: frozenset[type], from_types: frozenset[type],
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable | Iterable[ConversionPoint] | ConversionPoint) -> Optional[Callgraph]:
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
rettype = extract_return_type(consumer_fn)
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
# При этом, TypeAliasType также выступает в роли ключа преобразования # При этом, TypeAliasType также выступает в роли ключа преобразования
# Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
type _tmp_type_for_consumer = object type _tmp_type_for_consumer = object
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
if isinstance(consumer_fn, collections.abc.Iterable):
new_consumer_injectors = set()
for fn in consumer_fn:
new_consumer_injectors.add(fn.copy_with(injects=_tmp_type_for_consumer))
injectors |= new_consumer_injectors
elif isinstance(consumer_fn, ConversionPoint):
injectors |= set(consumer_fn.copy_with(injects=_tmp_type_for_consumer))
else:
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer) return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
@@ -80,7 +90,7 @@ class GraphWalker:
if subg is not None: if subg is not None:
variant_subgraphs.add(subg) variant_subgraphs.add(subg)
consumed = frozenset(point.requires) & from_types consumed = (frozenset(point.requires) | frozenset(point.opt_args)) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant) head = head.add_subgraph_variant(variant)
@@ -143,7 +153,7 @@ class GraphWalker:
if len(variants) > 1: if len(variants) > 1:
# sorting by first injector func name for creating minimal cosistancy # sorting by first injector func name for creating minimal cosistancy
# could lead to heizenbugs due to incosistancy in path selection between calls # could lead to heizenbugs due to incosistancy in path selection between calls
variants.sort(key=lambda x: x.injector.fn.__qualname__) variants.sort(key=lambda x: universal_qualname(x.injector.fn))
return variants return variants
if len(variants) < 2: if len(variants) < 2:

View File

@@ -11,7 +11,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
is_async_context_manager_factory, \ is_async_context_manager_factory, \
all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \ all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \
get_tuple_types, is_basic_type_annot get_tuple_types, is_basic_type_annot, universal_qualname
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -34,15 +34,8 @@ class ConversionPoint:
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))
def __repr__(self): def __repr__(self):
if '__qualname__' in dir(self.injects): injects_name = universal_qualname(self.injects)
injects_name = self.injects.__qualname__ fn_name = universal_qualname(self.fn)
else:
injects_name = str(self.injects)
if '__qualname__' in dir(self.fn):
fn_name = self.fn.__qualname__
else:
fn_name = str(self.fn)
return f'({",".join(map(str, self.requires))}) -> {injects_name}: {fn_name}' return f'({",".join(map(str, self.requires))}) -> {injects_name}: {fn_name}'
@@ -60,9 +53,16 @@ class ConversionPoint:
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
@classmethod @classmethod
def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: def from_fn(cls,
func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None,
ignore_basictype_return: bool = False) -> list[ConversionPoint]:
if type_remap is None:
annot = get_type_hints(func)
else:
annot = type_remap
annot = get_type_hints(func)
fn_rettype = annot.get('return') fn_rettype = annot.get('return')
if rettype is None: if rettype is None:
rettype = fn_rettype rettype = fn_rettype
@@ -87,7 +87,7 @@ class ConversionPoint:
if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
fn_rettype = get_args(fn_rettype)[0] fn_rettype = get_args(fn_rettype)[0]
if is_basic_type_annot(rettype): if not ignore_basictype_return and is_basic_type_annot(rettype):
return [] return []
ret = [] ret = []
@@ -97,10 +97,13 @@ class ConversionPoint:
if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped: if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped:
for t in tuple_unwrapped: for t in tuple_unwrapped:
if not is_basic_type_annot(t): if not is_basic_type_annot(t):
ret += ConversionPoint.from_fn(func, rettype=t) ret += ConversionPoint.from_fn(func,
rettype=t,
type_remap=type_remap,
ignore_basictype_return=ignore_basictype_return)
argtypes: list[list[type]] = [] argtypes: list[list[type]] = []
orig_args = extract_func_args(func) orig_args = extract_func_args(func, type_remap)
defaults = extract_func_arg_defaults(func) defaults = extract_func_arg_defaults(func)
orig_argtypes = [] orig_argtypes = []

View File

@@ -7,7 +7,7 @@ import importlib.resources
import jinja2 import jinja2
from .models import ConversionPoint from .models import ConversionPoint
from .util import hashname, get_tuple_types, is_basic_type_annot from .util import hashname, get_tuple_types, is_basic_type_annot, universal_qualname
class ConvertorRenderer(Protocol): class ConvertorRenderer(Protocol):
@@ -51,6 +51,7 @@ class ConversionRenderData:
is_ctxmanager: bool is_ctxmanager: bool
is_async: bool is_async: bool
unwrap_tuple_result: UnwprappedTuple unwrap_tuple_result: UnwprappedTuple
_injection: ConversionPoint
@classmethod @classmethod
def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): def from_inj(cls, inj: ConversionPoint, provided_types: set[type]):
@@ -74,7 +75,8 @@ class ConversionRenderData:
fnargs, fnargs,
inj.is_ctx_manager, inj.is_ctx_manager,
inj.is_async, inj.is_async,
unwrap_tuple_result) unwrap_tuple_result,
inj)
@dataclass @dataclass
@@ -86,16 +88,18 @@ class ConversionArgRenderData:
def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]: def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]:
deduplicated_conv_models: list[ConversionRenderData] = [] deduplicated_conv_models: list[ConversionRenderData] = []
deduplicated_hashes = set()
for conv_model in conversion_models: for conv_model in conversion_models:
if conv_model not in deduplicated_conv_models: if hash((conv_model.inj_hash, conv_model.funchash)) not in deduplicated_hashes:
deduplicated_conv_models.append(conv_model) deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
continue continue
argnames = list(map(lambda x: x[1], conv_model.funcargs)) argnames = list(map(lambda x: x[1], conv_model.funcargs))
argument_changed = False argument_changed = False
found_model = False found_model = False
for m in deduplicated_conv_models: for m in deduplicated_conv_models:
if not found_model and m == conv_model: if not found_model and m.funchash == conv_model.funchash:
found_model = True found_model = True
if found_model and m.inj_hash in argnames: if found_model and m.inj_hash in argnames:
@@ -103,9 +107,28 @@ def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[C
break break
if argument_changed: if argument_changed:
deduplicated_conv_models.append(conv_model) deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
return deduplicated_conv_models return deduplicated_conv_models
def render_data_from_callseq(from_types: Sequence[type],
fnmap: dict[int, Callable],
callseq: Sequence[ConversionPoint]):
conversion_models: list[ConversionRenderData] = []
ret_hash = 0
for call_id, call in enumerate(callseq):
provided_types = set(from_types)
for _call in callseq[:call_id]:
provided_types |= {_call.injects}
provided_types |= set(_call.requires)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types)
conversion_models.append(conv)
return conversion_models
class InTimeGenerationConvertorRenderer(ConvertorRenderer): class InTimeGenerationConvertorRenderer(ConvertorRenderer):
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
@@ -128,19 +151,10 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
store_sources: bool = False) -> Callable: store_sources: bool = False) -> Callable:
fnmap = {} fnmap = {}
conversion_models: list[ConversionRenderData] = [] conversion_models: list[ConversionRenderData] = render_data_from_callseq(from_types, fnmap, callseq)
ret_hash = 0 ret_hash = 0
is_async = force_async is_async = force_async
for call_id, call in enumerate(callseq): for call_id, call in enumerate(callseq):
provided_types = set(from_types)
for _call in callseq[:call_id]:
provided_types |= {_call.injects}
provided_types |= set(_call.requires)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types)
conversion_models.append(conv)
if call.is_async: if call.is_async:
is_async = True is_async = True

View File

@@ -19,9 +19,13 @@ def extract_return_type(func: Callable) -> Optional[type]:
return hints.get('return') return hints.get('return')
def extract_func_args(func: Callable) -> list[tuple[str, type]]: def extract_func_args(func: Callable, type_hints_remap: Optional[dict[str, type]] = None) -> list[tuple[str, type]]:
sig = inspect.signature(func) sig = inspect.signature(func)
type_hints = get_type_hints(func) if type_hints_remap is None:
type_hints = get_type_hints(func)
else:
type_hints = type_hints_remap
params = sig.parameters params = sig.parameters
args_info = [] args_info = []
@@ -127,3 +131,24 @@ def is_basic_type_annot(type_annot) -> bool:
return all(is_basic_type_annot(arg) for arg in args) return all(is_basic_type_annot(arg) for arg in args)
return False return False
def universal_qualname(any: Any) -> str:
ret = ''
if hasattr(any, '__qualname__'):
ret = any.__qualname__
elif hasattr(any, '__name__'):
ret = any.__name__
else:
ret = str(any)
ret = (ret
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_')
.replace('\'', '')
.replace('<', '')
.replace('>', ''))
return ret

View File

@@ -47,6 +47,37 @@ def test_default_consumer_args():
assert dep == (123, '1') assert dep == (123, '1')
def test_optional_default_none_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B | None = None) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B | None:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]:
return dep.a, opt_dep
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == (42, '42')
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == (123, '42')
fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '1')
assert dep == (123, '1')
def test_default_inj_args(): def test_default_inj_args():
repo = ConvRepo() repo = ConvRepo()
@@ -79,7 +110,6 @@ def test_default_inj_args():
def test_default_graph_override(): def test_default_graph_override():
repo = ConvRepo() repo = ConvRepo()
@repo.mark_injector() @repo.mark_injector()
@@ -111,4 +141,4 @@ def test_default_graph_override():
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '0') dep = fn3(123, '0')
assert dep == 123 assert dep == 123

View File

@@ -17,7 +17,7 @@ type optC = str
def test_default_consumer_args(): def test_default_consumer_args():
repo = ConvRepo() repo = ConvRepo(store_sources=True)
@repo.mark_injector() @repo.mark_injector()
def b_to_a(b: B) -> A: def b_to_a(b: B) -> A:
@@ -106,3 +106,13 @@ def test_pipeline_with_subgraph_duplicates():
assert b_to_a_calls[0] == 1 assert b_to_a_calls[0] == 1
assert cons1_calls[0] == 5 assert cons1_calls[0] == 5
assert cons2_calls[0] == 4 assert cons2_calls[0] == 4
def convertor(_5891515089754: "<class 'test_pipeline.B'>"):
# <function test_default_consumer_args.<locals>.b_to_a at 0x7f5bb1be02c0>
_5891515089643 = _conv_funcmap[8751987548204](b=_5891515089754)
# <function test_default_consumer_args.<locals>.consumer1 at 0x7f5bb1be0c20>
_8751987542640 = _conv_funcmap[8751987548354](dep=_5891515089643)
# <function test_default_consumer_args.<locals>.consumer2 at 0x7f5bb1be0540>
_8751987537115 = _conv_funcmap[8751987548244](dep=_5891515089643, dep1=_8751987542640)
return _8751987542640

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Annotated
import pytest
from breakshaft.models import ConversionPoint
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
def test_basic():
repo = ConvRepo()
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A) -> B:
return B(float(dep.a))
type NewA = A
type_remap = {'dep': NewA, 'return': B}
assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1
with pytest.raises(ValueError):
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
repo.mark_injector(type_remap={'i': int, 'return': NewA})(int_to_a)
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
assert fn1(42).b == 42.0
def consumer1(dep: B) -> A:
return A(int(dep.b))
p1 = repo.create_pipeline(
(int,),
[ConversionPoint.from_fn(consumer, type_remap=type_remap), consumer1, consumer],
force_commutative=True,
allow_sync=True,
allow_async=False,
force_async=False
)
assert p1(123).b == 123.0