Default-параметры в зависимостях преобразователей #4
@@ -14,19 +14,16 @@ class GraphWalker:
|
|||||||
from_types: frozenset[type],
|
from_types: frozenset[type],
|
||||||
consumer_fn: Callable) -> Optional[Callgraph]:
|
consumer_fn: Callable) -> Optional[Callgraph]:
|
||||||
|
|
||||||
into_types: frozenset[type] = extract_func_argtypes(consumer_fn)
|
|
||||||
|
|
||||||
branches: frozenset[Callgraph] = frozenset()
|
branches: frozenset[Callgraph] = frozenset()
|
||||||
|
|
||||||
for into_type in into_types:
|
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
|
||||||
cg = cls.generate_callgraph_singletype(injectors, from_types, into_type)
|
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
|
||||||
if cg is None:
|
# При этом, TypeAliasType также выступает в роли ключа преобразования
|
||||||
return None
|
# Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
|
||||||
branches |= {cg}
|
type _tmp_type_for_consumer = object
|
||||||
variant = CallgraphVariant(
|
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
|
||||||
ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))),
|
|
||||||
branches, frozenset())
|
return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
|
||||||
return Callgraph(frozenset({variant}))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_callgraph_singletype(cls,
|
def generate_callgraph_singletype(cls,
|
||||||
@@ -72,6 +69,16 @@ class GraphWalker:
|
|||||||
variant_subgraphs.add(subg)
|
variant_subgraphs.add(subg)
|
||||||
|
|
||||||
if not dead_end:
|
if not dead_end:
|
||||||
|
|
||||||
|
for opt in point.opt_args:
|
||||||
|
subg = cls.generate_callgraph_singletype(injectors,
|
||||||
|
from_types,
|
||||||
|
opt,
|
||||||
|
visited_path=visited_path.copy(),
|
||||||
|
visited_types=visited_types.copy())
|
||||||
|
if subg is not None:
|
||||||
|
variant_subgraphs.add(subg)
|
||||||
|
|
||||||
consumed = frozenset(point.requires) & from_types
|
consumed = frozenset(point.requires) & from_types
|
||||||
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
|
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
|
||||||
head = head.add_subgraph_variant(variant)
|
head = head.add_subgraph_variant(variant)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
|
|||||||
|
|
||||||
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
|
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
|
||||||
is_async_context_manager_factory, \
|
is_async_context_manager_factory, \
|
||||||
all_combinations, is_context_manager_factory
|
all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -18,6 +18,7 @@ class ConversionPoint:
|
|||||||
fn: Callable
|
fn: Callable
|
||||||
injects: type
|
injects: type
|
||||||
requires: tuple[type, ...]
|
requires: tuple[type, ...]
|
||||||
|
opt_args: tuple[type, ...]
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.fn, self.injects, self.requires))
|
return hash((self.fn, self.injects, self.requires))
|
||||||
@@ -26,8 +27,8 @@ class ConversionPoint:
|
|||||||
return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}'
|
return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fn_args(self) -> list[type]:
|
def fn_args(self) -> list[tuple[str, type]]:
|
||||||
return extract_func_argtypes_seq(self.fn)
|
return extract_func_args(self.fn)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ctx_manager(self) -> bool:
|
def is_ctx_manager(self) -> bool:
|
||||||
@@ -38,7 +39,7 @@ class ConversionPoint:
|
|||||||
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
|
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fn(cls, func: Callable, rettype: Optional[type] = None):
|
def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]:
|
||||||
if rettype is None:
|
if rettype is None:
|
||||||
annot = get_type_hints(func)
|
annot = get_type_hints(func)
|
||||||
rettype = annot.get('return')
|
rettype = annot.get('return')
|
||||||
@@ -61,20 +62,34 @@ class ConversionPoint:
|
|||||||
rettype = get_args(rettype)[0]
|
rettype = get_args(rettype)[0]
|
||||||
|
|
||||||
argtypes: list[list[type]] = []
|
argtypes: list[list[type]] = []
|
||||||
orig_argtypes = extract_func_argtypes_seq(func)
|
orig_args = extract_func_args(func)
|
||||||
for argtype in orig_argtypes:
|
defaults = extract_func_arg_defaults(func)
|
||||||
|
|
||||||
|
orig_argtypes = []
|
||||||
|
for argname, argtype in orig_args:
|
||||||
|
orig_argtypes.append((argtype, argname in defaults.keys()))
|
||||||
|
|
||||||
|
default_map: list[bool] = []
|
||||||
|
for argtype, has_default in orig_argtypes:
|
||||||
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
|
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
|
||||||
u_types = list(get_args(argtype)) + [argtype]
|
u_types = list(get_args(argtype)) + [argtype]
|
||||||
else:
|
else:
|
||||||
u_types = [argtype]
|
u_types = [argtype]
|
||||||
|
default_map.append(has_default)
|
||||||
argtypes.append(u_types)
|
argtypes.append(u_types)
|
||||||
|
|
||||||
argtype_combinations = all_combinations(argtypes)
|
argtype_combinations = all_combinations(argtypes)
|
||||||
ret = []
|
ret = []
|
||||||
for argtype_combination in argtype_combinations:
|
for argtype_combination in argtype_combinations:
|
||||||
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination)))
|
req_args = []
|
||||||
|
opt_args = []
|
||||||
|
for argt, has_default in zip(argtype_combination, default_map):
|
||||||
|
if has_default:
|
||||||
|
opt_args.append(argt)
|
||||||
|
else:
|
||||||
|
req_args.append(argt)
|
||||||
|
ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args)))
|
||||||
|
|
||||||
# return InjectorPoint(func, rettype, argtypes)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,15 +23,24 @@ class ConversionRenderData:
|
|||||||
inj_hash: str
|
inj_hash: str
|
||||||
funchash: str
|
funchash: str
|
||||||
funcname: str
|
funcname: str
|
||||||
funcargs: list[str]
|
funcargs: list[tuple[str, str]]
|
||||||
is_ctxmanager: bool
|
is_ctxmanager: bool
|
||||||
is_async: bool
|
is_async: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_inj(cls, inj: ConversionPoint):
|
def from_inj(cls, inj: ConversionPoint, provided_types: set[type]):
|
||||||
|
argmap = inj.fn_args
|
||||||
|
|
||||||
fnargs = []
|
fnargs = []
|
||||||
for argtype in inj.requires:
|
for arg_id, argtype in enumerate(inj.requires):
|
||||||
fnargs.append(hashname(argtype))
|
argname = argmap[arg_id][0]
|
||||||
|
fnargs.append((argname, hashname(argtype)))
|
||||||
|
|
||||||
|
for arg_id, argtype in enumerate(inj.opt_args, len(inj.requires)):
|
||||||
|
argname = argmap[arg_id][0]
|
||||||
|
if argtype in provided_types:
|
||||||
|
fnargs.append((argname, hashname(argtype)))
|
||||||
|
|
||||||
return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async)
|
return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,10 +75,15 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer):
|
|||||||
conversion_models = []
|
conversion_models = []
|
||||||
ret_hash = 0
|
ret_hash = 0
|
||||||
is_async = force_async
|
is_async = force_async
|
||||||
|
for call_id, call in enumerate(callseq):
|
||||||
|
|
||||||
|
provided_types = set(from_types)
|
||||||
|
for _call in callseq[:call_id]:
|
||||||
|
provided_types |= {_call.injects}
|
||||||
|
provided_types |= set(_call.requires)
|
||||||
|
|
||||||
for call in callseq:
|
|
||||||
fnmap[hash(call.fn)] = call.fn
|
fnmap[hash(call.fn)] = call.fn
|
||||||
conv = ConversionRenderData.from_inj(call)
|
conv = ConversionRenderData.from_inj(call, provided_types)
|
||||||
if conv not in conversion_models:
|
if conv not in conversion_models:
|
||||||
conversion_models.append(conv)
|
conversion_models.append(conv)
|
||||||
if call.is_async:
|
if call.is_async:
|
||||||
|
|||||||
@@ -3,11 +3,11 @@
|
|||||||
{% for conv in conversions %}
|
{% for conv in conversions %}
|
||||||
{% if conv.is_ctxmanager %}
|
{% if conv.is_ctxmanager %}
|
||||||
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
||||||
{{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}:
|
{{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}:
|
||||||
{% set ns.indent = ns.indent + 1 %}
|
{% set ns.indent = ns.indent + 1 %}
|
||||||
{% else %}
|
{% else %}
|
||||||
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
||||||
{{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %})
|
{{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %})
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{{ ' ' * ns.indent }}return _{{ret_hash}}
|
{{ ' ' * ns.indent }}return _{{ret_hash}}
|
||||||
|
|||||||
@@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def extract_func_arg_defaults(func: Callable) -> dict[str, object]:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
defaults = {
|
||||||
|
name: param.default
|
||||||
|
for name, param in sig.parameters.items()
|
||||||
|
if param.default is not inspect._empty
|
||||||
|
}
|
||||||
|
return defaults
|
||||||
|
|
||||||
|
|
||||||
def is_context_manager_factory(obj: object) -> bool:
|
def is_context_manager_factory(obj: object) -> bool:
|
||||||
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
|
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
|
||||||
|
|
||||||
|
|||||||
@@ -38,3 +38,33 @@ def test_basic():
|
|||||||
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
dep = fn2(123)
|
dep = fn2(123)
|
||||||
assert dep == 123
|
assert dep == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_union_deps():
|
||||||
|
repo = ConvRepo()
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def b_to_a(b: B) -> A:
|
||||||
|
return A(int(b.b))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def a_to_b(a: A) -> B:
|
||||||
|
return B(float(a.a))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def int_to_a(i: int) -> A:
|
||||||
|
return A(i)
|
||||||
|
|
||||||
|
def consumer(dep: A | B) -> int:
|
||||||
|
if isinstance(dep, A):
|
||||||
|
return dep.a
|
||||||
|
else:
|
||||||
|
return int(dep.b)
|
||||||
|
|
||||||
|
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn1(B(42.1))
|
||||||
|
assert dep == 42
|
||||||
|
|
||||||
|
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn2(123)
|
||||||
|
assert dep == 123
|
||||||
|
|||||||
114
tests/test_default_args.py
Normal file
114
tests/test_default_args.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src.breakshaft.convertor import ConvRepo
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class A:
|
||||||
|
a: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class B:
|
||||||
|
b: float
|
||||||
|
|
||||||
|
|
||||||
|
type optC = str
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_consumer_args():
|
||||||
|
repo = ConvRepo()
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def b_to_a(b: B) -> A:
|
||||||
|
return A(int(b.b))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def a_to_b(a: A) -> B:
|
||||||
|
return B(float(a.a))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def int_to_a(i: int) -> A:
|
||||||
|
return A(i)
|
||||||
|
|
||||||
|
def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]:
|
||||||
|
return dep.a, opt_dep
|
||||||
|
|
||||||
|
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn1(B(42.1))
|
||||||
|
assert dep == (42, '42')
|
||||||
|
|
||||||
|
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn2(123)
|
||||||
|
assert dep == (123, '42')
|
||||||
|
|
||||||
|
fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn3(123, '1')
|
||||||
|
assert dep == (123, '1')
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_inj_args():
|
||||||
|
repo = ConvRepo()
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def b_to_a(b: B) -> A:
|
||||||
|
return A(int(b.b))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def a_to_b(a: A) -> B:
|
||||||
|
return B(float(a.a))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def int_to_a(i: int, opt_dep: optC = '42') -> A:
|
||||||
|
return A(i + int(opt_dep))
|
||||||
|
|
||||||
|
def consumer(dep: A) -> int:
|
||||||
|
return dep.a
|
||||||
|
|
||||||
|
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn1(B(42.1))
|
||||||
|
assert dep == 42
|
||||||
|
|
||||||
|
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn2(123)
|
||||||
|
assert dep == 123 + 42
|
||||||
|
|
||||||
|
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn3(123, '0')
|
||||||
|
assert dep == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_graph_override():
|
||||||
|
|
||||||
|
repo = ConvRepo()
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def b_to_a(b: B) -> A:
|
||||||
|
return A(int(b.b))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def a_to_b(a: A) -> B:
|
||||||
|
return B(float(a.a))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def int_to_a(i: int, opt_dep: optC = '42') -> A:
|
||||||
|
return A(i + int(opt_dep))
|
||||||
|
|
||||||
|
@repo.mark_injector()
|
||||||
|
def inject_opt_dep() -> optC:
|
||||||
|
return '12345'
|
||||||
|
|
||||||
|
def consumer(dep: A) -> int:
|
||||||
|
return dep.a
|
||||||
|
|
||||||
|
fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn1(B(42.1))
|
||||||
|
assert dep == 42
|
||||||
|
|
||||||
|
fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn2(123)
|
||||||
|
assert dep == 123 + 12345
|
||||||
|
|
||||||
|
fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False)
|
||||||
|
dep = fn3(123, '0')
|
||||||
|
assert dep == 123
|
||||||
Reference in New Issue
Block a user