Compare commits

...

6 Commits

8 changed files with 101 additions and 31 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "breakshaft" name = "breakshaft"
version = "0.1.5" version = "0.1.6.post5"
description = "Library for in-time codegen for type conversion" description = "Library for in-time codegen for type conversion"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -79,11 +79,14 @@ class ConvRepo:
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
def add_conversion_points(self, conversion_points: Iterable[ConversionPoint]):
self._convertor_set |= set(conversion_points)
def add_injector(self, def add_injector(self,
func: Callable, func: Callable,
rettype: Optional[type] = None, rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None): type_remap: Optional[dict[str, type]] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype, type_remap=type_remap)) self.add_conversion_points(ConversionPoint.from_fn(func, rettype=rettype, type_remap=type_remap))
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]: def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
if len(cg.variants) == 0: if len(cg.variants) == 0:
@@ -191,9 +194,6 @@ class ForkedConvRepo(ConvRepo):
self._convertor_set = fork_with self._convertor_set = fork_with
self._base_repo = fork_from self._base_repo = fork_from
def add_injector(self, func: Callable, rettype: Optional[type] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
@property @property
def convertor_set(self): def convertor_set(self):
return self._base_repo.convertor_set | self._convertor_set return self._base_repo.convertor_set | self._convertor_set

View File

@@ -4,7 +4,7 @@ from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type, universal_qualname
from typing import Iterable from typing import Iterable
@@ -90,7 +90,7 @@ class GraphWalker:
if subg is not None: if subg is not None:
variant_subgraphs.add(subg) variant_subgraphs.add(subg)
consumed = frozenset(point.requires) & from_types consumed = (frozenset(point.requires) | frozenset(point.opt_args)) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant) head = head.add_subgraph_variant(variant)

View File

@@ -56,7 +56,8 @@ class ConversionPoint:
def from_fn(cls, def from_fn(cls,
func: Callable, func: Callable,
rettype: Optional[type] = None, rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None) -> list[ConversionPoint]: type_remap: Optional[dict[str, type]] = None,
ignore_basictype_return: bool = False) -> list[ConversionPoint]:
if type_remap is None: if type_remap is None:
annot = get_type_hints(func) annot = get_type_hints(func)
else: else:
@@ -86,7 +87,7 @@ class ConversionPoint:
if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
fn_rettype = get_args(fn_rettype)[0] fn_rettype = get_args(fn_rettype)[0]
if is_basic_type_annot(rettype): if not ignore_basictype_return and is_basic_type_annot(rettype):
return [] return []
ret = [] ret = []
@@ -96,7 +97,10 @@ class ConversionPoint:
if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped: if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped:
for t in tuple_unwrapped: for t in tuple_unwrapped:
if not is_basic_type_annot(t): if not is_basic_type_annot(t):
ret += ConversionPoint.from_fn(func, rettype=t, type_remap=type_remap) ret += ConversionPoint.from_fn(func,
rettype=t,
type_remap=type_remap,
ignore_basictype_return=ignore_basictype_return)
argtypes: list[list[type]] = [] argtypes: list[list[type]] = []
orig_args = extract_func_args(func, type_remap) orig_args = extract_func_args(func, type_remap)

View File

@@ -7,7 +7,7 @@ import importlib.resources
import jinja2 import jinja2
from .models import ConversionPoint from .models import ConversionPoint
from .util import hashname, get_tuple_types, is_basic_type_annot from .util import hashname, get_tuple_types, is_basic_type_annot, universal_qualname
class ConvertorRenderer(Protocol): class ConvertorRenderer(Protocol):
@@ -51,6 +51,7 @@ class ConversionRenderData:
is_ctxmanager: bool is_ctxmanager: bool
is_async: bool is_async: bool
unwrap_tuple_result: UnwprappedTuple unwrap_tuple_result: UnwprappedTuple
_injection: ConversionPoint
@classmethod @classmethod
def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): def from_inj(cls, inj: ConversionPoint, provided_types: set[type]):
@@ -74,7 +75,8 @@ class ConversionRenderData:
fnargs, fnargs,
inj.is_ctx_manager, inj.is_ctx_manager,
inj.is_async, inj.is_async,
unwrap_tuple_result) unwrap_tuple_result,
inj)
@dataclass @dataclass
@@ -86,16 +88,18 @@ class ConversionArgRenderData:
def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]: def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[ConversionRenderData]:
deduplicated_conv_models: list[ConversionRenderData] = [] deduplicated_conv_models: list[ConversionRenderData] = []
deduplicated_hashes = set()
for conv_model in conversion_models: for conv_model in conversion_models:
if conv_model not in deduplicated_conv_models: if hash((conv_model.inj_hash, conv_model.funchash)) not in deduplicated_hashes:
deduplicated_conv_models.append(conv_model) deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
continue continue
argnames = list(map(lambda x: x[1], conv_model.funcargs)) argnames = list(map(lambda x: x[1], conv_model.funcargs))
argument_changed = False argument_changed = False
found_model = False found_model = False
for m in deduplicated_conv_models: for m in deduplicated_conv_models:
if not found_model and m == conv_model: if not found_model and m.funchash == conv_model.funchash:
found_model = True found_model = True
if found_model and m.inj_hash in argnames: if found_model and m.inj_hash in argnames:
@@ -103,9 +107,28 @@ def deduplicate_callseq(conversion_models: list[ConversionRenderData]) -> list[C
break break
if argument_changed: if argument_changed:
deduplicated_conv_models.append(conv_model) deduplicated_conv_models.append(conv_model)
deduplicated_hashes.add(hash((conv_model.inj_hash, conv_model.funchash)))
return deduplicated_conv_models return deduplicated_conv_models
def render_data_from_callseq(from_types: Sequence[type],
fnmap: dict[int, Callable],
callseq: Sequence[ConversionPoint]):
conversion_models: list[ConversionRenderData] = []
ret_hash = 0
for call_id, call in enumerate(callseq):
provided_types = set(from_types)
for _call in callseq[:call_id]:
provided_types |= {_call.injects}
provided_types |= set(_call.requires)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types)
conversion_models.append(conv)
return conversion_models
class InTimeGenerationConvertorRenderer(ConvertorRenderer): class InTimeGenerationConvertorRenderer(ConvertorRenderer):
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
@@ -128,19 +151,10 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
store_sources: bool = False) -> Callable: store_sources: bool = False) -> Callable:
fnmap = {} fnmap = {}
conversion_models: list[ConversionRenderData] = [] conversion_models: list[ConversionRenderData] = render_data_from_callseq(from_types, fnmap, callseq)
ret_hash = 0 ret_hash = 0
is_async = force_async is_async = force_async
for call_id, call in enumerate(callseq): for call_id, call in enumerate(callseq):
provided_types = set(from_types)
for _call in callseq[:call_id]:
provided_types |= {_call.injects}
provided_types |= set(_call.requires)
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call, provided_types)
conversion_models.append(conv)
if call.is_async: if call.is_async:
is_async = True is_async = True

View File

@@ -134,9 +134,21 @@ def is_basic_type_annot(type_annot) -> bool:
def universal_qualname(any: Any) -> str: def universal_qualname(any: Any) -> str:
ret = ''
if hasattr(any, '__qualname__'): if hasattr(any, '__qualname__'):
return any.__qualname__ ret = any.__qualname__
if hasattr(any, '__name__'): elif hasattr(any, '__name__'):
return any.__name__ ret = any.__name__
else:
ret = str(any)
return str(any) ret = (ret
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_')
.replace('\'', '')
.replace('<', '')
.replace('>', ''))
return ret

View File

@@ -47,6 +47,37 @@ def test_default_consumer_args():
assert dep == (123, '1') assert dep == (123, '1')
def test_optional_default_none_consumer_args():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B | None = None) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B | None:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]:
return dep.a, opt_dep
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == (42, '42')
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == (123, '42')
fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn3(123, '1')
assert dep == (123, '1')
def test_default_inj_args(): def test_default_inj_args():
repo = ConvRepo() repo = ConvRepo()
@@ -79,7 +110,6 @@ def test_default_inj_args():
def test_default_graph_override(): def test_default_graph_override():
repo = ConvRepo() repo = ConvRepo()
@repo.mark_injector() @repo.mark_injector()

View File

@@ -17,7 +17,7 @@ type optC = str
def test_default_consumer_args(): def test_default_consumer_args():
repo = ConvRepo() repo = ConvRepo(store_sources=True)
@repo.mark_injector() @repo.mark_injector()
def b_to_a(b: B) -> A: def b_to_a(b: B) -> A:
@@ -106,3 +106,13 @@ def test_pipeline_with_subgraph_duplicates():
assert b_to_a_calls[0] == 1 assert b_to_a_calls[0] == 1
assert cons1_calls[0] == 5 assert cons1_calls[0] == 5
assert cons2_calls[0] == 4 assert cons2_calls[0] == 4
def convertor(_5891515089754: "<class 'test_pipeline.B'>"):
# <function test_default_consumer_args.<locals>.b_to_a at 0x7f5bb1be02c0>
_5891515089643 = _conv_funcmap[8751987548204](b=_5891515089754)
# <function test_default_consumer_args.<locals>.consumer1 at 0x7f5bb1be0c20>
_8751987542640 = _conv_funcmap[8751987548354](dep=_5891515089643)
# <function test_default_consumer_args.<locals>.consumer2 at 0x7f5bb1be0540>
_8751987537115 = _conv_funcmap[8751987548244](dep=_5891515089643, dep1=_8751987542640)
return _8751987542640