From 6bf28e5fe88d4e1c64b80447b3a0f6eecb41cd99 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:21:10 +0300 Subject: [PATCH 1/7] Add defaulted args into a `ConversionPoint` --- src/breakshaft/graph_walker.py | 2 +- src/breakshaft/models.py | 25 ++++++++++++++++++++----- src/breakshaft/util.py | 10 ++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index dd3719b..7c1d20c 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -24,7 +24,7 @@ class GraphWalker: return None branches |= {cg} variant = CallgraphVariant( - ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))), + ConversionPoint.from_fn(consumer_fn, NoneType)[0], branches, frozenset()) return Callgraph(frozenset({variant})) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 356e7d7..591c30d 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ is_async_context_manager_factory, \ - all_combinations, is_context_manager_factory + all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args @dataclass(frozen=True) @@ -18,6 +18,7 @@ class ConversionPoint: fn: Callable injects: type requires: tuple[type, ...] + opt_args: tuple[tuple[type, object], ...] def __hash__(self): return hash((self.fn, self.injects, self.requires)) @@ -61,20 +62,34 @@ class ConversionPoint: rettype = get_args(rettype)[0] argtypes: list[list[type]] = [] - orig_argtypes = extract_func_argtypes_seq(func) - for argtype in orig_argtypes: + orig_args = extract_func_args(func) + defaults = extract_func_arg_defaults(func) + + orig_argtypes = [] + for argname, argtype in orig_args: + orig_argtypes.append((argtype, argname in defaults.keys(), defaults.get(argname))) + + default_map: list[tuple[bool, object]] = [] + for argtype, has_default, default in orig_argtypes: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: u_types = list(get_args(argtype)) + [argtype] else: u_types = [argtype] + default_map.append((has_default, default)) argtypes.append(u_types) argtype_combinations = all_combinations(argtypes) ret = [] for argtype_combination in argtype_combinations: - ret.append(ConversionPoint(func, rettype, tuple(argtype_combination))) + req_args = [] + opt_args = [] + for argt, (has_default, default) in zip(argtype_combination, default_map): + if has_default: + opt_args.append((argt, default)) + else: + req_args.append(argt) + ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args))) - # return InjectorPoint(func, rettype, argtypes) return ret diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index 9b4db5d..bfdae60 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]: return ret +def extract_func_arg_defaults(func: Callable) -> dict[str, object]: + sig = inspect.signature(func) + defaults = { + name: param.default + for name, param in sig.parameters.items() + if param.default is not inspect._empty + } + return defaults + + def is_context_manager_factory(obj: object) -> bool: return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj) -- 2.49.1 From a2cf1bb6e658a00cce93f375c083af1451c7e789 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:35:27 +0300 Subject: [PATCH 2/7] Get rid of manual consumer fn unwrapping for callgraph generation --- src/breakshaft/graph_walker.py | 19 ++++++++----------- src/breakshaft/models.py | 2 +- tests/test_basic.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 7c1d20c..0eb8918 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -14,19 +14,16 @@ class GraphWalker: from_types: frozenset[type], consumer_fn: Callable) -> Optional[Callgraph]: - into_types: frozenset[type] = extract_func_argtypes(consumer_fn) - branches: frozenset[Callgraph] = frozenset() - for into_type in into_types: - cg = cls.generate_callgraph_singletype(injectors, from_types, into_type) - if cg is None: - return None - branches |= {cg} - variant = CallgraphVariant( - ConversionPoint.from_fn(consumer_fn, NoneType)[0], - branches, frozenset()) - return Callgraph(frozenset({variant})) + # Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer + # Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого + # При этом, TypeAliasType также выступает в роли ключа преобразования + # Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований + type _tmp_type_for_consumer = object + injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer)) + + return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer) @classmethod def generate_callgraph_singletype(cls, diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 591c30d..fdd1012 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -39,7 +39,7 @@ class ConversionPoint: return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn) @classmethod - def from_fn(cls, func: Callable, rettype: Optional[type] = None): + def from_fn(cls, func: Callable, rettype: Optional[type] = None) -> list[ConversionPoint]: if rettype is None: annot = get_type_hints(func) rettype = annot.get('return') diff --git a/tests/test_basic.py b/tests/test_basic.py index b29a7df..e6f7972 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -38,3 +38,33 @@ def test_basic(): fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn2(123) assert dep == 123 + + +def test_union_deps(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A | B) -> int: + if isinstance(dep, A): + return dep.a + else: + return int(dep.b) + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 -- 2.49.1 From fe53cf9270b6bd72a1ffc1bd600408d993af65f1 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:49:10 +0300 Subject: [PATCH 3/7] Add test for non-provided default consumer args --- tests/test_default_args.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/test_default_args.py diff --git a/tests/test_default_args.py b/tests/test_default_args.py new file mode 100644 index 0000000..ed5b3bc --- /dev/null +++ b/tests/test_default_args.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from src.breakshaft.convertor import ConvRepo + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +type optC = str + + +def test_default_consumer_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int) -> A: + return A(i) + + def consumer(dep: A, opt_dep: optC = '42') -> tuple[int, str]: + return dep.a, opt_dep + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == (42, '42') + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == (123, '42') + + -- 2.49.1 From b04ea2c16ac34c61b86f7e0036e106032047de0a Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:50:49 +0300 Subject: [PATCH 4/7] Add test for non-provided default convertor args --- tests/test_default_args.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_default_args.py b/tests/test_default_args.py index ed5b3bc..58d5243 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -43,3 +43,28 @@ def test_default_consumer_args(): assert dep == (123, '42') +def test_default_inj_args(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int, opt_dep: optC = '42') -> A: + return A(i + int(opt_dep)) + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + 42 -- 2.49.1 From f2ec4fad14ffe73b1a905a1ad6ce1afaa061cdf1 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 21:08:46 +0300 Subject: [PATCH 5/7] Allow default option to be overriden if was ocasionally provided on a conversion path --- src/breakshaft/models.py | 4 ++-- src/breakshaft/renderer.py | 26 +++++++++++++++++------ src/breakshaft/templates/convertor.jinja2 | 4 ++-- tests/test_default_args.py | 8 +++++++ 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index fdd1012..0d1afb4 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -27,8 +27,8 @@ class ConversionPoint: return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}' @property - def fn_args(self) -> list[type]: - return extract_func_argtypes_seq(self.fn) + def fn_args(self) -> list[tuple[str, type]]: + return extract_func_args(self.fn) @property def is_ctx_manager(self) -> bool: diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index be2b5f7..8494e6b 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -23,15 +23,24 @@ class ConversionRenderData: inj_hash: str funchash: str funcname: str - funcargs: list[str] + funcargs: list[tuple[str, str]] is_ctxmanager: bool is_async: bool @classmethod - def from_inj(cls, inj: ConversionPoint): + def from_inj(cls, inj: ConversionPoint, provided_types: set[type]): + argmap = inj.fn_args + fnargs = [] - for argtype in inj.requires: - fnargs.append(hashname(argtype)) + for arg_id, argtype in enumerate(inj.requires): + argname = argmap[arg_id][0] + fnargs.append((argname, hashname(argtype))) + + for arg_id, (argtype, _) in enumerate(inj.opt_args, len(inj.requires)): + argname = argmap[arg_id][0] + if argtype in provided_types: + fnargs.append((argname, hashname(argtype))) + return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async) @@ -66,10 +75,15 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): conversion_models = [] ret_hash = 0 is_async = force_async + for call_id, call in enumerate(callseq): + + provided_types = set(from_types) + for _call in callseq[:call_id]: + provided_types |= {_call.injects} + provided_types |= set(_call.requires) - for call in callseq: fnmap[hash(call.fn)] = call.fn - conv = ConversionRenderData.from_inj(call) + conv = ConversionRenderData.from_inj(call, provided_types) if conv not in conversion_models: conversion_models.append(conv) if call.is_async: diff --git a/src/breakshaft/templates/convertor.jinja2 b/src/breakshaft/templates/convertor.jinja2 index 74128ad..1bc3189 100644 --- a/src/breakshaft/templates/convertor.jinja2 +++ b/src/breakshaft/templates/convertor.jinja2 @@ -3,11 +3,11 @@ {% for conv in conversions %} {% if conv.is_ctxmanager %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}: + {{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) as _{{ conv.inj_hash }}: {% set ns.indent = ns.indent + 1 %} {% else %} {{ ' ' * ns.indent }}# {{conv.funcname}} - {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) + {{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}{{conv_arg[0]}} = _{{conv_arg[1]}}, {% endfor %}) {% endif %} {% endfor %} {{ ' ' * ns.indent }}return _{{ret_hash}} diff --git a/tests/test_default_args.py b/tests/test_default_args.py index 58d5243..3c80b14 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -42,6 +42,10 @@ def test_default_consumer_args(): dep = fn2(123) assert dep == (123, '42') + fn3 = repo.get_conversion((int, optC), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '1') + assert dep == (123, '1') + def test_default_inj_args(): repo = ConvRepo() @@ -68,3 +72,7 @@ def test_default_inj_args(): fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn2(123) assert dep == 123 + 42 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123 -- 2.49.1 From 69def6e74c5756d2529f534e5b9ff459fb99f95f Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 21:12:35 +0300 Subject: [PATCH 6/7] Allow default option to be overriden if there is any conversion point that injects this default option --- src/breakshaft/graph_walker.py | 10 ++++++++++ tests/test_default_args.py | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 0eb8918..1bdfc9d 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -69,6 +69,16 @@ class GraphWalker: variant_subgraphs.add(subg) if not dead_end: + + for opt, _ in point.opt_args: + subg = cls.generate_callgraph_singletype(injectors, + from_types, + opt, + visited_path=visited_path.copy(), + visited_types=visited_types.copy()) + if subg is not None: + variant_subgraphs.add(subg) + consumed = frozenset(point.requires) & from_types variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed) head = head.add_subgraph_variant(variant) diff --git a/tests/test_default_args.py b/tests/test_default_args.py index 3c80b14..ce4d9af 100644 --- a/tests/test_default_args.py +++ b/tests/test_default_args.py @@ -76,3 +76,39 @@ def test_default_inj_args(): fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) dep = fn3(123, '0') assert dep == 123 + + +def test_default_graph_override(): + + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + @repo.mark_injector() + def int_to_a(i: int, opt_dep: optC = '42') -> A: + return A(i + int(opt_dep)) + + @repo.mark_injector() + def inject_opt_dep() -> optC: + return '12345' + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + 12345 + + fn3 = repo.get_conversion((int, optC,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn3(123, '0') + assert dep == 123 \ No newline at end of file -- 2.49.1 From eae2cd9a4b04b6de1630630757995e1711999bb3 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 21:13:40 +0300 Subject: [PATCH 7/7] Remove unused defaults --- src/breakshaft/graph_walker.py | 2 +- src/breakshaft/models.py | 14 +++++++------- src/breakshaft/renderer.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index 1bdfc9d..2aa0aac 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -70,7 +70,7 @@ class GraphWalker: if not dead_end: - for opt, _ in point.opt_args: + for opt in point.opt_args: subg = cls.generate_callgraph_singletype(injectors, from_types, opt, diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 0d1afb4..c49fc91 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -18,7 +18,7 @@ class ConversionPoint: fn: Callable injects: type requires: tuple[type, ...] - opt_args: tuple[tuple[type, object], ...] + opt_args: tuple[type, ...] def __hash__(self): return hash((self.fn, self.injects, self.requires)) @@ -67,15 +67,15 @@ class ConversionPoint: orig_argtypes = [] for argname, argtype in orig_args: - orig_argtypes.append((argtype, argname in defaults.keys(), defaults.get(argname))) + orig_argtypes.append((argtype, argname in defaults.keys())) - default_map: list[tuple[bool, object]] = [] - for argtype, has_default, default in orig_argtypes: + default_map: list[bool] = [] + for argtype, has_default in orig_argtypes: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: u_types = list(get_args(argtype)) + [argtype] else: u_types = [argtype] - default_map.append((has_default, default)) + default_map.append(has_default) argtypes.append(u_types) argtype_combinations = all_combinations(argtypes) @@ -83,9 +83,9 @@ class ConversionPoint: for argtype_combination in argtype_combinations: req_args = [] opt_args = [] - for argt, (has_default, default) in zip(argtype_combination, default_map): + for argt, has_default in zip(argtype_combination, default_map): if has_default: - opt_args.append((argt, default)) + opt_args.append(argt) else: req_args.append(argt) ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args))) diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index 8494e6b..ffe4b28 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -36,7 +36,7 @@ class ConversionRenderData: argname = argmap[arg_id][0] fnargs.append((argname, hashname(argtype))) - for arg_id, (argtype, _) in enumerate(inj.opt_args, len(inj.requires)): + for arg_id, argtype in enumerate(inj.opt_args, len(inj.requires)): argname = argmap[arg_id][0] if argtype in provided_types: fnargs.append((argname, hashname(argtype))) -- 2.49.1