Allow passing [ConversionPoint] into get_conversion with a type remap for ConversionPoint

This commit is contained in:
2025-08-19 02:32:15 +03:00
parent 742c21e199
commit 52d82550e6
5 changed files with 112 additions and 30 deletions

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence
import collections.abc
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any, Sequence, Iterable
from .graph_walker import GraphWalker from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
from .util import extract_return_type from .util import extract_return_type, universal_qualname
Tin = TypeVarTuple('Tin') Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout') Tout = TypeVar('Tout')
@@ -69,8 +71,11 @@ class ConvRepo:
def convertor_set(self): def convertor_set(self):
return self._convertor_set return self._convertor_set
def add_injector(self, func: Callable, rettype: Optional[type] = None): def add_injector(self,
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype)) func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype, type_remap=type_remap))
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]: def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
if len(cg.variants) == 0: if len(cg.variants) == 0:
@@ -97,12 +102,12 @@ class ConvRepo:
def get_callseq(self, def get_callseq(self,
injectors: frozenset[ConversionPoint], injectors: frozenset[ConversionPoint],
from_types: frozenset[type], from_types: frozenset[type],
fn: Callable, fn: Callable | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool) -> list[ConversionPoint]: force_commutative: bool) -> list[ConversionPoint]:
cg = self.walker.generate_callgraph(injectors, from_types, fn) cg = self.walker.generate_callgraph(injectors, from_types, fn)
if cg is None: if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}') raise ValueError(f'Unable to compute conversion graph on {from_types}->{universal_qualname(fn)}')
exploded = self.walker.explode_callgraph_branches(cg, from_types) exploded = self.walker.explode_callgraph_branches(cg, from_types)
@@ -116,6 +121,14 @@ class ConvRepo:
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]]))) callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
if len(callseq) > 0: if len(callseq) > 0:
injects = None
if isinstance(fn, collections.abc.Iterable):
for f in fn:
injects = f.injects
break
elif isinstance(fn, ConversionPoint):
injects = fn.injects
else:
injects = extract_return_type(fn) injects = extract_return_type(fn)
callseq[-1] = callseq[-1].copy_with(injects=injects) callseq[-1] = callseq[-1].copy_with(injects=injects)
@@ -123,7 +136,7 @@ class ConvRepo:
def get_conversion(self, def get_conversion(self,
from_types: Sequence[type[Unpack[Tin]]], from_types: Sequence[type[Unpack[Tin]]],
fn: Callable[..., Tout], fn: Callable[..., Tout] | Iterable[ConversionPoint] | ConversionPoint,
force_commutative: bool = True, force_commutative: bool = True,
allow_async: bool = True, allow_async: bool = True,
allow_sync: bool = True, allow_sync: bool = True,
@@ -138,9 +151,9 @@ class ConvRepo:
setattr(ret_fn, '__breakshaft_callseq__', callseq) setattr(ret_fn, '__breakshaft_callseq__', callseq)
return ret_fn return ret_fn
def mark_injector(self, *, rettype: Optional[type] = None): def mark_injector(self, *, rettype: Optional[type] = None, type_remap: Optional[dict[str, type]] = None):
def inner(func: Callable): def inner(func: Callable):
self.add_injector(func) self.add_injector(func, rettype=rettype, type_remap=type_remap)
return func return func
return inner return inner

View File

@@ -1,9 +1,11 @@
import collections.abc
import typing import typing
from types import NoneType from types import NoneType
from typing import Callable, Optional from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type
from typing import Iterable
class GraphWalker: class GraphWalker:
@@ -12,16 +14,24 @@ class GraphWalker:
def generate_callgraph(cls, def generate_callgraph(cls,
injectors: frozenset[ConversionPoint], injectors: frozenset[ConversionPoint],
from_types: frozenset[type], from_types: frozenset[type],
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable | Iterable[ConversionPoint] | ConversionPoint) -> Optional[Callgraph]:
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
rettype = extract_return_type(consumer_fn)
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
# При этом, TypeAliasType также выступает в роли ключа преобразования # При этом, TypeAliasType также выступает в роли ключа преобразования
# Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
type _tmp_type_for_consumer = object type _tmp_type_for_consumer = object
if isinstance(consumer_fn, collections.abc.Iterable):
new_consumer_injectors = set()
for fn in consumer_fn:
new_consumer_injectors.add(fn.copy_with(injects=_tmp_type_for_consumer))
injectors |= new_consumer_injectors
elif isinstance(consumer_fn, ConversionPoint):
injectors |= set(consumer_fn.copy_with(injects=_tmp_type_for_consumer))
else:
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer)) injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer) return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
@@ -143,7 +153,7 @@ class GraphWalker:
if len(variants) > 1: if len(variants) > 1:
# sorting by first injector func name for creating minimal cosistancy # sorting by first injector func name for creating minimal cosistancy
# could lead to heizenbugs due to incosistancy in path selection between calls # could lead to heizenbugs due to incosistancy in path selection between calls
variants.sort(key=lambda x: x.injector.fn.__qualname__) variants.sort(key=lambda x: universal_qualname(x.injector.fn))
return variants return variants
if len(variants) < 2: if len(variants) < 2:

View File

@@ -11,7 +11,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
is_async_context_manager_factory, \ is_async_context_manager_factory, \
all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \ all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args, extract_func_argnames, \
get_tuple_types, is_basic_type_annot get_tuple_types, is_basic_type_annot, universal_qualname
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -34,15 +34,8 @@ class ConversionPoint:
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))
def __repr__(self): def __repr__(self):
if '__qualname__' in dir(self.injects): injects_name = universal_qualname(self.injects)
injects_name = self.injects.__qualname__ fn_name = universal_qualname(self.fn)
else:
injects_name = str(self.injects)
if '__qualname__' in dir(self.fn):
fn_name = self.fn.__qualname__
else:
fn_name = str(self.fn)
return f'({",".join(map(str, self.requires))}) -> {injects_name}: {fn_name}' return f'({",".join(map(str, self.requires))}) -> {injects_name}: {fn_name}'
@@ -60,9 +53,15 @@ class ConversionPoint:
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
@classmethod @classmethod
def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: def from_fn(cls,
func: Callable,
rettype: Optional[type] = None,
type_remap: Optional[dict[str, type]] = None) -> list[ConversionPoint]:
if type_remap is None:
annot = get_type_hints(func) annot = get_type_hints(func)
else:
annot = type_remap
fn_rettype = annot.get('return') fn_rettype = annot.get('return')
if rettype is None: if rettype is None:
rettype = fn_rettype rettype = fn_rettype
@@ -97,10 +96,10 @@ class ConversionPoint:
if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped: if len(tuple_unwrapped) > 0 and Ellipsis not in tuple_unwrapped:
for t in tuple_unwrapped: for t in tuple_unwrapped:
if not is_basic_type_annot(t): if not is_basic_type_annot(t):
ret += ConversionPoint.from_fn(func, rettype=t) ret += ConversionPoint.from_fn(func, rettype=t, type_remap=type_remap)
argtypes: list[list[type]] = [] argtypes: list[list[type]] = []
orig_args = extract_func_args(func) orig_args = extract_func_args(func, type_remap)
defaults = extract_func_arg_defaults(func) defaults = extract_func_arg_defaults(func)
orig_argtypes = [] orig_argtypes = []

View File

@@ -19,9 +19,13 @@ def extract_return_type(func: Callable) -> Optional[type]:
return hints.get('return') return hints.get('return')
def extract_func_args(func: Callable) -> list[tuple[str, type]]: def extract_func_args(func: Callable, type_hints_remap: Optional[dict[str, type]] = None) -> list[tuple[str, type]]:
sig = inspect.signature(func) sig = inspect.signature(func)
if type_hints_remap is None:
type_hints = get_type_hints(func) type_hints = get_type_hints(func)
else:
type_hints = type_hints_remap
params = sig.parameters params = sig.parameters
args_info = [] args_info = []
@@ -127,3 +131,12 @@ def is_basic_type_annot(type_annot) -> bool:
return all(is_basic_type_annot(arg) for arg in args) return all(is_basic_type_annot(arg) for arg in args)
return False return False
def universal_qualname(any: Any) -> str:
if hasattr(any, '__qualname__'):
return any.__qualname__
if hasattr(any, '__name__'):
return any.__name__
return str(any)

View File

@@ -0,0 +1,47 @@
from dataclasses import dataclass
from typing import Annotated
import pytest
from breakshaft.models import ConversionPoint
from src.breakshaft.convertor import ConvRepo
@dataclass
class A:
a: int
@dataclass
class B:
b: float
def test_basic():
repo = ConvRepo()
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
type HackInt = int
def consumer(dep: A) -> int:
return dep.a
type NewA = A
type_remap = {'dep': NewA, 'return': Annotated[HackInt, 'fuck']}
assert len(ConversionPoint.from_fn(consumer, type_remap=type_remap)) == 1
with pytest.raises(ValueError):
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
repo.mark_injector(type_remap={'i': int, 'return': NewA})(int_to_a)
fn1 = repo.get_conversion((int,), ConversionPoint.from_fn(consumer, type_remap=type_remap),
force_commutative=True, force_async=False, allow_async=False)
assert fn1(42) == 42