diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index dd3719b..2aa0aac 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -14,19 +14,16 @@ class GraphWalker: from_types: frozenset[type], consumer_fn: Callable) -> Optional[Callgraph]: - into_types: frozenset[type] = extract_func_argtypes(consumer_fn) - branches: frozenset[Callgraph] = frozenset() - for into_type in into_types: - cg = cls.generate_callgraph_singletype(injectors, from_types, into_type) - if cg is None: - return None - branches |= {cg} - variant = CallgraphVariant( - ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))), - branches, frozenset()) - return Callgraph(frozenset({variant})) + # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer + # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого + # При этом, TypeAliasType также выступает в роли ключа преобразования + # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований + type _tmp_type_for_consumer = object + injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer)) + + return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer) @classmethod def generate_callgraph_singletype(cls, @@ -72,6 +69,16 @@ class GraphWalker: variant_subgraphs.add(subg) if not dead_end: + + for opt in point.opt_args: + subg = cls.generate_callgraph_singletype(injectors, + from_types, + opt, + visited_path=visited_path.copy(), + visited_types=visited_types.copy()) + if subg is not None: + variant_subgraphs.add(subg) + consumed = frozenset(point.requires) & from_types variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) head = head.add_subgraph_variant(variant) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 356e7d7..c49fc91 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ is_async_context_manager_factory, \ - all_combinations, is_context_manager_factory + all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args @dataclass(frozen=True) @@ -18,6 +18,7 @@ class ConversionPoint: fn: Callable injects: type requires: tuple[type, ...] + opt_args: tuple[type, ...] def __hash__(self): return hash((self.fn, self.injects, self.requires)) @@ -26,8 +27,8 @@ class ConversionPoint: return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}' @property - def fn_args(self) -> list[type]: - return extract_func_argtypes_seq(self.fn) + def fn_args(self) -> list[tuple[str, type]]: + return extract_func_args(self.fn) @property def is_ctx_manager(self) -> bool: @@ -38,7 +39,7 @@ class ConversionPoint: return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) @classmethod - def from_fn(cls, func: Callable, rettype: Optional[type] = None): + def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: if rettype is None: annot = get_type_hints(func) rettype = annot.get('return') @@ -61,20 +62,34 @@ class ConversionPoint: rettype = get_args(rettype)[0] argtypes: list[list[type]] = [] - orig_argtypes = extract_func_argtypes_seq(func) - for argtype in orig_argtypes: + orig_args = extract_func_args(func) + defaults = extract_func_arg_defaults(func) + + orig_argtypes = [] + for argname, argtype in orig_args: + orig_argtypes.append((argtype, argname in defaults.keys())) + + default_map: list[bool] = [] + for argtype, has_default in orig_argtypes: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: u_types = list(get_args(argtype)) + [argtype] else: u_types = [argtype] + default_map.append(has_default) argtypes.append(u_types) argtype_combinations = all_combinations(argtypes) ret = [] for argtype_combination in argtype_combinations: - ret.append(ConversionPoint(func, rettype, tuple(argtype_combination))) + req_args = [] + opt_args = [] + for argt, has_default in zip(argtype_combination, default_map): + if has_default: + opt_args.append(argt) + else: + req_args.append(argt) + ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args))) - # return InjectorPoint(func, rettype, argtypes) return ret diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index be2b5f7..ffe4b28 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -23,15 +23,24 @@ class ConversionRenderData: inj_hash: str funchash: str funcname: str - funcargs: list[str] + funcargs: list[tuple[str, str]] is_ctxmanager: bool is_async: bool @classmethod - def from_inj(cls, inj: ConversionPoint): + def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): + argmap = inj.fn_args + fnargs = [] - for argtype in inj.requires: - fnargs.append(hashname(argtype)) + for arg_id, argtype in enumerate(inj.requires): + argname = argmap[arg_id][0] + fnargs.append((argname, hashname(argtype))) + + for arg_id, argtype in enumerate(inj.opt_args, len(inj.requires)): + argname = argmap[arg_id][0] + if argtype in provided_types: + fnargs.append((argname, hashname(argtype))) + return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async) @@ -66,10 +75,15 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): conversion_models = [] ret_hash = 0 is_async = force_async + for call_id, call in enumerate(callseq): + + provided_types = set(from_types) + for _call in callseq[:call_id]: + provided_types |= {_call.injects} + provided_types |= set(_call.requires) - for call in callseq: fnmap[hash(call.fn)] = call.fn - conv = ConversionRenderData.from_inj(call) + conv = ConversionRenderData.from_inj(call, provided_types) if conv not in conversion_models: conversion_models.append(conv) if call.is_async: diff --git a/src/breakshaft/templates/convertor.jinja2 b/src/breakshaft/templates/convertor.jinja2 index 74128ad..1bc3189 100644 --- a/src/breakshaft/templates/convertor.jinja2 +++ b/src/breakshaft/templates/convertor.jinja2 @@ -3,11 +3,11 @@ {% for conv in conversions %} {% if conv.is_ctxmanager %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}: + {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}: {% set ns.indent = ns.indent + 1 %} {% else %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) + {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) {% endif %} {% endfor %} {{ ' ' * ns.indent }}return _{{ret_hash}} diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index 9b4db5d..bfdae60 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]: return ret +def extract_func_arg_defaults(func: Callable) -> dict[str, object]: + sig = inspect.signature(func) + defaults = { + name: param.default + for name, param in sig.parameters.items() + if param.default is not inspect._empty + } + return defaults + + def is_context_manager_factory(obj: object) -> bool: return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj) diff --git a/tests/test_basic.py b/tests/test_basic.py index b29a7df..e6f7972 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -38,3 +38,33 @@ def test_basic(): fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn2(123) assert dep == 123 + + +def test_union_deps(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A | B) -> int: + if isinstance(dep, A): + return dep.a + else: + return int(dep.b) + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 diff --git a/tests/test_default_args.py b/tests/test_default_args.py new file mode 100644 index 0000000..ce4d9af --- /dev/null +++ b/tests/test_default_args.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass + +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +type optC = str + + +def test_default_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]: + return dep.a, opt_dep + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == (42, '42') + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == (123, '42') + + fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '1') + assert dep == (123, '1') + + +def test_default_inj_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int, opt_dep: optC = '42') -> A: + return A(i + int(opt_dep)) + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + 42 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123 + + +def test_default_graph_override(): + + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int, opt_dep: optC = '42') -> A: + return A(i + int(opt_dep)) + + @repo.mark_injector() + def inject_opt_dep() -> optC: + return '12345' + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + 12345 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123 \ No newline at end of file