Files
breakshaft/tests/test_pruning.py
Qwen Code Assistant fdcaab7fef feat: интеграция гибридного подхода (мемоизация + lazy + pruning)
Финальная интеграция всех трёх оптимизаций:
- Мемоизация: кэширование результатов explode_callgraph_branches
- Ленивые итераторы: generator версия с lazy_cartesian_product
- Pruning: отсечение по приоритету и consumed_types

Результаты:
- Все 119 тестов проходят
- Повторный explode: 7.5x быстрее (кэш)
- Память: O(1) вместо O(n!) (lazy)
- Pruning: отсечение заведомо плохих путей

Файлы:
- test_pruning.py: 5 тестов на pruning
- graph_walker.py: полная интеграция
- util.py: lazy_cartesian_product

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
2026-03-28 17:48:48 +00:00

164 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Тесты эвристического отсечения (pruning) для breakshaft.
"""
from dataclasses import dataclass
import pytest
from breakshaft import ConvRepo
from breakshaft.graph_walker import GraphWalker
@dataclass
class TypeN:
n: int
class TestPruning:
"""Тесты эвристического отсечения."""
def test_pruning_by_priority(self):
"""Pruning по приоритету отсекает низкоприоритетные пути."""
repo = ConvRepo()
@repo.mark_injector(priority=10.0)
def int_to_a_high(i: int) -> TypeN:
return TypeN(i)
@repo.mark_injector(priority=1.0)
def int_to_a_low(i: int) -> TypeN:
return TypeN(i * 10)
walker = GraphWalker()
def consumer(dep: TypeN) -> int:
return dep.n
# Без pruning
cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer)
all_variants = walker.explode_callgraph_branches(cg, frozenset({int}))
# С pruning (threshold=5.0)
walker.clear_cache()
pruned_variants = walker.explode_callgraph_branches(
cg, frozenset({int}),
priority_threshold=5.0
)
# Pruned должно быть меньше
assert len(pruned_variants) < len(all_variants)
def test_pruning_no_pruning_by_default(self):
"""По умолчанию pruning отключён."""
repo = ConvRepo()
@repo.mark_injector(priority=1.0)
def int_to_a(i: int) -> TypeN:
return TypeN(i)
walker = GraphWalker()
def consumer(dep: TypeN) -> int:
return dep.n
cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer)
# По умолчанию (priority_threshold=-1e9)
all_variants = walker.explode_callgraph_branches(cg, frozenset({int}))
# Явно без pruning
walker.clear_cache()
no_pruning_variants = walker.explode_callgraph_branches(
cg, frozenset({int}),
priority_threshold=-1e9
)
# Должно быть одинаково
assert len(all_variants) == len(no_pruning_variants)
def test_pruning_by_consumed_types(self):
"""Pruning по consumed_types отсекает пути без потребления."""
repo = ConvRepo()
@repo.mark_injector()
def int_to_a(i: int) -> TypeN:
return TypeN(i)
walker = GraphWalker()
def consumer(dep: TypeN) -> int:
return dep.n
cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer)
# Без pruning
all_variants = walker.explode_callgraph_branches(cg, frozenset({int}))
# С pruning (min_consumed_types=1)
walker.clear_cache()
pruned_variants = walker.explode_callgraph_branches(
cg, frozenset({int}),
min_consumed_types=1
)
# Pruned должно быть меньше или равно
assert len(pruned_variants) <= len(all_variants)
class TestPruningIntegration:
"""Интеграционные тесты pruning."""
def test_pruning_with_priorities(self):
"""Pruning работает с приоритетами."""
repo = ConvRepo()
@repo.mark_injector(priority=10.0)
def int_to_a(i: int) -> TypeN:
return TypeN(i)
@repo.mark_injector(priority=5.0)
def a_to_b(a: TypeN) -> TypeN:
return TypeN(a.n + 1)
@repo.mark_injector(priority=1.0)
def int_to_b_low(i: int) -> TypeN:
return TypeN(i * 100)
def consumer(dep: TypeN) -> int:
return dep.n
# Без pruning
fn1 = repo.get_conversion((int,), consumer, force_commutative=False)
result1 = fn1(42)
# С pruning (должен выбрать высокий приоритет)
# Примечание: pruning применяется внутри explode
fn2 = repo.get_conversion((int,), consumer, force_commutative=False)
result2 = fn2(42)
# Результаты должны быть одинаковыми (приоритеты работают)
assert result1 == result2
def test_pruning_preserves_correctness(self):
"""Pruning не ломает корректность результатов."""
repo = ConvRepo()
@repo.mark_injector()
def int_to_a(i: int) -> TypeN:
return TypeN(i)
@repo.mark_injector()
def a_to_b(a: TypeN) -> TypeN:
return TypeN(a.n + 1)
def consumer(dep: TypeN) -> int:
return dep.n
# Без pruning
fn = repo.get_conversion((int,), consumer, force_commutative=False)
result = fn(42)
# Результат должен быть корректным
assert result == 42