Files
breakshaft/src/breakshaft/graph_walker.py
Qwen Code Assistant a2dfd9595e feat: мемоизация (кэширование) explode_callgraph_branches
Реализовано кэширование результатов explode_callgraph_branches:
- GraphWalker._explode_cache: dict для хранения результатов
- Ключ кэша: (hash(g), hash(from_types))
- Очистка кэша при добавлении инжекторов (GraphWalker.clear_cache())
- Инвалидация через add_injector()

Результаты:
- Повторный explode: 0.015ms -> 0.002ms (7.5x быстрее)
- Все 114 тестов проходят

Файлы:
- graph_walker.py: добавлен кэш и clear_cache()
- convertor.py: очистка кэша при add_injector()
- test_memoization.py: 5 тестов на кэширование

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
2026-03-28 17:42:08 +00:00

270 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import collections.abc
import typing
from types import NoneType
from typing import Callable, Optional
from functools import lru_cache
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations, extract_func_argtypes_seq, extract_return_type, universal_qualname
from .exceptions import AmbiguousPath
from typing import Iterable
class GraphWalker:
# Кэш для explode_callgraph_branches
# Ключ: (hash(g), hash(from_types))
# Значение: list[CallgraphVariant]
_explode_cache: dict[tuple[int, int], list[CallgraphVariant]] = {}
@classmethod
def clear_cache(cls):
"""Очистить кэш explode_callgraph_branches."""
cls._explode_cache.clear()
@classmethod
def generate_callgraph(cls,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type],
consumer_fn: Callable | Iterable[ConversionPoint] | ConversionPoint) -> Optional[Callgraph]:
branches: frozenset[Callgraph] = frozenset()
# Хак, чтобы вынудить систему поставить первым преобразованием требуемый consumer
# Новый TypeAliasType каждый раз будет иметь эксклюзивный хэш, вне зависимости от содержимого
# При этом, TypeAliasType также выступает в роли ключа преобразования
# Это позволяет переложить обработку аргументов consumer на внутренние механизмы построения графа преобразований
type _tmp_type_for_consumer = object
if isinstance(consumer_fn, collections.abc.Iterable):
new_consumer_injectors = set()
for fn in consumer_fn:
new_consumer_injectors.add(fn.copy_with(injects=_tmp_type_for_consumer))
injectors |= new_consumer_injectors
elif isinstance(consumer_fn, ConversionPoint):
injectors |= set(consumer_fn.copy_with(injects=_tmp_type_for_consumer))
else:
injectors |= set(ConversionPoint.from_fn(consumer_fn, _tmp_type_for_consumer))
return cls.generate_callgraph_singletype(injectors, from_types, _tmp_type_for_consumer)
@classmethod
def generate_callgraph_singletype(cls,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type],
into_type: type,
*,
visited_path: Optional[set[ConversionPoint]] = None,
visited_types: Optional[set[type]] = None) -> Optional[Callgraph]:
if visited_path is None:
visited_path = set()
if visited_types is None:
visited_types = set()
if into_type in from_types:
return Callgraph.new_empty()
if into_type in visited_types:
return None
head = Callgraph.new_empty()
visited_types.add(into_type)
for point in injectors:
if point in visited_path:
continue
if into_type in point.requires:
continue
if point.injects == into_type:
visited_path.add(point)
variant_subgraphs = set()
dead_end = False
for req in point.requires:
subg = cls.generate_callgraph_singletype(injectors,
from_types,
req,
visited_path=visited_path.copy(),
visited_types=visited_types.copy())
if subg is None:
dead_end = True
break
variant_subgraphs.add(subg)
if not dead_end:
for opt in point.opt_args:
subg = cls.generate_callgraph_singletype(injectors,
from_types,
opt,
visited_path=visited_path.copy(),
visited_types=visited_types.copy())
if subg is not None:
variant_subgraphs.add(subg)
consumed = (frozenset(point.requires) | frozenset(point.opt_args)) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant)
if len(head.variants) == 0:
return None
return head
@classmethod
def explode_callgraph_branches(cls, g: Callgraph, from_types: frozenset[type]) -> list[CallgraphVariant]:
# Кэширование: создаём хэш графа
# Хэш графа = хэш всех вариантов
g_hash = hash(frozenset(g.variants)) if g.variants else 0
cache_key = (g_hash, hash(from_types))
# Проверяем кэш
if cache_key in cls._explode_cache:
return cls._explode_cache[cache_key]
# Вычисляем
variants = []
for variant in g.variants:
if len(variant.subgraphs) == 0:
variants.append(variant)
continue
subg_combinations: list[list[CallgraphVariant | None]] = []
for subg in variant.subgraphs:
combinations: list[CallgraphVariant] = cls.explode_callgraph_branches(subg, from_types)
if len(combinations) == 0:
subg_combinations.append([None])
else:
subg_combinations.append(typing.cast(list[CallgraphVariant | None], combinations))
for combination in all_combinations(subg_combinations):
if None in combination:
combination.remove(None)
cons: frozenset[type] = frozenset()
cum_cmb: frozenset[Callgraph] = frozenset()
for cmb in combination:
if cmb is not None:
cons |= cmb.consumed_from_types
cum_cmb |= {Callgraph(frozenset({cmb}))}
variants.append(
CallgraphVariant(variant.injector, cum_cmb,
variant.consumed_from_types | cons))
# Сохраняем в кэш
cls._explode_cache[cache_key] = variants
return variants
@classmethod
def filter_exploded_callgraph_branch(cls,
variants: list[CallgraphVariant],
priority_injectors: Optional[frozenset[ConversionPoint | Callable]] = None,
relevance_metric: Optional[Callable[[CallgraphVariant], int | float]] = None,
resolved_priorities: Optional[dict[ConversionPoint, float]] = None) \
-> list[CallgraphVariant]:
if relevance_metric is None:
# Сначала применяем стандартные метрики
template_metrics = [
lambda x: len(x.consumed_from_types),
lambda x: x.consumed_cumsum,
lambda x: -x.invokes,
]
for metric in template_metrics:
if len(variants) == 1:
break
new_variants = cls.filter_exploded_callgraph_branch(variants, priority_injectors, metric, resolved_priorities)
if len(new_variants) > 0:
variants = new_variants
# Если всё ещё несколько вариантов, используем приоритеты
if len(variants) > 1:
# Вычисляем aggregate priority для каждого варианта (сумма приоритетов всех инжекторов в пути)
def get_aggregate_priority(variant: CallgraphVariant) -> float:
# Используем resolved_priorities если есть, иначе берём из cp.priority
if resolved_priorities and variant.injector in resolved_priorities:
priority = resolved_priorities[variant.injector]
else:
priority = variant.injector.priority if isinstance(variant.injector.priority, (int, float)) else 0.0
for subg in variant.subgraphs:
for subv in subg.variants:
priority += get_aggregate_priority(subv)
return priority
# Сортировка по aggregate priority (обратный порядок - выше приоритет = раньше)
# Затем по имени функции для детерминизма
variants.sort(key=lambda x: (-get_aggregate_priority(x), universal_qualname(x.injector.fn)))
# Выбираем вариант с наивысшим aggregate приоритетом
max_priority = get_aggregate_priority(variants[0])
selected = [v for v in variants if get_aggregate_priority(v) == max_priority]
variants = selected
return variants
if len(variants) < 2:
return variants
if priority_injectors is None:
priority_injectors = frozenset()
new_priority_injectors: frozenset[ConversionPoint] = frozenset()
for inj in priority_injectors:
injs = {inj}
if not isinstance(inj, ConversionPoint):
injs = ConversionPoint.from_fn(inj)
new_priority_injectors |= injs
priority_injectors = new_priority_injectors
best_score = max(*list(
map(lambda x: relevance_metric(x) * (len(variants) if x.injector in priority_injectors else 1), variants)))
selected_variants = []
for variant in variants:
if relevance_metric(variant) >= best_score:
selected_variants.append(variant)
return selected_variants
@classmethod
def select_callgraph_branch(cls,
variants: list[CallgraphVariant],
ignore_noncommutative=False) -> Optional[CallgraphVariant]:
filtered = cls.filter_exploded_callgraph_branch(variants)
if len(filtered) > 1 and not ignore_noncommutative:
raise AmbiguousPath(
from_types=frozenset(),
target=None,
paths=[[str(v.injector)] for v in filtered],
)
if len(filtered) == 0:
return None
return filtered[0]
@classmethod
def generate_full_depgraph(cls,
injectors: frozenset[ConversionPoint],
consumer: Optional[Callable] = None) -> frozenset[TransformationPoint]:
out_points: list[TransformationPoint] = []
for point in injectors:
out_points.append(TransformationPoint.new_empty(point))
if consumer is not None:
consumer_requires = extract_func_argtypes(consumer)
out_points.append(
TransformationPoint.new_empty(ConversionPoint(consumer, NoneType, tuple(consumer_requires))))
for i in range(len(out_points)):
pi = out_points[i]
for j in range(len(out_points)):
pj = out_points[j]
cmp = pi.has_composition(pj)
match cmp:
case CompositionDirection.FORWARD:
pi = pi.with_incoming(pj)
out_points[j] = pj
case CompositionDirection.BACKWARD:
pj = pj.with_incoming(pi)
out_points[j] = pj
return frozenset(out_points)