From 6bf28e5fe88d4e1c64b80447b3a0f6eecb41cd99 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 19 Jul 2025 20:21:10 +0300 Subject: [PATCH] Add defaulted args into a `ConversionPoint` --- src/breakshaft/graph_walker.py | 2 +- src/breakshaft/models.py | 25 ++++++++++++++++++++----- src/breakshaft/util.py | 10 ++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/breakshaft/graph_walker.py b/src/breakshaft/graph_walker.py index dd3719b..7c1d20c 100644 --- a/src/breakshaft/graph_walker.py +++ b/src/breakshaft/graph_walker.py @@ -24,7 +24,7 @@ class GraphWalker: return None branches |= {cg} variant = CallgraphVariant( - ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))), + ConversionPoint.from_fn(consumer_fn, NoneType)[0], branches, frozenset()) return Callgraph(frozenset({variant})) diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 356e7d7..591c30d 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \ is_async_context_manager_factory, \ - all_combinations, is_context_manager_factory + all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args @dataclass(frozen=True) @@ -18,6 +18,7 @@ class ConversionPoint: fn: Callable injects: type requires: tuple[type, ...] + opt_args: tuple[tuple[type, object], ...] def __hash__(self): return hash((self.fn, self.injects, self.requires)) @@ -61,20 +62,34 @@ class ConversionPoint: rettype = get_args(rettype)[0] argtypes: list[list[type]] = [] - orig_argtypes = extract_func_argtypes_seq(func) - for argtype in orig_argtypes: + orig_args = extract_func_args(func) + defaults = extract_func_arg_defaults(func) + + orig_argtypes = [] + for argname, argtype in orig_args: + orig_argtypes.append((argtype, argname in defaults.keys(), defaults.get(argname))) + + default_map: list[tuple[bool, object]] = [] + for argtype, has_default, default in orig_argtypes: if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union: u_types = list(get_args(argtype)) + [argtype] else: u_types = [argtype] + default_map.append((has_default, default)) argtypes.append(u_types) argtype_combinations = all_combinations(argtypes) ret = [] for argtype_combination in argtype_combinations: - ret.append(ConversionPoint(func, rettype, tuple(argtype_combination))) + req_args = [] + opt_args = [] + for argt, (has_default, default) in zip(argtype_combination, default_map): + if has_default: + opt_args.append((argt, default)) + else: + req_args.append(argt) + ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args))) - # return InjectorPoint(func, rettype, argtypes) return ret diff --git a/src/breakshaft/util.py b/src/breakshaft/util.py index 9b4db5d..bfdae60 100644 --- a/src/breakshaft/util.py +++ b/src/breakshaft/util.py @@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]: return ret +def extract_func_arg_defaults(func: Callable) -> dict[str, object]: + sig = inspect.signature(func) + defaults = { + name: param.default + for name, param in sig.parameters.items() + if param.default is not inspect._empty + } + return defaults + + def is_context_manager_factory(obj: object) -> bool: return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)