Create basic deflator generator

This commit is contained in:
2025-08-29 00:34:14 +03:00
parent 36e343d3bc
commit cc77cc7012
4 changed files with 478 additions and 2 deletions

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from types import NoneType
from typing import Optional
from typing import TypedDict
import megasniff.exceptions
from megasniff.deflator import SchemaDeflatorGenerator, JsonObject
from . import SchemaInflatorGenerator
@@ -16,7 +19,8 @@ class ASchema:
c: float = 1.1
class BSchema(TypedDict):
@dataclass
class BSchema:
a: int
b: str
c: float
@@ -55,5 +59,109 @@ def main():
print(e)
@dataclass
class DSchema:
a: dict
b: dict[str, int | float | dict]
c: str | float | ASchema
d: ESchema
@dataclass
class ESchema:
a: list[list[list[str]]]
b: str
@dataclass
class ZSchema:
z: ZSchema | None
d: ZSchema | int
def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True)
fn = deflator.schema_to_deflator(DSchema)
print(getattr(fn, '__megasniff_sources__', '## No data'))
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
# assert ret['a'] == 1
# assert ret['b'] == 1.1
# assert ret['c'] == 'a'
# assert ret['d']['a'][0][0][0] == 'a'
# assert ret['d']['b'] == 'b'
print(json.dumps(ret, indent=4))
pass
if __name__ == '__main__':
main()
main_deflator()
from typing import Any
from megasniff.exceptions import MissingFieldException, FieldValidationException
def deflate_DSchema(from_data: DSchema) -> JsonObject:
"""
DSchema -> None
"""
ret = {}
ret['a'] = from_data.a
ret['b'] = {}
for k__8786182384272, v__8786182384272 in from_data.b.items():
k__8786182384272 = str(k__8786182384272)
v__8786182384272 = str(ret['b'][v__8786182384272])
if isinstance(from_data.c, str):
ret['c'] = str(from_data.c)
elif isinstance(from_data.c, float):
ret['c'] = float(from_data.c)
elif isinstance(from_data.c, ASchema):
ret['c'] = {}
ret['c']['a'] = int(from_data.c.a)
if isinstance(from_data.c.b, float):
ret['c']['b'] = float(from_data.c.b)
elif isinstance(from_data.c.b, str):
ret['c']['b'] = str(from_data.c.b)
if isinstance(from_data.c.bs, BSchema):
ret['c']['bs'] = {}
ret['c']['bs']['a'] = int(from_data.c.bs.a)
ret['c']['bs']['b'] = str(from_data.c.bs.b)
ret['c']['bs']['c'] = float(from_data.c.bs.c)
ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d)
elif isinstance(from_data.c.bs, NoneType):
ret['c']['bs'] = from_data.c.bs
ret['c']['c'] = float(from_data.c.c)
ret['d'] = {}
ret['d']['a'] = []
for _8786182524849 in from_data.d.a:
_8786182524849_tmp_container = []
for _8786182524629 in _8786182524849:
_8786182524629_tmp_container = []
for _8786182384293 in _8786182524629:
_8786182384293_tmp_container = str(_8786182384293)
_8786182524629_tmp_container.append(_8786182384293_tmp_container)
_8786182524849_tmp_container.append(_8786182524629_tmp_container)
ret['d']['a'].append(_8786182524849_tmp_container)
ret['d']['b'] = str(from_data.d.b)
return ret
def deflate_ASchema(from_data: ASchema) -> JsonObject:
"""
ASchema -> None
"""
ret = {}
ret['a'] = int(from_data.a)
if isinstance(from_data.b, float):
ret['b'] = float(from_data.b)
elif isinstance(from_data.b, str):
ret['b'] = str(from_data.b)
if isinstance(from_data.bs, BSchema):
ret['bs'] = {}
ret['bs']['a'] = int(from_data.bs.a)
ret['bs']['b'] = str(from_data.bs.b)
ret['bs']['c'] = float(from_data.bs.c)
ret['bs']['d'] = deflate_ASchema(from_data.bs.d)
elif isinstance(from_data.bs, NoneType):
ret['bs'] = from_data.bs
ret['c'] = float(from_data.c)
return ret

278
src/megasniff/deflator.py Normal file
View File

@@ -0,0 +1,278 @@
# Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import importlib.resources
import typing
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias
import jinja2
from .utils import *
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
class Unwrapping:
kind: str
class OtherUnwrapping(Unwrapping):
cast: str
def __init__(self, cast: str = ''):
self.kind = 'other'
self.cast = cast
@dataclass()
class ObjectFieldUnwrapping:
key: str
object_key: str
unwrapping: Unwrapping
class ObjectUnwrapping(Unwrapping):
fields: list[ObjectFieldUnwrapping]
def __init__(self, fields: list[ObjectFieldUnwrapping]):
self.kind = 'object'
self.fields = fields
class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping
def __init__(self, item_unwrap: Unwrapping):
self.kind = 'list'
self.item_unwrap = item_unwrap
class DictUnwrapping(Unwrapping):
key_unwrap: Unwrapping
value_unwrap: Unwrapping
def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping):
self.kind = 'dict'
self.key_unwrap = key_unwrap
self.value_unwrap = value_unwrap
class FuncUnwrapping(Unwrapping):
fn: str
def __init__(self, fn: str):
self.kind = 'fn'
self.fn = fn
@dataclass
class UnionKindUnwrapping:
kind: str
unwrapping: Unwrapping
class UnionUnwrapping(Unwrapping):
union_kinds: list[UnionKindUnwrapping]
def __init__(self, union_kinds: list[UnionKindUnwrapping]):
self.kind = 'union'
self.union_kinds = union_kinds
def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]:
if isinstance(t, TypeAliasType):
return _flatten_type(t.__value__)
origin = get_origin(t)
if origin is Annotated:
args = get_args(t)
return _flatten_type(args[0])[0], args[1]
return t, None
def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
t, _ = _flatten_type(t)
return 'deflate_' + typename(t).replace('.', '_')
class SchemaDeflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
object_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
strict_mode: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'deflator.jinja2',
):
self._strict_mode = strict_mode
self._store_sources = store_sources
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
loader = jinja2.FileSystemLoader(str(template_path))
self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.object_template = self.templateEnv.get_template(object_template_filename)
def schema_to_deflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None
) -> Callable[[Any], dict[str, Any]]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_deflator(schema,
strict_mode_override=strict_mode_override,
)
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
exec(txt, namespace)
fn = namespace[_schema_to_deflator_func(schema)]
if self._store_sources:
setattr(fn, '__megasniff_sources__', txt)
return fn
def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None):
if _visited_types is None:
_visited_types = []
else:
_visited_types = _visited_types.copy()
schema, field_rename = _flatten_type(schema)
if schema in _visited_types:
return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema}
_visited_types.append(schema)
ongoing_types = set()
recurcive_types = set()
origin = get_origin(schema)
ret_unw = None
if origin is not None:
if origin is list:
args = get_args(schema)
item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0],
_visited_types=_visited_types)
ret_unw = ListUnwrapping(item_unw)
recurcive_types |= item_rec
ongoing_types |= ongoings
elif origin is dict:
args = get_args(schema)
if len(args) != 2:
ret_unw = OtherUnwrapping()
else:
k, v = args
k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
ongoing_types |= k_ongoings | v_ongoings
recurcive_types |= k_rec | v_rec
ret_unw = DictUnwrapping(k_unw, v_unw)
elif origin is UnionType or origin is Union:
args = get_args(schema)
union_unwraps = []
for targ in args:
arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ,
_visited_types=_visited_types)
union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw))
ongoing_types |= ongoings
recurcive_types |= arg_rec
ret_unw = UnionUnwrapping(union_unwraps)
else:
raise NotImplementedError
else:
if schema is int:
ret_unw = OtherUnwrapping('int')
elif schema is float:
ret_unw = OtherUnwrapping('float')
elif schema is bool:
ret_unw = OtherUnwrapping('bool')
elif schema is str:
ret_unw = OtherUnwrapping('str')
elif schema is None or schema is NoneType:
ret_unw = OtherUnwrapping()
elif schema is dict:
ret_unw = OtherUnwrapping()
elif schema is list:
ret_unw = OtherUnwrapping()
elif is_class_definition(schema):
hints = typing.get_type_hints(schema)
fields = []
for k, f in hints.items():
f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types)
fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw))
ongoing_types |= ongoings
recurcive_types |= f_rec
ret_unw = ObjectUnwrapping(fields)
else:
raise NotImplementedError()
return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types
def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='deflate',
_namespace=None,
) -> tuple[str, dict]:
if strict_mode_override is not None:
strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
template = self.object_template
types_for_namespace = set()
recursive_types = {schema}
namespace = {
'JsonObject': JsonObject
}
convertor_functext = ''
added_types = set()
while len(recursive_types ^ (recursive_types & added_types)) > 0:
rec_t = list(recursive_types ^ (recursive_types & added_types))[0]
rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t)
recursive_types |= rec_rec_t
types_for_namespace |= rec_t_namespace
rec_functext = template.render(
funcname=_schema_to_deflator_func(rec_t),
from_type=typename(rec_t),
into_type=None,
root_unwrap=rec_unw,
hashname=hashname,
)
convertor_functext += '\n\n\n' + rec_functext
added_types.add(rec_t)
for t in types_for_namespace:
namespace[typename(t)] = t
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
return convertor_functext, namespace

View File

@@ -0,0 +1,82 @@
{% macro render_unwrap_object(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% for kv in unwrapping.fields %}
{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{into_container}} = []
for {{hashname(unwrapping)}} in {{from_container}}:
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.cast != '' %}
{{into_container}} = {{unwrapping.cast}}({{from_container}})
{% else %}
{{into_container}} = {{from_container}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_union(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% for union_kind in unwrapping.union_kinds %}
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.kind == 'dict' %}
{{ render_unwrap_dict(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'list' %}
{{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %}
{{into_container}} = {{ unwrapping.fn }}({{from_container}})
{% else %}
{{ render_unwrap_other(unwrapping, from_container, into_container) }}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}:
"""
{{from_type}} -> {{into_type}}
"""
{{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }}
return ret

View File

@@ -71,3 +71,11 @@ def typename(tp: type) -> str:
if get_origin(tp) is None and hasattr(tp, '__name__'):
return tp.__name__
return str(tp)
def is_class_definition(obj):
return isinstance(obj, type) or inspect.isclass(obj)
def hashname(obj) -> str:
return '_' + str(hash(obj)).replace('-', '_')