Create basic deflator generator

This commit is contained in:
2025-08-29 00:34:14 +03:00
parent 36e343d3bc
commit cc77cc7012
4 changed files with 478 additions and 2 deletions

View File

@@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
import json
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType
from typing import Optional from typing import Optional
from typing import TypedDict from typing import TypedDict
import megasniff.exceptions import megasniff.exceptions
from megasniff.deflator import SchemaDeflatorGenerator, JsonObject
from . import SchemaInflatorGenerator from . import SchemaInflatorGenerator
@@ -16,7 +19,8 @@ class ASchema:
c: float = 1.1 c: float = 1.1
class BSchema(TypedDict): @dataclass
class BSchema:
a: int a: int
b: str b: str
c: float c: float
@@ -55,5 +59,109 @@ def main():
print(e) print(e)
@dataclass
class DSchema:
a: dict
b: dict[str, int | float | dict]
c: str | float | ASchema
d: ESchema
@dataclass
class ESchema:
a: list[list[list[str]]]
b: str
@dataclass
class ZSchema:
z: ZSchema | None
d: ZSchema | int
def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True)
fn = deflator.schema_to_deflator(DSchema)
print(getattr(fn, '__megasniff_sources__', '## No data'))
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
# assert ret['a'] == 1
# assert ret['b'] == 1.1
# assert ret['c'] == 'a'
# assert ret['d']['a'][0][0][0] == 'a'
# assert ret['d']['b'] == 'b'
print(json.dumps(ret, indent=4))
pass
if __name__ == '__main__': if __name__ == '__main__':
main() main_deflator()
from typing import Any
from megasniff.exceptions import MissingFieldException, FieldValidationException
def deflate_DSchema(from_data: DSchema) -> JsonObject:
"""
DSchema -> None
"""
ret = {}
ret['a'] = from_data.a
ret['b'] = {}
for k__8786182384272, v__8786182384272 in from_data.b.items():
k__8786182384272 = str(k__8786182384272)
v__8786182384272 = str(ret['b'][v__8786182384272])
if isinstance(from_data.c, str):
ret['c'] = str(from_data.c)
elif isinstance(from_data.c, float):
ret['c'] = float(from_data.c)
elif isinstance(from_data.c, ASchema):
ret['c'] = {}
ret['c']['a'] = int(from_data.c.a)
if isinstance(from_data.c.b, float):
ret['c']['b'] = float(from_data.c.b)
elif isinstance(from_data.c.b, str):
ret['c']['b'] = str(from_data.c.b)
if isinstance(from_data.c.bs, BSchema):
ret['c']['bs'] = {}
ret['c']['bs']['a'] = int(from_data.c.bs.a)
ret['c']['bs']['b'] = str(from_data.c.bs.b)
ret['c']['bs']['c'] = float(from_data.c.bs.c)
ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d)
elif isinstance(from_data.c.bs, NoneType):
ret['c']['bs'] = from_data.c.bs
ret['c']['c'] = float(from_data.c.c)
ret['d'] = {}
ret['d']['a'] = []
for _8786182524849 in from_data.d.a:
_8786182524849_tmp_container = []
for _8786182524629 in _8786182524849:
_8786182524629_tmp_container = []
for _8786182384293 in _8786182524629:
_8786182384293_tmp_container = str(_8786182384293)
_8786182524629_tmp_container.append(_8786182384293_tmp_container)
_8786182524849_tmp_container.append(_8786182524629_tmp_container)
ret['d']['a'].append(_8786182524849_tmp_container)
ret['d']['b'] = str(from_data.d.b)
return ret
def deflate_ASchema(from_data: ASchema) -> JsonObject:
"""
ASchema -> None
"""
ret = {}
ret['a'] = int(from_data.a)
if isinstance(from_data.b, float):
ret['b'] = float(from_data.b)
elif isinstance(from_data.b, str):
ret['b'] = str(from_data.b)
if isinstance(from_data.bs, BSchema):
ret['bs'] = {}
ret['bs']['a'] = int(from_data.bs.a)
ret['bs']['b'] = str(from_data.bs.b)
ret['bs']['c'] = float(from_data.bs.c)
ret['bs']['d'] = deflate_ASchema(from_data.bs.d)
elif isinstance(from_data.bs, NoneType):
ret['bs'] = from_data.bs
ret['c'] = float(from_data.c)
return ret

278
src/megasniff/deflator.py Normal file
View File

@@ -0,0 +1,278 @@
# Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import importlib.resources
import typing
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias
import jinja2
from .utils import *
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
class Unwrapping:
kind: str
class OtherUnwrapping(Unwrapping):
cast: str
def __init__(self, cast: str = ''):
self.kind = 'other'
self.cast = cast
@dataclass()
class ObjectFieldUnwrapping:
key: str
object_key: str
unwrapping: Unwrapping
class ObjectUnwrapping(Unwrapping):
fields: list[ObjectFieldUnwrapping]
def __init__(self, fields: list[ObjectFieldUnwrapping]):
self.kind = 'object'
self.fields = fields
class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping
def __init__(self, item_unwrap: Unwrapping):
self.kind = 'list'
self.item_unwrap = item_unwrap
class DictUnwrapping(Unwrapping):
key_unwrap: Unwrapping
value_unwrap: Unwrapping
def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping):
self.kind = 'dict'
self.key_unwrap = key_unwrap
self.value_unwrap = value_unwrap
class FuncUnwrapping(Unwrapping):
fn: str
def __init__(self, fn: str):
self.kind = 'fn'
self.fn = fn
@dataclass
class UnionKindUnwrapping:
kind: str
unwrapping: Unwrapping
class UnionUnwrapping(Unwrapping):
union_kinds: list[UnionKindUnwrapping]
def __init__(self, union_kinds: list[UnionKindUnwrapping]):
self.kind = 'union'
self.union_kinds = union_kinds
def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]:
if isinstance(t, TypeAliasType):
return _flatten_type(t.__value__)
origin = get_origin(t)
if origin is Annotated:
args = get_args(t)
return _flatten_type(args[0])[0], args[1]
return t, None
def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
t, _ = _flatten_type(t)
return 'deflate_' + typename(t).replace('.', '_')
class SchemaDeflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
object_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
strict_mode: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'deflator.jinja2',
):
self._strict_mode = strict_mode
self._store_sources = store_sources
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
loader = jinja2.FileSystemLoader(str(template_path))
self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.object_template = self.templateEnv.get_template(object_template_filename)
def schema_to_deflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None
) -> Callable[[Any], dict[str, Any]]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_deflator(schema,
strict_mode_override=strict_mode_override,
)
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
exec(txt, namespace)
fn = namespace[_schema_to_deflator_func(schema)]
if self._store_sources:
setattr(fn, '__megasniff_sources__', txt)
return fn
def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None):
if _visited_types is None:
_visited_types = []
else:
_visited_types = _visited_types.copy()
schema, field_rename = _flatten_type(schema)
if schema in _visited_types:
return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema}
_visited_types.append(schema)
ongoing_types = set()
recurcive_types = set()
origin = get_origin(schema)
ret_unw = None
if origin is not None:
if origin is list:
args = get_args(schema)
item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0],
_visited_types=_visited_types)
ret_unw = ListUnwrapping(item_unw)
recurcive_types |= item_rec
ongoing_types |= ongoings
elif origin is dict:
args = get_args(schema)
if len(args) != 2:
ret_unw = OtherUnwrapping()
else:
k, v = args
k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
ongoing_types |= k_ongoings | v_ongoings
recurcive_types |= k_rec | v_rec
ret_unw = DictUnwrapping(k_unw, v_unw)
elif origin is UnionType or origin is Union:
args = get_args(schema)
union_unwraps = []
for targ in args:
arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ,
_visited_types=_visited_types)
union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw))
ongoing_types |= ongoings
recurcive_types |= arg_rec
ret_unw = UnionUnwrapping(union_unwraps)
else:
raise NotImplementedError
else:
if schema is int:
ret_unw = OtherUnwrapping('int')
elif schema is float:
ret_unw = OtherUnwrapping('float')
elif schema is bool:
ret_unw = OtherUnwrapping('bool')
elif schema is str:
ret_unw = OtherUnwrapping('str')
elif schema is None or schema is NoneType:
ret_unw = OtherUnwrapping()
elif schema is dict:
ret_unw = OtherUnwrapping()
elif schema is list:
ret_unw = OtherUnwrapping()
elif is_class_definition(schema):
hints = typing.get_type_hints(schema)
fields = []
for k, f in hints.items():
f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types)
fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw))
ongoing_types |= ongoings
recurcive_types |= f_rec
ret_unw = ObjectUnwrapping(fields)
else:
raise NotImplementedError()
return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types
def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='deflate',
_namespace=None,
) -> tuple[str, dict]:
if strict_mode_override is not None:
strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
template = self.object_template
types_for_namespace = set()
recursive_types = {schema}
namespace = {
'JsonObject': JsonObject
}
convertor_functext = ''
added_types = set()
while len(recursive_types ^ (recursive_types & added_types)) > 0:
rec_t = list(recursive_types ^ (recursive_types & added_types))[0]
rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t)
recursive_types |= rec_rec_t
types_for_namespace |= rec_t_namespace
rec_functext = template.render(
funcname=_schema_to_deflator_func(rec_t),
from_type=typename(rec_t),
into_type=None,
root_unwrap=rec_unw,
hashname=hashname,
)
convertor_functext += '\n\n\n' + rec_functext
added_types.add(rec_t)
for t in types_for_namespace:
namespace[typename(t)] = t
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
return convertor_functext, namespace

View File

@@ -0,0 +1,82 @@
{% macro render_unwrap_object(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% for kv in unwrapping.fields %}
{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{into_container}} = []
for {{hashname(unwrapping)}} in {{from_container}}:
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.cast != '' %}
{{into_container}} = {{unwrapping.cast}}({{from_container}})
{% else %}
{{into_container}} = {{from_container}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_union(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% for union_kind in unwrapping.union_kinds %}
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.kind == 'dict' %}
{{ render_unwrap_dict(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'list' %}
{{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %}
{{into_container}} = {{ unwrapping.fn }}({{from_container}})
{% else %}
{{ render_unwrap_other(unwrapping, from_container, into_container) }}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}:
"""
{{from_type}} -> {{into_type}}
"""
{{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }}
return ret

View File

@@ -71,3 +71,11 @@ def typename(tp: type) -> str:
if get_origin(tp) is None and hasattr(tp, '__name__'): if get_origin(tp) is None and hasattr(tp, '__name__'):
return tp.__name__ return tp.__name__
return str(tp) return str(tp)
def is_class_definition(obj):
return isinstance(obj, type) or inspect.isclass(obj)
def hashname(obj) -> str:
return '_' + str(hash(obj)).replace('-', '_')