Compare commits

...

2 Commits

4 changed files with 69 additions and 17 deletions

View File

@@ -14,19 +14,16 @@ class GraphWalker:
from_types: frozenset[type], from_types: frozenset[type],
consumer_fn: Callable) -> Optional[Callgraph]: consumer_fn: Callable) -> Optional[Callgraph]:
into_types: frozenset[type] = extract_func_argtypes(consumer_fn)
branches: frozenset[Callgraph] = frozenset() branches: frozenset[Callgraph] = frozenset()
for into_type in into_types: # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
cg = cls.generate_callgraph_singletype(injectors, from_types, into_type) # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
if cg is None: # При этом, TypeAliasType также выступает в роли ключа преобразования
return None # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
branches |= {cg} type _tmp_type_for_consumer = object
variant = CallgraphVariant( injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))),
branches, frozenset()) return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
return Callgraph(frozenset({variant}))
@classmethod @classmethod
def generate_callgraph_singletype(cls, def generate_callgraph_singletype(cls,

View File

@@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
is_async_context_manager_factory, \ is_async_context_manager_factory, \
all_combinations, is_context_manager_factory all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -18,6 +18,7 @@ class ConversionPoint:
fn: Callable fn: Callable
injects: type injects: type
requires: tuple[type, ...] requires: tuple[type, ...]
opt_args: tuple[tuple[type, object], ...]
def __hash__(self): def __hash__(self):
return hash((self.fn, self.injects, self.requires)) return hash((self.fn, self.injects, self.requires))
@@ -38,7 +39,7 @@ class ConversionPoint:
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
@classmethod @classmethod
def from_fn(cls, func: Callable, rettype: Optional[type] = None): def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]:
if rettype is None: if rettype is None:
annot = get_type_hints(func) annot = get_type_hints(func)
rettype = annot.get('return') rettype = annot.get('return')
@@ -61,20 +62,34 @@ class ConversionPoint:
rettype = get_args(rettype)[0] rettype = get_args(rettype)[0]
argtypes: list[list[type]] = [] argtypes: list[list[type]] = []
orig_argtypes = extract_func_argtypes_seq(func) orig_args = extract_func_args(func)
for argtype in orig_argtypes: defaults = extract_func_arg_defaults(func)
orig_argtypes = []
for argname, argtype in orig_args:
orig_argtypes.append((argtype, argname in defaults.keys(), defaults.get(argname)))
default_map: list[tuple[bool, object]] = []
for argtype, has_default, default in orig_argtypes:
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
u_types = list(get_args(argtype)) + [argtype] u_types = list(get_args(argtype)) + [argtype]
else: else:
u_types = [argtype] u_types = [argtype]
default_map.append((has_default, default))
argtypes.append(u_types) argtypes.append(u_types)
argtype_combinations = all_combinations(argtypes) argtype_combinations = all_combinations(argtypes)
ret = [] ret = []
for argtype_combination in argtype_combinations: for argtype_combination in argtype_combinations:
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination))) req_args = []
opt_args = []
for argt, (has_default, default) in zip(argtype_combination, default_map):
if has_default:
opt_args.append((argt, default))
else:
req_args.append(argt)
ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args)))
# return InjectorPoint(func, rettype, argtypes)
return ret return ret

View File

@@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]:
return ret return ret
def extract_func_arg_defaults(func: Callable) -> dict[str, object]:
sig = inspect.signature(func)
defaults = {
name: param.default
for name, param in sig.parameters.items()
if param.default is not inspect._empty
}
return defaults
def is_context_manager_factory(obj: object) -> bool: def is_context_manager_factory(obj: object) -> bool:
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj) return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)

View File

@@ -38,3 +38,33 @@ def test_basic():
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123) dep = fn2(123)
assert dep == 123 assert dep == 123
def test_union_deps():
repo = ConvRepo()
@repo.mark_injector()
def b_to_a(b: B) -> A:
return A(int(b.b))
@repo.mark_injector()
def a_to_b(a: A) -> B:
return B(float(a.a))
@repo.mark_injector()
def int_to_a(i: int) -> A:
return A(i)
def consumer(dep: A | B) -> int:
if isinstance(dep, A):
return dep.a
else:
return int(dep.b)
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn1(B(42.1))
assert dep == 42
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
dep = fn2(123)
assert dep == 123