Add library
This commit is contained in:
1
src/breakshaft/__init__.py
Normal file
1
src/breakshaft/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
162
src/breakshaft/__main__.py
Normal file
162
src/breakshaft/__main__.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.breakshaft import util_mermaid
|
||||
from src.breakshaft.graph_walker import GraphWalker
|
||||
from src.breakshaft.models import Callgraph
|
||||
from .convertor import ConvRepo
|
||||
|
||||
|
||||
@dataclass
|
||||
class SomeSchema:
|
||||
a: int
|
||||
|
||||
|
||||
repo = ConvRepo()
|
||||
|
||||
|
||||
@dataclass
|
||||
class A:
|
||||
a: int
|
||||
|
||||
|
||||
class B:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
a: A
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def zero_schema() -> SomeSchema:
|
||||
return SomeSchema(1)
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def c_from_a(a: A) -> C:
|
||||
return C(a)
|
||||
|
||||
|
||||
# @repo.mark_injector()
|
||||
# def a_from_int(z: int) -> A:
|
||||
# return A(z)
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def zero_a() -> A:
|
||||
return A(34)
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def schema_from_c(c: C) -> SomeSchema:
|
||||
return SomeSchema(c.a.a)
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def a_from_schema(s: SomeSchema) -> A:
|
||||
return A(s.a)
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
def b_from_a(a: A) -> B:
|
||||
return B()
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
@asynccontextmanager
|
||||
async def a_from_b(b: B | int) -> typing.AsyncIterator[A]:
|
||||
if isinstance(b, B):
|
||||
yield A(0)
|
||||
else:
|
||||
yield A(b)
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# @repo.mark_injector()
|
||||
# def schema_from_b(b: B) -> SomeSchema:
|
||||
# return SomeSchema(0)
|
||||
|
||||
|
||||
@repo.mark_injector()
|
||||
@contextmanager
|
||||
def schema_from_a_b(a: A, b: B) -> typing.Generator[SomeSchema]:
|
||||
yield SomeSchema(a.a)
|
||||
|
||||
|
||||
# @repo.mark_injector()
|
||||
# def schema_from_a(a: A) -> SomeSchema:
|
||||
# return SomeSchema(a.a)
|
||||
|
||||
# @repo.mark_injector()
|
||||
# def schema_from_int(i: int) -> SomeSchema:
|
||||
# return SomeSchema(i)
|
||||
|
||||
@repo.mark_injector()
|
||||
def int_from_float(f: float) -> int:
|
||||
return int(f)
|
||||
|
||||
|
||||
# @repo.mark_injector()
|
||||
# def zero_int() -> int:
|
||||
# return 42
|
||||
|
||||
|
||||
# @repo.consumer
|
||||
|
||||
|
||||
def consumer(dep: SomeSchema) -> int:
|
||||
print(f'consume {dep.a}')
|
||||
return 42
|
||||
|
||||
|
||||
async def main():
|
||||
# fn = repo.get_conversion((int,), consumer, force_commutative=True, force_async=True)
|
||||
# await fn(42)
|
||||
# await fn(B())
|
||||
|
||||
# # g = walker.generate_callgraph_singletype(repo._injector_set, frozenset({int}), SomeSchema)
|
||||
from_types = (B | int,)
|
||||
walker = GraphWalker()
|
||||
g = walker.generate_callgraph(repo.convertor_set, frozenset(from_types), consumer)
|
||||
print('full graph:\n')
|
||||
print(util_mermaid.draw_callgraph_mermaid(g, split_duplicates=True))
|
||||
exploded = walker.explode_callgraph_branches(g, frozenset(from_types))
|
||||
|
||||
print('\nexploded:\n\n')
|
||||
for s_i, selected in enumerate(exploded):
|
||||
# print('\nselected path:\n')
|
||||
print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True,
|
||||
skip_title=True,
|
||||
prefix=f'{s_i}_'))
|
||||
|
||||
print('\nselect variants:\n\n')
|
||||
exploded = walker.filter_exploded_callgraph_branch(exploded)
|
||||
# print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True))
|
||||
|
||||
for s_i, selected in enumerate(exploded):
|
||||
# print('\nselected path:\n')
|
||||
print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True,
|
||||
skip_title=True,
|
||||
prefix=f'{s_i}_'))
|
||||
print(f'\nconsumed {selected.consumed_from_types}')
|
||||
# consumer({})
|
||||
# graph = walker.generate_full_depgraph(consumer)
|
||||
# print(draw_depgraph_mermaid(graph))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
114
src/breakshaft/convertor.py
Normal file
114
src/breakshaft/convertor.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any
|
||||
|
||||
from .graph_walker import GraphWalker
|
||||
from .models import ConversionPoint, Callgraph
|
||||
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
|
||||
|
||||
Tin = TypeVarTuple('Tin')
|
||||
Tout = TypeVar('Tout')
|
||||
|
||||
|
||||
class ConvRepo:
|
||||
_convertor_set: set[ConversionPoint]
|
||||
|
||||
walker: GraphWalker
|
||||
renderer: ConvertorRenderer
|
||||
|
||||
def __init__(self,
|
||||
graph_walker: Optional[GraphWalker] = None,
|
||||
renderer: Optional[ConvertorRenderer] = None, ):
|
||||
if graph_walker is None:
|
||||
graph_walker = GraphWalker()
|
||||
if renderer is None:
|
||||
renderer = InTimeGenerationConvertorRenderer()
|
||||
|
||||
self._convertor_set = set()
|
||||
self.walker = graph_walker
|
||||
self.renderer = renderer
|
||||
|
||||
@property
|
||||
def convertor_set(self):
|
||||
return self._convertor_set
|
||||
|
||||
def add_injector(self, func: Callable, rettype: Optional[type] = None):
|
||||
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
|
||||
|
||||
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
|
||||
if len(cg.variants) == 0:
|
||||
return []
|
||||
if len(cg.variants) > 1:
|
||||
raise ValueError('All callgraph subgraphs must be solved for callseq generation')
|
||||
ret = []
|
||||
variant = list(cg.variants)[0]
|
||||
for sg in variant.subgraphs:
|
||||
ret += self._callseq_from_callgraph(sg)
|
||||
ret += [variant.injector]
|
||||
return ret
|
||||
|
||||
def get_conversion(self,
|
||||
from_types: tuple[type[Unpack[Tin]]],
|
||||
fn: Callable[..., Tout],
|
||||
force_commutative: bool = True,
|
||||
allow_async: bool = True,
|
||||
allow_sync: bool = True,
|
||||
force_async: bool = False
|
||||
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]:
|
||||
if not allow_async or force_async:
|
||||
filtered_injectors: frozenset[ConversionPoint] = frozenset()
|
||||
for inj in self.convertor_set:
|
||||
if inj.is_async and not allow_async:
|
||||
continue
|
||||
if not inj.is_async and not allow_sync:
|
||||
continue
|
||||
filtered_injectors |= {inj}
|
||||
else:
|
||||
filtered_injectors = frozenset(self.convertor_set)
|
||||
|
||||
cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn)
|
||||
if cg is None:
|
||||
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}')
|
||||
|
||||
exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types))
|
||||
|
||||
selected = self.walker.filter_exploded_callgraph_branch(exploded)
|
||||
if len(selected) == 0:
|
||||
raise ValueError('Unable to select conversion path')
|
||||
|
||||
if force_commutative and len(selected) > 1:
|
||||
raise ValueError('Conversion path is not commutative')
|
||||
|
||||
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
|
||||
return self.renderer.render(from_types, callseq, force_async=force_async)
|
||||
|
||||
def mark_injector(self, *, rettype: Optional[type] = None):
|
||||
def inner(func: Callable):
|
||||
self.add_injector(func)
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
def fork(self, fork_with: Optional[set[ConversionPoint]] = None) -> ConvRepo:
|
||||
return ForkedConvRepo(self, fork_with or None, self.walker, self.renderer)
|
||||
|
||||
|
||||
class ForkedConvRepo(ConvRepo):
|
||||
_base_repo: ConvRepo
|
||||
|
||||
def __init__(self,
|
||||
fork_from: ConvRepo,
|
||||
fork_with: Optional[set[ConversionPoint]] = None,
|
||||
graph_walker: Optional[GraphWalker] = None,
|
||||
renderer: Optional[ConvertorRenderer] = None):
|
||||
super().__init__(graph_walker, renderer)
|
||||
if fork_with is None:
|
||||
fork_with = set()
|
||||
self._convertor_set = fork_with
|
||||
self._base_repo = fork_from
|
||||
|
||||
def add_injector(self, func: Callable, rettype: Optional[type] = None):
|
||||
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
|
||||
|
||||
@property
|
||||
def convertor_set(self):
|
||||
return self._base_repo.convertor_set | self._convertor_set
|
||||
200
src/breakshaft/graph_walker.py
Normal file
200
src/breakshaft/graph_walker.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import typing
|
||||
from types import NoneType
|
||||
from typing import Callable, Optional
|
||||
|
||||
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
|
||||
from .util import extract_func_argtypes, all_combinations
|
||||
|
||||
|
||||
class GraphWalker:
|
||||
|
||||
@classmethod
|
||||
def generate_callgraph(cls,
|
||||
injectors: frozenset[ConversionPoint],
|
||||
from_types: frozenset[type],
|
||||
consumer_fn: Callable) -> Optional[Callgraph]:
|
||||
|
||||
into_types: frozenset[type] = extract_func_argtypes(consumer_fn)
|
||||
|
||||
branches: frozenset[Callgraph] = frozenset()
|
||||
|
||||
for into_type in into_types:
|
||||
cg = cls.generate_callgraph_singletype(injectors, from_types, into_type)
|
||||
if cg is None:
|
||||
return None
|
||||
branches |= {cg}
|
||||
variant = CallgraphVariant(ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes(consumer_fn))),
|
||||
branches, frozenset())
|
||||
return Callgraph(frozenset({variant}))
|
||||
|
||||
@classmethod
|
||||
def generate_callgraph_singletype(cls,
|
||||
injectors: frozenset[ConversionPoint],
|
||||
from_types: frozenset[type],
|
||||
into_type: type,
|
||||
*,
|
||||
visited_path: Optional[set[ConversionPoint]] = None,
|
||||
visited_types: Optional[set[type]] = None) -> Optional[Callgraph]:
|
||||
if visited_path is None:
|
||||
visited_path = set()
|
||||
if visited_types is None:
|
||||
visited_types = set()
|
||||
|
||||
if into_type in from_types:
|
||||
return Callgraph.new_empty()
|
||||
|
||||
if into_type in visited_types:
|
||||
return None
|
||||
|
||||
head = Callgraph.new_empty()
|
||||
|
||||
visited_types.add(into_type)
|
||||
|
||||
for point in injectors:
|
||||
if point in visited_path:
|
||||
continue
|
||||
if into_type in point.requires:
|
||||
continue
|
||||
if point.injects == into_type:
|
||||
visited_path.add(point)
|
||||
variant_subgraphs = set()
|
||||
dead_end = False
|
||||
for req in point.requires:
|
||||
subg = cls.generate_callgraph_singletype(injectors,
|
||||
from_types,
|
||||
req,
|
||||
visited_path=visited_path.copy(),
|
||||
visited_types=visited_types.copy())
|
||||
if subg is None:
|
||||
dead_end = True
|
||||
break
|
||||
variant_subgraphs.add(subg)
|
||||
|
||||
if not dead_end:
|
||||
consumed = frozenset(point.requires) & from_types
|
||||
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
|
||||
head = head.add_subgraph_variant(variant)
|
||||
|
||||
if len(head.variants) == 0:
|
||||
return None
|
||||
|
||||
return head
|
||||
|
||||
@classmethod
|
||||
def explode_callgraph_branches(cls, g: Callgraph, from_types: frozenset[type]) -> list[CallgraphVariant]:
|
||||
variants = []
|
||||
for variant in g.variants:
|
||||
if len(variant.subgraphs) == 0:
|
||||
variants.append(variant)
|
||||
continue
|
||||
subg_combinations: list[list[CallgraphVariant | None]] = []
|
||||
for subg in variant.subgraphs:
|
||||
combinations: list[CallgraphVariant] = cls.explode_callgraph_branches(subg, from_types)
|
||||
if len(combinations) == 0:
|
||||
subg_combinations.append([None])
|
||||
else:
|
||||
subg_combinations.append(typing.cast(list[CallgraphVariant | None], combinations))
|
||||
|
||||
for combination in all_combinations(subg_combinations):
|
||||
if None in combination:
|
||||
combination.remove(None)
|
||||
cons: frozenset[type] = frozenset()
|
||||
cum_cmb: frozenset[Callgraph] = frozenset()
|
||||
for cmb in combination:
|
||||
if cmb is not None:
|
||||
cons |= cmb.consumed_from_types
|
||||
cum_cmb |= {Callgraph(frozenset({cmb}))}
|
||||
variants.append(
|
||||
CallgraphVariant(variant.injector, cum_cmb,
|
||||
variant.consumed_from_types | cons))
|
||||
|
||||
return variants
|
||||
|
||||
@classmethod
|
||||
def filter_exploded_callgraph_branch(cls,
|
||||
variants: list[CallgraphVariant],
|
||||
priority_injectors: Optional[frozenset[ConversionPoint | Callable]] = None,
|
||||
relevance_metric: Optional[Callable[[CallgraphVariant], int | float]] = None) \
|
||||
-> list[CallgraphVariant]:
|
||||
|
||||
if relevance_metric is None:
|
||||
template_metrics = [
|
||||
lambda x: len(x.consumed_from_types),
|
||||
lambda x: x.consumed_cumsum,
|
||||
lambda x: -x.invokes,
|
||||
]
|
||||
|
||||
for metric in template_metrics:
|
||||
if len(variants) == 1:
|
||||
break
|
||||
new_variants = cls.filter_exploded_callgraph_branch(variants, priority_injectors, metric)
|
||||
if len(new_variants) > 0:
|
||||
variants = new_variants
|
||||
|
||||
if len(variants) > 1:
|
||||
# sorting by first injector func name for creating minimal cosistancy
|
||||
# could lead to heizenbugs due to incosistancy in path selection between calls
|
||||
variants.sort(key=lambda x: x.injector.fn.__qualname__)
|
||||
return variants
|
||||
|
||||
if len(variants) < 2:
|
||||
return variants
|
||||
|
||||
if priority_injectors is None:
|
||||
priority_injectors = frozenset()
|
||||
new_priority_injectors: frozenset[ConversionPoint] = frozenset()
|
||||
for inj in priority_injectors:
|
||||
injs = {inj}
|
||||
if not isinstance(inj, ConversionPoint):
|
||||
injs = ConversionPoint.from_fn(inj)
|
||||
new_priority_injectors |= injs
|
||||
|
||||
priority_injectors = new_priority_injectors
|
||||
|
||||
best_score = max(*list(
|
||||
map(lambda x: relevance_metric(x) * (len(variants) if x.injector in priority_injectors else 1), variants)))
|
||||
|
||||
selected_variants = []
|
||||
for variant in variants:
|
||||
if relevance_metric(variant) >= best_score:
|
||||
selected_variants.append(variant)
|
||||
return selected_variants
|
||||
|
||||
@classmethod
|
||||
def select_callgraph_branch(cls,
|
||||
variants: list[CallgraphVariant],
|
||||
ignore_noncommutative=False) -> Optional[CallgraphVariant]:
|
||||
filtered = cls.filter_exploded_callgraph_branch(variants)
|
||||
if len(filtered) > 1 and not ignore_noncommutative:
|
||||
raise ValueError('Graph is not commutative')
|
||||
if len(filtered) == 0:
|
||||
return None
|
||||
return filtered[0]
|
||||
|
||||
@classmethod
|
||||
def generate_full_depgraph(cls,
|
||||
injectors: frozenset[ConversionPoint],
|
||||
consumer: Optional[Callable] = None) -> frozenset[TransformationPoint]:
|
||||
out_points: list[TransformationPoint] = []
|
||||
|
||||
for point in injectors:
|
||||
out_points.append(TransformationPoint.new_empty(point))
|
||||
|
||||
if consumer is not None:
|
||||
consumer_requires = extract_func_argtypes(consumer)
|
||||
out_points.append(
|
||||
TransformationPoint.new_empty(ConversionPoint(consumer, NoneType, tuple(consumer_requires))))
|
||||
|
||||
for i in range(len(out_points)):
|
||||
pi = out_points[i]
|
||||
for j in range(len(out_points)):
|
||||
pj = out_points[j]
|
||||
cmp = pi.has_composition(pj)
|
||||
match cmp:
|
||||
case CompositionDirection.FORWARD:
|
||||
pi = pi.with_incoming(pj)
|
||||
out_points[j] = pj
|
||||
case CompositionDirection.BACKWARD:
|
||||
pj = pj.with_incoming(pi)
|
||||
out_points[j] = pj
|
||||
return frozenset(out_points)
|
||||
175
src/breakshaft/models.py
Normal file
175
src/breakshaft/models.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import inspect
|
||||
import types
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, get_type_hints, get_origin, Generator, get_args, Union, AsyncIterator
|
||||
|
||||
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
|
||||
is_async_context_manager_factory, \
|
||||
all_combinations, is_context_manager_factory
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversionPoint:
|
||||
fn: Callable
|
||||
injects: type
|
||||
requires: tuple[type, ...]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.fn, self.injects, self.requires))
|
||||
|
||||
def __repr__(self):
|
||||
return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}'
|
||||
|
||||
@property
|
||||
def fn_args(self) -> list[type]:
|
||||
return extract_func_argtypes_seq(self.fn)
|
||||
|
||||
@property
|
||||
def is_ctx_manager(self) -> bool:
|
||||
return is_sync_context_manager_factory(self.fn) or is_async_context_manager_factory(self.fn)
|
||||
|
||||
@property
|
||||
def is_async(self):
|
||||
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
|
||||
|
||||
@classmethod
|
||||
def from_fn(cls, func: Callable, rettype: Optional[type] = None):
|
||||
if rettype is None:
|
||||
annot = get_type_hints(func)
|
||||
rettype = annot.get('return')
|
||||
|
||||
if rettype is None:
|
||||
raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified')
|
||||
|
||||
rettype_origin = get_origin(rettype)
|
||||
cm_out_origins = [
|
||||
typing.Generator,
|
||||
typing.Iterator,
|
||||
collections.abc.Generator,
|
||||
collections.abc.Iterator,
|
||||
typing.AsyncIterator,
|
||||
typing.AsyncGenerator,
|
||||
collections.abc.AsyncIterator,
|
||||
collections.abc.AsyncGenerator,
|
||||
]
|
||||
if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
|
||||
rettype = get_args(rettype)[0]
|
||||
|
||||
argtypes: list[list[type]] = []
|
||||
orig_argtypes = extract_func_argtypes_seq(func)
|
||||
for argtype in orig_argtypes:
|
||||
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
|
||||
u_types = list(get_args(argtype)) + [argtype]
|
||||
else:
|
||||
u_types = [argtype]
|
||||
argtypes.append(u_types)
|
||||
|
||||
argtype_combinations = all_combinations(argtypes)
|
||||
ret = []
|
||||
for argtype_combination in argtype_combinations:
|
||||
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination)))
|
||||
|
||||
# return InjectorPoint(func, rettype, argtypes)
|
||||
return ret
|
||||
|
||||
|
||||
class CompositionDirection(Enum):
|
||||
FORWARD = 1
|
||||
BACKWARD = -1
|
||||
NONE = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformationPoint:
|
||||
point: ConversionPoint
|
||||
incoming_points: frozenset[TransformationPoint]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.point, self.incoming_points))
|
||||
|
||||
@classmethod
|
||||
def new_empty(cls, point: ConversionPoint) -> TransformationPoint:
|
||||
return TransformationPoint(point, frozenset())
|
||||
|
||||
def has_composition(self, other: TransformationPoint) -> CompositionDirection:
|
||||
if other.point.injects in self.point.requires:
|
||||
return CompositionDirection.FORWARD
|
||||
if self.point.injects in other.point.requires:
|
||||
return CompositionDirection.BACKWARD
|
||||
return CompositionDirection.NONE
|
||||
|
||||
def copy_with(self, *,
|
||||
incoming_points: Optional[frozenset[TransformationPoint]] = None) -> TransformationPoint:
|
||||
return TransformationPoint(
|
||||
self.point,
|
||||
incoming_points or self.incoming_points,
|
||||
)
|
||||
|
||||
def with_incoming(self, incoming: TransformationPoint) -> TransformationPoint:
|
||||
return self.copy_with(
|
||||
incoming_points=self.incoming_points | frozenset({incoming})
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallgraphVariant:
|
||||
injector: ConversionPoint
|
||||
subgraphs: frozenset[Callgraph]
|
||||
consumed_from_types: frozenset[type]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.injector, self.subgraphs, self.consumed_from_types))
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return 1 + max(0, 0, *list(map(lambda x: x.depth, self.subgraphs)))
|
||||
|
||||
@property
|
||||
def invokes(self) -> int:
|
||||
return 1 + sum(list(map(lambda x: x.invokes, self.subgraphs)))
|
||||
|
||||
@property
|
||||
def consumed_cumsum(self):
|
||||
ret = len(self.consumed_from_types)
|
||||
for g in self.subgraphs:
|
||||
ret += g.consumed_cumsum
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Callgraph:
|
||||
variants: frozenset[CallgraphVariant]
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.variants)
|
||||
|
||||
@property
|
||||
def consumed_cumsum(self):
|
||||
return max(0, 0, *list(map(lambda x: x.consumed_cumsum, self.variants)))
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return 1 + max(0, 0, *list(map(lambda x: x.depth, self.variants)))
|
||||
|
||||
@property
|
||||
def invokes(self) -> int:
|
||||
return max(0, 0, *list(map(lambda x: x.invokes, self.variants)))
|
||||
|
||||
@classmethod
|
||||
def new_empty(cls) -> Callgraph:
|
||||
return cls(frozenset())
|
||||
|
||||
def add_subgraph_variant(self, new_variant: CallgraphVariant) -> Callgraph:
|
||||
return Callgraph(self.variants | {new_variant})
|
||||
|
||||
@property
|
||||
def consumed_from_types_max_cnt(self) -> int:
|
||||
ret = 0
|
||||
for variant in self.variants:
|
||||
ret = max(ret, len(variant.consumed_from_types))
|
||||
return ret
|
||||
97
src/breakshaft/renderer.py
Normal file
97
src/breakshaft/renderer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, Sequence, Callable, Optional
|
||||
|
||||
import importlib.resources
|
||||
|
||||
import jinja2
|
||||
|
||||
from .models import ConversionPoint
|
||||
from .util import hashname
|
||||
|
||||
|
||||
class ConvertorRenderer(Protocol):
|
||||
def render(self,
|
||||
from_types: Sequence[type],
|
||||
callseq: Sequence[ConversionPoint],
|
||||
force_async: bool = False) -> Callable:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversionRenderData:
|
||||
inj_hash: str
|
||||
funchash: str
|
||||
funcname: str
|
||||
funcargs: list[str]
|
||||
is_ctxmanager: bool
|
||||
is_async: bool
|
||||
|
||||
@classmethod
|
||||
def from_inj(cls, inj: ConversionPoint):
|
||||
fnargs = []
|
||||
for argtype in inj.requires:
|
||||
fnargs.append(hashname(argtype))
|
||||
return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversionArgRenderData:
|
||||
name: str
|
||||
typename: str
|
||||
typehash: str
|
||||
|
||||
|
||||
class InTimeGenerationConvertorRenderer(ConvertorRenderer):
|
||||
templateLoader: jinja2.BaseLoader
|
||||
templateEnv: jinja2.Environment
|
||||
template: jinja2.Template
|
||||
|
||||
def __init__(self,
|
||||
loader: Optional[jinja2.BaseLoader] = None,
|
||||
convertor_template: str = 'convertor.jinja2'):
|
||||
if loader is None:
|
||||
template_path = importlib.resources.files('src.breakshaft.templates')
|
||||
loader = jinja2.FileSystemLoader(str(template_path))
|
||||
self.templateLoader = loader
|
||||
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
|
||||
self.template = self.templateEnv.get_template(convertor_template)
|
||||
|
||||
def render(self,
|
||||
from_types: Sequence[type],
|
||||
callseq: Sequence[ConversionPoint],
|
||||
force_async: bool = False) -> Callable:
|
||||
|
||||
fnmap = {}
|
||||
conversion_models = []
|
||||
ret_hash = 0
|
||||
is_async = force_async
|
||||
|
||||
for call in callseq:
|
||||
fnmap[hash(call.fn)] = call.fn
|
||||
conv = ConversionRenderData.from_inj(call)
|
||||
if conv not in conversion_models:
|
||||
conversion_models.append(conv)
|
||||
if call.is_async:
|
||||
is_async = True
|
||||
|
||||
ret_hash = hash(callseq[-1].injects)
|
||||
|
||||
conv_args = []
|
||||
for i, from_type in enumerate(from_types):
|
||||
conv_args.append(ConversionArgRenderData(f'arg{i}', repr(from_type), hashname(from_type)))
|
||||
|
||||
namespace = {
|
||||
'_conv_funcmap': fnmap,
|
||||
}
|
||||
convertor_functext = self.template.render(
|
||||
ret_hash=ret_hash,
|
||||
conv_args=conv_args,
|
||||
conversions=conversion_models,
|
||||
is_async=is_async,
|
||||
)
|
||||
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
|
||||
exec(convertor_functext, namespace)
|
||||
unwrap_func = namespace['convertor']
|
||||
|
||||
return typing.cast(Callable, unwrap_func)
|
||||
0
src/breakshaft/templates/__init__.py
Normal file
0
src/breakshaft/templates/__init__.py
Normal file
13
src/breakshaft/templates/convertor.jinja2
Normal file
13
src/breakshaft/templates/convertor.jinja2
Normal file
@@ -0,0 +1,13 @@
|
||||
{% set ns = namespace(indent=0) %}
|
||||
{% if is_async %}async {% endif %}def convertor({% for arg in conv_args %}_{{arg.typehash}}: "{{arg.typename}}",{% endfor %}){% if rettype %} -> '{{rettype}}'{% endif %}:
|
||||
{% for conv in conversions %}
|
||||
{% if conv.is_ctxmanager %}
|
||||
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
||||
{{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}:
|
||||
{% set ns.indent = ns.indent + 1 %}
|
||||
{% else %}
|
||||
{{ ' ' * ns.indent }}# {{conv.funcname}}
|
||||
{{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %})
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{{ ' ' * ns.indent }}return _{{ret_hash}}
|
||||
65
src/breakshaft/util.py
Normal file
65
src/breakshaft/util.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import inspect
|
||||
from itertools import product
|
||||
from typing import Callable, get_type_hints, TypeVar, Any
|
||||
|
||||
|
||||
def extract_func_args(func: Callable) -> list[tuple[str, type]]:
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
params = sig.parameters
|
||||
|
||||
args_info = []
|
||||
for name, param in params.items():
|
||||
if name not in type_hints:
|
||||
raise TypeError(f"Param {name} must be type-annotated")
|
||||
args_info.append((name, type_hints[name]))
|
||||
return args_info
|
||||
|
||||
|
||||
def extract_func_argtypes(func: Callable) -> frozenset[type]:
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
params = sig.parameters
|
||||
|
||||
ret: frozenset[type] = frozenset()
|
||||
for name, param in params.items():
|
||||
if name not in type_hints:
|
||||
raise TypeError(f"Param {name} must be type-annotated")
|
||||
ret |= {type_hints[name]}
|
||||
return ret
|
||||
|
||||
|
||||
def extract_func_argtypes_seq(func: Callable) -> list[type]:
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
params = sig.parameters
|
||||
|
||||
ret: list[type] = []
|
||||
for name, param in params.items():
|
||||
if name not in type_hints:
|
||||
raise TypeError(f"Param {name} must be type-annotated")
|
||||
ret.append(type_hints[name])
|
||||
return ret
|
||||
|
||||
|
||||
def is_context_manager_factory(obj: object) -> bool:
|
||||
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
|
||||
|
||||
|
||||
def is_sync_context_manager_factory(obj: object) -> bool:
|
||||
return hasattr(obj, '__wrapped__') and inspect.isgeneratorfunction(obj.__wrapped__)
|
||||
|
||||
|
||||
def is_async_context_manager_factory(obj: object) -> bool:
|
||||
return hasattr(obj, '__wrapped__') and inspect.isasyncgenfunction(obj.__wrapped__)
|
||||
|
||||
|
||||
def hashname(any: Any) -> str:
|
||||
return str(hash(any)).replace('-', '_')
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def all_combinations(options: list[list[T]]) -> list[list[T]]:
|
||||
return [list(comb) for comb in product(*options)]
|
||||
70
src/breakshaft/util_mermaid.py
Normal file
70
src/breakshaft/util_mermaid.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from .models import Callgraph, TransformationPoint
|
||||
from .util import hashname
|
||||
|
||||
|
||||
def draw_depgraph_mermaid(depgraph: frozenset[TransformationPoint]):
|
||||
ret = ['flowchart TD']
|
||||
for point in depgraph:
|
||||
n = str(hash(point.point)).replace('-', '_')
|
||||
ret.append(f' {n}["{shield_mermaid_name(str(point.point))}"]')
|
||||
|
||||
for point in depgraph:
|
||||
pn = str(hash(point.point)).replace('-', '_')
|
||||
for incoming in point.incoming_points:
|
||||
n = str(hash(incoming.point)).replace('-', '_')
|
||||
ret.append(f' {n} --> {pn}')
|
||||
|
||||
return '\n'.join(ret)
|
||||
|
||||
|
||||
def shield_mermaid_name(s: str) -> str:
|
||||
syms_for_shielding = {'-', '|', '(', ')', '[', ']', '<', '>', '\'', '"', '_'}
|
||||
syms_for_shielding_after = {'<', '>'}
|
||||
for sym in syms_for_shielding:
|
||||
s = s.replace(sym, '\\' + sym)
|
||||
for sym in syms_for_shielding_after:
|
||||
s = s.replace(sym, sym + ' ')
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def _draw_callgraph_mermaid(g: Callgraph, prefix='', split_duplicates=False):
|
||||
definitions = []
|
||||
edges = []
|
||||
if not split_duplicates:
|
||||
prefix = ''
|
||||
|
||||
if len(g.variants) > 0:
|
||||
if len(g.variants) == 1:
|
||||
definitions.append(f' {prefix}{hashname(g)}["{g.invokes}"]')
|
||||
else:
|
||||
definitions.append(
|
||||
f' {prefix}{hashname(g)}["branch select \\[{g.consumed_from_types_max_cnt}\\] \\[{g.invokes}\\]"]')
|
||||
else:
|
||||
definitions.append(f' {prefix}{hashname(g)}((({g.consumed_from_types_max_cnt})))')
|
||||
|
||||
for v_i, variant in enumerate(g.variants):
|
||||
definitions.append(
|
||||
f' {prefix}{hashname(variant)}("{shield_mermaid_name(str(variant.injector))} \\[{len(variant.consumed_from_types)}\\]")')
|
||||
edges.append(f' {prefix}{hashname(g)} -.-> {prefix}{hashname(variant)}')
|
||||
for s_i, subgraph in enumerate(variant.subgraphs):
|
||||
s_prefix = str(v_i) + '_' + str(s_i) + '_' + prefix
|
||||
if not split_duplicates:
|
||||
s_prefix = ''
|
||||
d, e = _draw_callgraph_mermaid(subgraph, s_prefix, split_duplicates)
|
||||
definitions += d
|
||||
edges += e
|
||||
edges.append(f' {prefix}{hashname(variant)} ---> {s_prefix}{hashname(subgraph)}')
|
||||
|
||||
return definitions, edges
|
||||
|
||||
|
||||
def draw_callgraph_mermaid(g: Callgraph, split_duplicates=False, skip_title=False, prefix=''):
|
||||
d, e = _draw_callgraph_mermaid(g, split_duplicates=split_duplicates, prefix=prefix)
|
||||
e = list(set(e))
|
||||
e = [f' head(((head))) --> {prefix}{hashname(g)}'] + e
|
||||
ret = ''
|
||||
if not skip_title:
|
||||
ret += 'flowchart TD\n\n'
|
||||
ret += ' %%defs:\n' + '\n'.join(d) + '\n\n %%edges:\n' + '\n'.join(e)
|
||||
return ret
|
||||
Reference in New Issue
Block a user