Add library

This commit is contained in:
2025-07-14 22:47:09 +03:00
parent baf76597f5
commit 7ffc620f06
15 changed files with 1465 additions and 1 deletions

View File

@@ -0,0 +1 @@

162
src/breakshaft/__main__.py Normal file
View File

@@ -0,0 +1,162 @@
from __future__ import annotations
import asyncio
import typing
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from src.breakshaft import util_mermaid
from src.breakshaft.graph_walker import GraphWalker
from src.breakshaft.models import Callgraph
from .convertor import ConvRepo
@dataclass
class SomeSchema:
a: int
repo = ConvRepo()
@dataclass
class A:
a: int
class B:
pass
@dataclass
class C:
a: A
@repo.mark_injector()
def zero_schema() -> SomeSchema:
return SomeSchema(1)
@repo.mark_injector()
def c_from_a(a: A) -> C:
return C(a)
# @repo.mark_injector()
# def a_from_int(z: int) -> A:
# return A(z)
@repo.mark_injector()
def zero_a() -> A:
return A(34)
@repo.mark_injector()
def schema_from_c(c: C) -> SomeSchema:
return SomeSchema(c.a.a)
@repo.mark_injector()
def a_from_schema(s: SomeSchema) -> A:
return A(s.a)
pass
@repo.mark_injector()
def b_from_a(a: A) -> B:
return B()
pass
@repo.mark_injector()
@asynccontextmanager
async def a_from_b(b: B | int) -> typing.AsyncIterator[A]:
if isinstance(b, B):
yield A(0)
else:
yield A(b)
pass
# @repo.mark_injector()
# def schema_from_b(b: B) -> SomeSchema:
# return SomeSchema(0)
@repo.mark_injector()
@contextmanager
def schema_from_a_b(a: A, b: B) -> typing.Generator[SomeSchema]:
yield SomeSchema(a.a)
# @repo.mark_injector()
# def schema_from_a(a: A) -> SomeSchema:
# return SomeSchema(a.a)
# @repo.mark_injector()
# def schema_from_int(i: int) -> SomeSchema:
# return SomeSchema(i)
@repo.mark_injector()
def int_from_float(f: float) -> int:
return int(f)
# @repo.mark_injector()
# def zero_int() -> int:
# return 42
# @repo.consumer
def consumer(dep: SomeSchema) -> int:
print(f'consume {dep.a}')
return 42
async def main():
# fn = repo.get_conversion((int,), consumer, force_commutative=True, force_async=True)
# await fn(42)
# await fn(B())
# # g = walker.generate_callgraph_singletype(repo._injector_set, frozenset({int}), SomeSchema)
from_types = (B | int,)
walker = GraphWalker()
g = walker.generate_callgraph(repo.convertor_set, frozenset(from_types), consumer)
print('full graph:\n')
print(util_mermaid.draw_callgraph_mermaid(g, split_duplicates=True))
exploded = walker.explode_callgraph_branches(g, frozenset(from_types))
print('\nexploded:\n\n')
for s_i, selected in enumerate(exploded):
# print('\nselected path:\n')
print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True,
skip_title=True,
prefix=f'{s_i}_'))
print('\nselect variants:\n\n')
exploded = walker.filter_exploded_callgraph_branch(exploded)
# print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True))
for s_i, selected in enumerate(exploded):
# print('\nselected path:\n')
print(util_mermaid.draw_callgraph_mermaid(Callgraph(frozenset({selected})), split_duplicates=True,
skip_title=True,
prefix=f'{s_i}_'))
print(f'\nconsumed {selected.consumed_from_types}')
# consumer({})
# graph = walker.generate_full_depgraph(consumer)
# print(draw_depgraph_mermaid(graph))
if __name__ == '__main__':
asyncio.run(main())

114
src/breakshaft/convertor.py Normal file
View File

@@ -0,0 +1,114 @@
from __future__ import annotations
from typing import Optional, Callable, Unpack, TypeVarTuple, TypeVar, Awaitable, Any
from .graph_walker import GraphWalker
from .models import ConversionPoint, Callgraph
from .renderer import ConvertorRenderer, InTimeGenerationConvertorRenderer
Tin = TypeVarTuple('Tin')
Tout = TypeVar('Tout')
class ConvRepo:
_convertor_set: set[ConversionPoint]
walker: GraphWalker
renderer: ConvertorRenderer
def __init__(self,
graph_walker: Optional[GraphWalker] = None,
renderer: Optional[ConvertorRenderer] = None, ):
if graph_walker is None:
graph_walker = GraphWalker()
if renderer is None:
renderer = InTimeGenerationConvertorRenderer()
self._convertor_set = set()
self.walker = graph_walker
self.renderer = renderer
@property
def convertor_set(self):
return self._convertor_set
def add_injector(self, func: Callable, rettype: Optional[type] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
def _callseq_from_callgraph(self, cg: Callgraph) -> list[ConversionPoint]:
if len(cg.variants) == 0:
return []
if len(cg.variants) > 1:
raise ValueError('All callgraph subgraphs must be solved for callseq generation')
ret = []
variant = list(cg.variants)[0]
for sg in variant.subgraphs:
ret += self._callseq_from_callgraph(sg)
ret += [variant.injector]
return ret
def get_conversion(self,
from_types: tuple[type[Unpack[Tin]]],
fn: Callable[..., Tout],
force_commutative: bool = True,
allow_async: bool = True,
allow_sync: bool = True,
force_async: bool = False
) -> Callable[[Unpack[Tin]], Tout] | Awaitable[Callable[[Unpack[Tin]], Tout]]:
if not allow_async or force_async:
filtered_injectors: frozenset[ConversionPoint] = frozenset()
for inj in self.convertor_set:
if inj.is_async and not allow_async:
continue
if not inj.is_async and not allow_sync:
continue
filtered_injectors |= {inj}
else:
filtered_injectors = frozenset(self.convertor_set)
cg = self.walker.generate_callgraph(filtered_injectors, frozenset(from_types), fn)
if cg is None:
raise ValueError(f'Unable to compute conversion graph on {from_types}->{fn.__qualname__}')
exploded = self.walker.explode_callgraph_branches(cg, frozenset(from_types))
selected = self.walker.filter_exploded_callgraph_branch(exploded)
if len(selected) == 0:
raise ValueError('Unable to select conversion path')
if force_commutative and len(selected) > 1:
raise ValueError('Conversion path is not commutative')
callseq = self._callseq_from_callgraph(Callgraph(frozenset([selected[0]])))
return self.renderer.render(from_types, callseq, force_async=force_async)
def mark_injector(self, *, rettype: Optional[type] = None):
def inner(func: Callable):
self.add_injector(func)
return func
return inner
def fork(self, fork_with: Optional[set[ConversionPoint]] = None) -> ConvRepo:
return ForkedConvRepo(self, fork_with or None, self.walker, self.renderer)
class ForkedConvRepo(ConvRepo):
_base_repo: ConvRepo
def __init__(self,
fork_from: ConvRepo,
fork_with: Optional[set[ConversionPoint]] = None,
graph_walker: Optional[GraphWalker] = None,
renderer: Optional[ConvertorRenderer] = None):
super().__init__(graph_walker, renderer)
if fork_with is None:
fork_with = set()
self._convertor_set = fork_with
self._base_repo = fork_from
def add_injector(self, func: Callable, rettype: Optional[type] = None):
self._convertor_set |= set(ConversionPoint.from_fn(func, rettype=rettype))
@property
def convertor_set(self):
return self._base_repo.convertor_set | self._convertor_set

View File

@@ -0,0 +1,200 @@
import typing
from types import NoneType
from typing import Callable, Optional
from .models import ConversionPoint, Callgraph, CallgraphVariant, TransformationPoint, CompositionDirection
from .util import extract_func_argtypes, all_combinations
class GraphWalker:
@classmethod
def generate_callgraph(cls,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type],
consumer_fn: Callable) -> Optional[Callgraph]:
into_types: frozenset[type] = extract_func_argtypes(consumer_fn)
branches: frozenset[Callgraph] = frozenset()
for into_type in into_types:
cg = cls.generate_callgraph_singletype(injectors, from_types, into_type)
if cg is None:
return None
branches |= {cg}
variant = CallgraphVariant(ConversionPoint(consumer_fn, NoneType, tuple(extract_func_argtypes(consumer_fn))),
branches, frozenset())
return Callgraph(frozenset({variant}))
@classmethod
def generate_callgraph_singletype(cls,
injectors: frozenset[ConversionPoint],
from_types: frozenset[type],
into_type: type,
*,
visited_path: Optional[set[ConversionPoint]] = None,
visited_types: Optional[set[type]] = None) -> Optional[Callgraph]:
if visited_path is None:
visited_path = set()
if visited_types is None:
visited_types = set()
if into_type in from_types:
return Callgraph.new_empty()
if into_type in visited_types:
return None
head = Callgraph.new_empty()
visited_types.add(into_type)
for point in injectors:
if point in visited_path:
continue
if into_type in point.requires:
continue
if point.injects == into_type:
visited_path.add(point)
variant_subgraphs = set()
dead_end = False
for req in point.requires:
subg = cls.generate_callgraph_singletype(injectors,
from_types,
req,
visited_path=visited_path.copy(),
visited_types=visited_types.copy())
if subg is None:
dead_end = True
break
variant_subgraphs.add(subg)
if not dead_end:
consumed = frozenset(point.requires) & from_types
variant = CallgraphVariant(point, frozenset(variant_subgraphs), consumed)
head = head.add_subgraph_variant(variant)
if len(head.variants) == 0:
return None
return head
@classmethod
def explode_callgraph_branches(cls, g: Callgraph, from_types: frozenset[type]) -> list[CallgraphVariant]:
variants = []
for variant in g.variants:
if len(variant.subgraphs) == 0:
variants.append(variant)
continue
subg_combinations: list[list[CallgraphVariant | None]] = []
for subg in variant.subgraphs:
combinations: list[CallgraphVariant] = cls.explode_callgraph_branches(subg, from_types)
if len(combinations) == 0:
subg_combinations.append([None])
else:
subg_combinations.append(typing.cast(list[CallgraphVariant | None], combinations))
for combination in all_combinations(subg_combinations):
if None in combination:
combination.remove(None)
cons: frozenset[type] = frozenset()
cum_cmb: frozenset[Callgraph] = frozenset()
for cmb in combination:
if cmb is not None:
cons |= cmb.consumed_from_types
cum_cmb |= {Callgraph(frozenset({cmb}))}
variants.append(
CallgraphVariant(variant.injector, cum_cmb,
variant.consumed_from_types | cons))
return variants
@classmethod
def filter_exploded_callgraph_branch(cls,
variants: list[CallgraphVariant],
priority_injectors: Optional[frozenset[ConversionPoint | Callable]] = None,
relevance_metric: Optional[Callable[[CallgraphVariant], int | float]] = None) \
-> list[CallgraphVariant]:
if relevance_metric is None:
template_metrics = [
lambda x: len(x.consumed_from_types),
lambda x: x.consumed_cumsum,
lambda x: -x.invokes,
]
for metric in template_metrics:
if len(variants) == 1:
break
new_variants = cls.filter_exploded_callgraph_branch(variants, priority_injectors, metric)
if len(new_variants) > 0:
variants = new_variants
if len(variants) > 1:
# sorting by first injector func name for creating minimal cosistancy
# could lead to heizenbugs due to incosistancy in path selection between calls
variants.sort(key=lambda x: x.injector.fn.__qualname__)
return variants
if len(variants) < 2:
return variants
if priority_injectors is None:
priority_injectors = frozenset()
new_priority_injectors: frozenset[ConversionPoint] = frozenset()
for inj in priority_injectors:
injs = {inj}
if not isinstance(inj, ConversionPoint):
injs = ConversionPoint.from_fn(inj)
new_priority_injectors |= injs
priority_injectors = new_priority_injectors
best_score = max(*list(
map(lambda x: relevance_metric(x) * (len(variants) if x.injector in priority_injectors else 1), variants)))
selected_variants = []
for variant in variants:
if relevance_metric(variant) >= best_score:
selected_variants.append(variant)
return selected_variants
@classmethod
def select_callgraph_branch(cls,
variants: list[CallgraphVariant],
ignore_noncommutative=False) -> Optional[CallgraphVariant]:
filtered = cls.filter_exploded_callgraph_branch(variants)
if len(filtered) > 1 and not ignore_noncommutative:
raise ValueError('Graph is not commutative')
if len(filtered) == 0:
return None
return filtered[0]
@classmethod
def generate_full_depgraph(cls,
injectors: frozenset[ConversionPoint],
consumer: Optional[Callable] = None) -> frozenset[TransformationPoint]:
out_points: list[TransformationPoint] = []
for point in injectors:
out_points.append(TransformationPoint.new_empty(point))
if consumer is not None:
consumer_requires = extract_func_argtypes(consumer)
out_points.append(
TransformationPoint.new_empty(ConversionPoint(consumer, NoneType, tuple(consumer_requires))))
for i in range(len(out_points)):
pi = out_points[i]
for j in range(len(out_points)):
pj = out_points[j]
cmp = pi.has_composition(pj)
match cmp:
case CompositionDirection.FORWARD:
pi = pi.with_incoming(pj)
out_points[j] = pj
case CompositionDirection.BACKWARD:
pj = pj.with_incoming(pi)
out_points[j] = pj
return frozenset(out_points)

175
src/breakshaft/models.py Normal file
View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import collections.abc
import inspect
import types
import typing
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional, get_type_hints, get_origin, Generator, get_args, Union, AsyncIterator
from .util import extract_func_argtypes, extract_func_argtypes_seq, is_sync_context_manager_factory, \
is_async_context_manager_factory, \
all_combinations, is_context_manager_factory
@dataclass(frozen=True)
class ConversionPoint:
fn: Callable
injects: type
requires: tuple[type, ...]
def __hash__(self):
return hash((self.fn, self.injects, self.requires))
def __repr__(self):
return f'({",".join(map(str, self.requires))}) -> {self.injects.__qualname__}: {self.fn.__qualname__}'
@property
def fn_args(self) -> list[type]:
return extract_func_argtypes_seq(self.fn)
@property
def is_ctx_manager(self) -> bool:
return is_sync_context_manager_factory(self.fn) or is_async_context_manager_factory(self.fn)
@property
def is_async(self):
return inspect.iscoroutinefunction(self.fn) or is_async_context_manager_factory(self.fn)
@classmethod
def from_fn(cls, func: Callable, rettype: Optional[type] = None):
if rettype is None:
annot = get_type_hints(func)
rettype = annot.get('return')
if rettype is None:
raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified')
rettype_origin = get_origin(rettype)
cm_out_origins = [
typing.Generator,
typing.Iterator,
collections.abc.Generator,
collections.abc.Iterator,
typing.AsyncIterator,
typing.AsyncGenerator,
collections.abc.AsyncIterator,
collections.abc.AsyncGenerator,
]
if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func):
rettype = get_args(rettype)[0]
argtypes: list[list[type]] = []
orig_argtypes = extract_func_argtypes_seq(func)
for argtype in orig_argtypes:
if isinstance(argtype, types.UnionType) or get_origin(argtype) is Union:
u_types = list(get_args(argtype)) + [argtype]
else:
u_types = [argtype]
argtypes.append(u_types)
argtype_combinations = all_combinations(argtypes)
ret = []
for argtype_combination in argtype_combinations:
ret.append(ConversionPoint(func, rettype, tuple(argtype_combination)))
# return InjectorPoint(func, rettype, argtypes)
return ret
class CompositionDirection(Enum):
FORWARD = 1
BACKWARD = -1
NONE = 0
@dataclass(frozen=True)
class TransformationPoint:
point: ConversionPoint
incoming_points: frozenset[TransformationPoint]
def __hash__(self):
return hash((self.point, self.incoming_points))
@classmethod
def new_empty(cls, point: ConversionPoint) -> TransformationPoint:
return TransformationPoint(point, frozenset())
def has_composition(self, other: TransformationPoint) -> CompositionDirection:
if other.point.injects in self.point.requires:
return CompositionDirection.FORWARD
if self.point.injects in other.point.requires:
return CompositionDirection.BACKWARD
return CompositionDirection.NONE
def copy_with(self, *,
incoming_points: Optional[frozenset[TransformationPoint]] = None) -> TransformationPoint:
return TransformationPoint(
self.point,
incoming_points or self.incoming_points,
)
def with_incoming(self, incoming: TransformationPoint) -> TransformationPoint:
return self.copy_with(
incoming_points=self.incoming_points | frozenset({incoming})
)
@dataclass(frozen=True)
class CallgraphVariant:
injector: ConversionPoint
subgraphs: frozenset[Callgraph]
consumed_from_types: frozenset[type]
def __hash__(self):
return hash((self.injector, self.subgraphs, self.consumed_from_types))
@property
def depth(self) -> int:
return 1 + max(0, 0, *list(map(lambda x: x.depth, self.subgraphs)))
@property
def invokes(self) -> int:
return 1 + sum(list(map(lambda x: x.invokes, self.subgraphs)))
@property
def consumed_cumsum(self):
ret = len(self.consumed_from_types)
for g in self.subgraphs:
ret += g.consumed_cumsum
return ret
@dataclass(frozen=True)
class Callgraph:
variants: frozenset[CallgraphVariant]
def __hash__(self):
return hash(self.variants)
@property
def consumed_cumsum(self):
return max(0, 0, *list(map(lambda x: x.consumed_cumsum, self.variants)))
@property
def depth(self) -> int:
return 1 + max(0, 0, *list(map(lambda x: x.depth, self.variants)))
@property
def invokes(self) -> int:
return max(0, 0, *list(map(lambda x: x.invokes, self.variants)))
@classmethod
def new_empty(cls) -> Callgraph:
return cls(frozenset())
def add_subgraph_variant(self, new_variant: CallgraphVariant) -> Callgraph:
return Callgraph(self.variants | {new_variant})
@property
def consumed_from_types_max_cnt(self) -> int:
ret = 0
for variant in self.variants:
ret = max(ret, len(variant.consumed_from_types))
return ret

View File

@@ -0,0 +1,97 @@
import typing
from dataclasses import dataclass
from typing import Protocol, Sequence, Callable, Optional
import importlib.resources
import jinja2
from .models import ConversionPoint
from .util import hashname
class ConvertorRenderer(Protocol):
def render(self,
from_types: Sequence[type],
callseq: Sequence[ConversionPoint],
force_async: bool = False) -> Callable:
raise NotImplementedError()
@dataclass
class ConversionRenderData:
inj_hash: str
funchash: str
funcname: str
funcargs: list[str]
is_ctxmanager: bool
is_async: bool
@classmethod
def from_inj(cls, inj: ConversionPoint):
fnargs = []
for argtype in inj.requires:
fnargs.append(hashname(argtype))
return cls(hashname(inj.injects), hashname(inj.fn), repr(inj.fn), fnargs, inj.is_ctx_manager, inj.is_async)
@dataclass
class ConversionArgRenderData:
name: str
typename: str
typehash: str
class InTimeGenerationConvertorRenderer(ConvertorRenderer):
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
template: jinja2.Template
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
convertor_template: str = 'convertor.jinja2'):
if loader is None:
template_path = importlib.resources.files('src.breakshaft.templates')
loader = jinja2.FileSystemLoader(str(template_path))
self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(convertor_template)
def render(self,
from_types: Sequence[type],
callseq: Sequence[ConversionPoint],
force_async: bool = False) -> Callable:
fnmap = {}
conversion_models = []
ret_hash = 0
is_async = force_async
for call in callseq:
fnmap[hash(call.fn)] = call.fn
conv = ConversionRenderData.from_inj(call)
if conv not in conversion_models:
conversion_models.append(conv)
if call.is_async:
is_async = True
ret_hash = hash(callseq[-1].injects)
conv_args = []
for i, from_type in enumerate(from_types):
conv_args.append(ConversionArgRenderData(f'arg{i}', repr(from_type), hashname(from_type)))
namespace = {
'_conv_funcmap': fnmap,
}
convertor_functext = self.template.render(
ret_hash=ret_hash,
conv_args=conv_args,
conversions=conversion_models,
is_async=is_async,
)
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
exec(convertor_functext, namespace)
unwrap_func = namespace['convertor']
return typing.cast(Callable, unwrap_func)

View File

View File

@@ -0,0 +1,13 @@
{% set ns = namespace(indent=0) %}
{% if is_async %}async {% endif %}def convertor({% for arg in conv_args %}_{{arg.typehash}}: "{{arg.typename}}",{% endfor %}){% if rettype %} -> '{{rettype}}'{% endif %}:
{% for conv in conversions %}
{% if conv.is_ctxmanager %}
{{ ' ' * ns.indent }}# {{conv.funcname}}
{{ ' ' * ns.indent }}{% if conv.is_async %}async {% endif %}with _conv_funcmap[{{ conv.funchash }}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %}) as _{{ conv.inj_hash }}:
{% set ns.indent = ns.indent + 1 %}
{% else %}
{{ ' ' * ns.indent }}# {{conv.funcname}}
{{ ' ' * ns.indent }}_{{conv.inj_hash}} = {% if conv.is_async %}await {% endif %}_conv_funcmap[{{conv.funchash}}]({% for conv_arg in conv.funcargs %}_{{conv_arg}}, {% endfor %})
{% endif %}
{% endfor %}
{{ ' ' * ns.indent }}return _{{ret_hash}}

65
src/breakshaft/util.py Normal file
View File

@@ -0,0 +1,65 @@
import inspect
from itertools import product
from typing import Callable, get_type_hints, TypeVar, Any
def extract_func_args(func: Callable) -> list[tuple[str, type]]:
sig = inspect.signature(func)
type_hints = get_type_hints(func)
params = sig.parameters
args_info = []
for name, param in params.items():
if name not in type_hints:
raise TypeError(f"Param {name} must be type-annotated")
args_info.append((name, type_hints[name]))
return args_info
def extract_func_argtypes(func: Callable) -> frozenset[type]:
sig = inspect.signature(func)
type_hints = get_type_hints(func)
params = sig.parameters
ret: frozenset[type] = frozenset()
for name, param in params.items():
if name not in type_hints:
raise TypeError(f"Param {name} must be type-annotated")
ret |= {type_hints[name]}
return ret
def extract_func_argtypes_seq(func: Callable) -> list[type]:
sig = inspect.signature(func)
type_hints = get_type_hints(func)
params = sig.parameters
ret: list[type] = []
for name, param in params.items():
if name not in type_hints:
raise TypeError(f"Param {name} must be type-annotated")
ret.append(type_hints[name])
return ret
def is_context_manager_factory(obj: object) -> bool:
return is_sync_context_manager_factory(obj) or is_async_context_manager_factory(obj)
def is_sync_context_manager_factory(obj: object) -> bool:
return hasattr(obj, '__wrapped__') and inspect.isgeneratorfunction(obj.__wrapped__)
def is_async_context_manager_factory(obj: object) -> bool:
return hasattr(obj, '__wrapped__') and inspect.isasyncgenfunction(obj.__wrapped__)
def hashname(any: Any) -> str:
return str(hash(any)).replace('-', '_')
T = TypeVar('T')
def all_combinations(options: list[list[T]]) -> list[list[T]]:
return [list(comb) for comb in product(*options)]

View File

@@ -0,0 +1,70 @@
from .models import Callgraph, TransformationPoint
from .util import hashname
def draw_depgraph_mermaid(depgraph: frozenset[TransformationPoint]):
ret = ['flowchart TD']
for point in depgraph:
n = str(hash(point.point)).replace('-', '_')
ret.append(f' {n}["{shield_mermaid_name(str(point.point))}"]')
for point in depgraph:
pn = str(hash(point.point)).replace('-', '_')
for incoming in point.incoming_points:
n = str(hash(incoming.point)).replace('-', '_')
ret.append(f' {n} --> {pn}')
return '\n'.join(ret)
def shield_mermaid_name(s: str) -> str:
syms_for_shielding = {'-', '|', '(', ')', '[', ']', '<', '>', '\'', '"', '_'}
syms_for_shielding_after = {'<', '>'}
for sym in syms_for_shielding:
s = s.replace(sym, '\\' + sym)
for sym in syms_for_shielding_after:
s = s.replace(sym, sym + ' ')
return s
def _draw_callgraph_mermaid(g: Callgraph, prefix='', split_duplicates=False):
definitions = []
edges = []
if not split_duplicates:
prefix = ''
if len(g.variants) > 0:
if len(g.variants) == 1:
definitions.append(f' {prefix}{hashname(g)}["{g.invokes}"]')
else:
definitions.append(
f' {prefix}{hashname(g)}["branch select \\[{g.consumed_from_types_max_cnt}\\] \\[{g.invokes}\\]"]')
else:
definitions.append(f' {prefix}{hashname(g)}((({g.consumed_from_types_max_cnt})))')
for v_i, variant in enumerate(g.variants):
definitions.append(
f' {prefix}{hashname(variant)}("{shield_mermaid_name(str(variant.injector))} \\[{len(variant.consumed_from_types)}\\]")')
edges.append(f' {prefix}{hashname(g)} -.-> {prefix}{hashname(variant)}')
for s_i, subgraph in enumerate(variant.subgraphs):
s_prefix = str(v_i) + '_' + str(s_i) + '_' + prefix
if not split_duplicates:
s_prefix = ''
d, e = _draw_callgraph_mermaid(subgraph, s_prefix, split_duplicates)
definitions += d
edges += e
edges.append(f' {prefix}{hashname(variant)} ---> {s_prefix}{hashname(subgraph)}')
return definitions, edges
def draw_callgraph_mermaid(g: Callgraph, split_duplicates=False, skip_title=False, prefix=''):
d, e = _draw_callgraph_mermaid(g, split_duplicates=split_duplicates, prefix=prefix)
e = list(set(e))
e = [f' head(((head))) --> {prefix}{hashname(g)}'] + e
ret = ''
if not skip_title:
ret += 'flowchart TD\n\n'
ret += ' %%defs:\n' + '\n'.join(d) + '\n\n %%edges:\n' + '\n'.join(e)
return ret