diff --git a/tests/test_pruning.py b/tests/test_pruning.py new file mode 100644 index 0000000..ef8e932 --- /dev/null +++ b/tests/test_pruning.py @@ -0,0 +1,163 @@ +""" +Тесты эвристического отсечения (pruning) для breakshaft. +""" + +from dataclasses import dataclass + +import pytest + +from breakshaft import ConvRepo +from breakshaft.graph_walker import GraphWalker + + +@dataclass +class TypeN: + n: int + + +class TestPruning: + """Тесты эвристического отсечения.""" + + def test_pruning_by_priority(self): + """Pruning по приоритету отсекает низкоприоритетные пути.""" + repo = ConvRepo() + + @repo.mark_injector(priority=10.0) + def int_to_a_high(i: int) -> TypeN: + return TypeN(i) + + @repo.mark_injector(priority=1.0) + def int_to_a_low(i: int) -> TypeN: + return TypeN(i * 10) + + walker = GraphWalker() + + def consumer(dep: TypeN) -> int: + return dep.n + + # Без pruning + cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer) + all_variants = walker.explode_callgraph_branches(cg, frozenset({int})) + + # С pruning (threshold=5.0) + walker.clear_cache() + pruned_variants = walker.explode_callgraph_branches( + cg, frozenset({int}), + priority_threshold=5.0 + ) + + # Pruned должно быть меньше + assert len(pruned_variants) < len(all_variants) + + def test_pruning_no_pruning_by_default(self): + """По умолчанию pruning отключён.""" + repo = ConvRepo() + + @repo.mark_injector(priority=1.0) + def int_to_a(i: int) -> TypeN: + return TypeN(i) + + walker = GraphWalker() + + def consumer(dep: TypeN) -> int: + return dep.n + + cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer) + + # По умолчанию (priority_threshold=-1e9) + all_variants = walker.explode_callgraph_branches(cg, frozenset({int})) + + # Явно без pruning + walker.clear_cache() + no_pruning_variants = walker.explode_callgraph_branches( + cg, frozenset({int}), + priority_threshold=-1e9 + ) + + # Должно быть одинаково + assert len(all_variants) == len(no_pruning_variants) + + def test_pruning_by_consumed_types(self): + """Pruning по consumed_types отсекает пути без потребления.""" + repo = ConvRepo() + + @repo.mark_injector() + def int_to_a(i: int) -> TypeN: + return TypeN(i) + + walker = GraphWalker() + + def consumer(dep: TypeN) -> int: + return dep.n + + cg = walker.generate_callgraph(repo.convertor_set, frozenset({int}), consumer) + + # Без pruning + all_variants = walker.explode_callgraph_branches(cg, frozenset({int})) + + # С pruning (min_consumed_types=1) + walker.clear_cache() + pruned_variants = walker.explode_callgraph_branches( + cg, frozenset({int}), + min_consumed_types=1 + ) + + # Pruned должно быть меньше или равно + assert len(pruned_variants) <= len(all_variants) + + +class TestPruningIntegration: + """Интеграционные тесты pruning.""" + + def test_pruning_with_priorities(self): + """Pruning работает с приоритетами.""" + repo = ConvRepo() + + @repo.mark_injector(priority=10.0) + def int_to_a(i: int) -> TypeN: + return TypeN(i) + + @repo.mark_injector(priority=5.0) + def a_to_b(a: TypeN) -> TypeN: + return TypeN(a.n + 1) + + @repo.mark_injector(priority=1.0) + def int_to_b_low(i: int) -> TypeN: + return TypeN(i * 100) + + def consumer(dep: TypeN) -> int: + return dep.n + + # Без pruning + fn1 = repo.get_conversion((int,), consumer, force_commutative=False) + result1 = fn1(42) + + # С pruning (должен выбрать высокий приоритет) + # Примечание: pruning применяется внутри explode + fn2 = repo.get_conversion((int,), consumer, force_commutative=False) + result2 = fn2(42) + + # Результаты должны быть одинаковыми (приоритеты работают) + assert result1 == result2 + + def test_pruning_preserves_correctness(self): + """Pruning не ломает корректность результатов.""" + repo = ConvRepo() + + @repo.mark_injector() + def int_to_a(i: int) -> TypeN: + return TypeN(i) + + @repo.mark_injector() + def a_to_b(a: TypeN) -> TypeN: + return TypeN(a.n + 1) + + def consumer(dep: TypeN) -> int: + return dep.n + + # Без pruning + fn = repo.get_conversion((int,), consumer, force_commutative=False) + result = fn(42) + + # Результат должен быть корректным + assert result == 42