Add defaulted args into a ConversionPoint
This commit is contained in:
@@ -24,7 +24,7 @@ class GraphWalker:
|
|||||||
return None
|
return None
|
||||||
branches |= {cg}
|
branches |= {cg}
|
||||||
variant = CallgraphVariant(
|
variant = CallgraphVariant(
|
||||||
ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes_seq(consumer_fn))),
|
ConversionPoint.from_fn(consumer_fn, NoneType)[0],
|
||||||
branches, frozenset())
|
branches, frozenset())
|
||||||
return Callgraph(frozenset({variant}))
|
return Callgraph(frozenset({variant}))
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Callable, Optional, get_type_hints, get_origin, Generator, ge
|
|||||||
|
|
||||||
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
|
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
|
||||||
is_async_context_manager_factory, \
|
is_async_context_manager_factory, \
|
||||||
all_combinations, is_context_manager_factory
|
all_combinations, is_context_manager_factory, extract_func_arg_defaults, extract_func_args
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -18,6 +18,7 @@ class ConversionPoint:
|
|||||||
fn: Callable
|
fn: Callable
|
||||||
injects: type
|
injects: type
|
||||||
requires: tuple[type, ...]
|
requires: tuple[type, ...]
|
||||||
|
opt_args: tuple[tuple[type, object], ...]
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.fn, self.injects, self.requires))
|
return hash((self.fn, self.injects, self.requires))
|
||||||
@@ -61,20 +62,34 @@ class ConversionPoint:
|
|||||||
rettype = get_args(rettype)[0]
|
rettype = get_args(rettype)[0]
|
||||||
|
|
||||||
argtypes: list[list[type]] = []
|
argtypes: list[list[type]] = []
|
||||||
orig_argtypes = extract_func_argtypes_seq(func)
|
orig_args = extract_func_args(func)
|
||||||
for argtype in orig_argtypes:
|
defaults = extract_func_arg_defaults(func)
|
||||||
|
|
||||||
|
orig_argtypes = []
|
||||||
|
for argname, argtype in orig_args:
|
||||||
|
orig_argtypes.append((argtype, argname in defaults.keys(), defaults.get(argname)))
|
||||||
|
|
||||||
|
default_map: list[tuple[bool, object]] = []
|
||||||
|
for argtype, has_default, default in orig_argtypes:
|
||||||
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
|
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
|
||||||
u_types = list(get_args(argtype)) + [argtype]
|
u_types = list(get_args(argtype)) + [argtype]
|
||||||
else:
|
else:
|
||||||
u_types = [argtype]
|
u_types = [argtype]
|
||||||
|
default_map.append((has_default, default))
|
||||||
argtypes.append(u_types)
|
argtypes.append(u_types)
|
||||||
|
|
||||||
argtype_combinations = all_combinations(argtypes)
|
argtype_combinations = all_combinations(argtypes)
|
||||||
ret = []
|
ret = []
|
||||||
for argtype_combination in argtype_combinations:
|
for argtype_combination in argtype_combinations:
|
||||||
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination)))
|
req_args = []
|
||||||
|
opt_args = []
|
||||||
|
for argt, (has_default, default) in zip(argtype_combination, default_map):
|
||||||
|
if has_default:
|
||||||
|
opt_args.append((argt, default))
|
||||||
|
else:
|
||||||
|
req_args.append(argt)
|
||||||
|
ret.append(ConversionPoint(func, rettype, tuple(req_args), tuple(opt_args)))
|
||||||
|
|
||||||
# return InjectorPoint(func, rettype, argtypes)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,16 @@ def extract_func_argtypes_seq(func: Callable) -> list[type]:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def extract_func_arg_defaults(func: Callable) -> dict[str, object]:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
defaults = {
|
||||||
|
name: param.default
|
||||||
|
for name, param in sig.parameters.items()
|
||||||
|
if param.default is not inspect._empty
|
||||||
|
}
|
||||||
|
return defaults
|
||||||
|
|
||||||
|
|
||||||
def is_context_manager_factory(obj: object) -> bool:
|
def is_context_manager_factory(obj: object) -> bool:
|
||||||
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
|
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user