diff --git a/src/megasniff/__main__.py b/src/megasniff/__main__.py index 41203a8..38b4270 100644 --- a/src/megasniff/__main__.py +++ b/src/megasniff/__main__.py @@ -1,10 +1,13 @@ from __future__ import annotations +import json from dataclasses import dataclass +from types import NoneType from typing import Optional from typing import TypedDict import megasniff.exceptions +from megasniff.deflator import SchemaDeflatorGenerator, JsonObject from . import SchemaInflatorGenerator @@ -16,7 +19,8 @@ class ASchema: c: float = 1.1 -class BSchema(TypedDict): +@dataclass +class BSchema: a: int b: str c: float @@ -55,5 +59,109 @@ def main(): print(e) +@dataclass +class DSchema: + a: dict + b: dict[str, int | float | dict] + c: str | float | ASchema + d: ESchema + + +@dataclass +class ESchema: + a: list[list[list[str]]] + b: str + + +@dataclass +class ZSchema: + z: ZSchema | None + d: ZSchema | int + + +def main_deflator(): + deflator = SchemaDeflatorGenerator(store_sources=True) + fn = deflator.schema_to_deflator(DSchema) + print(getattr(fn, '__megasniff_sources__', '## No data')) + # ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42))) + ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b'))) + # assert ret['a'] == 1 + # assert ret['b'] == 1.1 + # assert ret['c'] == 'a' + # assert ret['d']['a'][0][0][0] == 'a' + # assert ret['d']['b'] == 'b' + print(json.dumps(ret, indent=4)) + pass + + if __name__ == '__main__': - main() + main_deflator() + + + +from typing import Any +from megasniff.exceptions import MissingFieldException, FieldValidationException + +def deflate_DSchema(from_data: DSchema) -> JsonObject: + """ + DSchema -> None + """ + ret = {} + ret['a'] = from_data.a + ret['b'] = {} + for k__8786182384272, v__8786182384272 in from_data.b.items(): + k__8786182384272 = str(k__8786182384272) + v__8786182384272 = str(ret['b'][v__8786182384272]) + if isinstance(from_data.c, str): + ret['c'] = str(from_data.c) + elif isinstance(from_data.c, float): + ret['c'] = float(from_data.c) + elif isinstance(from_data.c, ASchema): + ret['c'] = {} + ret['c']['a'] = int(from_data.c.a) + if isinstance(from_data.c.b, float): + ret['c']['b'] = float(from_data.c.b) + elif isinstance(from_data.c.b, str): + ret['c']['b'] = str(from_data.c.b) + if isinstance(from_data.c.bs, BSchema): + ret['c']['bs'] = {} + ret['c']['bs']['a'] = int(from_data.c.bs.a) + ret['c']['bs']['b'] = str(from_data.c.bs.b) + ret['c']['bs']['c'] = float(from_data.c.bs.c) + ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d) + elif isinstance(from_data.c.bs, NoneType): + ret['c']['bs'] = from_data.c.bs + ret['c']['c'] = float(from_data.c.c) + ret['d'] = {} + ret['d']['a'] = [] + for _8786182524849 in from_data.d.a: + _8786182524849_tmp_container = [] + for _8786182524629 in _8786182524849: + _8786182524629_tmp_container = [] + for _8786182384293 in _8786182524629: + _8786182384293_tmp_container = str(_8786182384293) + _8786182524629_tmp_container.append(_8786182384293_tmp_container) + _8786182524849_tmp_container.append(_8786182524629_tmp_container) + ret['d']['a'].append(_8786182524849_tmp_container) + ret['d']['b'] = str(from_data.d.b) + return ret +def deflate_ASchema(from_data: ASchema) -> JsonObject: + """ + ASchema -> None + """ + ret = {} + ret['a'] = int(from_data.a) + if isinstance(from_data.b, float): + ret['b'] = float(from_data.b) + elif isinstance(from_data.b, str): + ret['b'] = str(from_data.b) + if isinstance(from_data.bs, BSchema): + ret['bs'] = {} + ret['bs']['a'] = int(from_data.bs.a) + ret['bs']['b'] = str(from_data.bs.b) + ret['bs']['c'] = float(from_data.bs.c) + ret['bs']['d'] = deflate_ASchema(from_data.bs.d) + elif isinstance(from_data.bs, NoneType): + ret['bs'] = from_data.bs + ret['c'] = float(from_data.c) + return ret diff --git a/src/megasniff/deflator.py b/src/megasniff/deflator.py new file mode 100644 index 0000000..85a3503 --- /dev/null +++ b/src/megasniff/deflator.py @@ -0,0 +1,278 @@ +# Copyright (C) 2025 Shevchenko A +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import annotations + +import importlib.resources +import typing +from collections.abc import Callable +from dataclasses import dataclass +from types import NoneType, UnionType +from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \ + OrderedDict, TypeAlias + +import jinja2 + +from .utils import * + +JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']] + + +class Unwrapping: + kind: str + + +class OtherUnwrapping(Unwrapping): + cast: str + + def __init__(self, cast: str = ''): + self.kind = 'other' + self.cast = cast + + +@dataclass() +class ObjectFieldUnwrapping: + key: str + object_key: str + unwrapping: Unwrapping + + +class ObjectUnwrapping(Unwrapping): + fields: list[ObjectFieldUnwrapping] + + def __init__(self, fields: list[ObjectFieldUnwrapping]): + self.kind = 'object' + self.fields = fields + + +class ListUnwrapping(Unwrapping): + item_unwrap: Unwrapping + + def __init__(self, item_unwrap: Unwrapping): + self.kind = 'list' + self.item_unwrap = item_unwrap + + +class DictUnwrapping(Unwrapping): + key_unwrap: Unwrapping + value_unwrap: Unwrapping + + def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping): + self.kind = 'dict' + self.key_unwrap = key_unwrap + self.value_unwrap = value_unwrap + + +class FuncUnwrapping(Unwrapping): + fn: str + + def __init__(self, fn: str): + self.kind = 'fn' + self.fn = fn + + +@dataclass +class UnionKindUnwrapping: + kind: str + unwrapping: Unwrapping + + +class UnionUnwrapping(Unwrapping): + union_kinds: list[UnionKindUnwrapping] + + def __init__(self, union_kinds: list[UnionKindUnwrapping]): + self.kind = 'union' + self.union_kinds = union_kinds + + +def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]: + if isinstance(t, TypeAliasType): + return _flatten_type(t.__value__) + + origin = get_origin(t) + + if origin is Annotated: + args = get_args(t) + return _flatten_type(args[0])[0], args[1] + + return t, None + + +def _schema_to_deflator_func(t: type | TypeAliasType) -> str: + t, _ = _flatten_type(t) + return 'deflate_' + typename(t).replace('.', '_') + + +class SchemaDeflatorGenerator: + templateLoader: jinja2.BaseLoader + templateEnv: jinja2.Environment + + object_template: jinja2.Template + _store_sources: bool + _strict_mode: bool + + def __init__(self, + loader: Optional[jinja2.BaseLoader] = None, + strict_mode: bool = False, + store_sources: bool = False, + *, + object_template_filename: str = 'deflator.jinja2', + ): + + self._strict_mode = strict_mode + self._store_sources = store_sources + + if loader is None: + template_path = importlib.resources.files('megasniff.templates') + loader = jinja2.FileSystemLoader(str(template_path)) + + self.templateLoader = loader + self.templateEnv = jinja2.Environment(loader=self.templateLoader) + self.object_template = self.templateEnv.get_template(object_template_filename) + + def schema_to_deflator(self, + schema: type, + strict_mode_override: Optional[bool] = None, + from_type_override: Optional[type | TypeAliasType] = None + ) -> Callable[[Any], dict[str, Any]]: + if from_type_override is not None and '__getitem__' not in dir(from_type_override): + raise RuntimeError('from_type_override must provide __getitem__') + txt, namespace = self._schema_to_deflator(schema, + strict_mode_override=strict_mode_override, + ) + imports = ('from typing import Any\n' + 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt = imports + '\n' + txt + exec(txt, namespace) + fn = namespace[_schema_to_deflator_func(schema)] + if self._store_sources: + setattr(fn, '__megasniff_sources__', txt) + return fn + + def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None): + if _visited_types is None: + _visited_types = [] + else: + _visited_types = _visited_types.copy() + + schema, field_rename = _flatten_type(schema) + + if schema in _visited_types: + return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema} + _visited_types.append(schema) + ongoing_types = set() + recurcive_types = set() + + origin = get_origin(schema) + + ret_unw = None + + if origin is not None: + if origin is list: + args = get_args(schema) + item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0], + _visited_types=_visited_types) + ret_unw = ListUnwrapping(item_unw) + recurcive_types |= item_rec + ongoing_types |= ongoings + elif origin is dict: + args = get_args(schema) + if len(args) != 2: + ret_unw = OtherUnwrapping() + else: + k, v = args + k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types) + v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types) + ongoing_types |= k_ongoings | v_ongoings + recurcive_types |= k_rec | v_rec + ret_unw = DictUnwrapping(k_unw, v_unw) + elif origin is UnionType or origin is Union: + args = get_args(schema) + union_unwraps = [] + for targ in args: + arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ, + _visited_types=_visited_types) + union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw)) + ongoing_types |= ongoings + recurcive_types |= arg_rec + ret_unw = UnionUnwrapping(union_unwraps) + else: + raise NotImplementedError + else: + if schema is int: + ret_unw = OtherUnwrapping('int') + elif schema is float: + ret_unw = OtherUnwrapping('float') + elif schema is bool: + ret_unw = OtherUnwrapping('bool') + elif schema is str: + ret_unw = OtherUnwrapping('str') + elif schema is None or schema is NoneType: + ret_unw = OtherUnwrapping() + elif schema is dict: + ret_unw = OtherUnwrapping() + elif schema is list: + ret_unw = OtherUnwrapping() + elif is_class_definition(schema): + hints = typing.get_type_hints(schema) + fields = [] + for k, f in hints.items(): + f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types) + fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw)) + ongoing_types |= ongoings + recurcive_types |= f_rec + + ret_unw = ObjectUnwrapping(fields) + else: + raise NotImplementedError() + + return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types + + def _schema_to_deflator(self, + schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type], + strict_mode_override: Optional[bool] = None, + into_type_override: Optional[type | TypeAliasType] = None, + *, + _funcname='deflate', + _namespace=None, + ) -> tuple[str, dict]: + if strict_mode_override is not None: + strict_mode = strict_mode_override + else: + strict_mode = self._strict_mode + + template = self.object_template + + types_for_namespace = set() + recursive_types = {schema} + + namespace = { + 'JsonObject': JsonObject + } + + convertor_functext = '' + + added_types = set() + + while len(recursive_types ^ (recursive_types & added_types)) > 0: + rec_t = list(recursive_types ^ (recursive_types & added_types))[0] + rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t) + recursive_types |= rec_rec_t + types_for_namespace |= rec_t_namespace + + rec_functext = template.render( + funcname=_schema_to_deflator_func(rec_t), + from_type=typename(rec_t), + into_type=None, + root_unwrap=rec_unw, + hashname=hashname, + ) + + convertor_functext += '\n\n\n' + rec_functext + added_types.add(rec_t) + + for t in types_for_namespace: + namespace[typename(t)] = t + + convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) + + return convertor_functext, namespace diff --git a/src/megasniff/templates/deflator.jinja2 b/src/megasniff/templates/deflator.jinja2 new file mode 100644 index 0000000..b827b0a --- /dev/null +++ b/src/megasniff/templates/deflator.jinja2 @@ -0,0 +1,82 @@ +{% macro render_unwrap_object(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{ into_container }} = {} +{% for kv in unwrapping.fields %} +{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }} +{% endfor %} +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{ into_container }} = {} +for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items(): +{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }} +{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }} +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_unwrap_list(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{into_container}} = [] +for {{hashname(unwrapping)}} in {{from_container}}: +{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }} + {{into_container}}.append({{hashname(unwrapping)}}_tmp_container) +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap_other(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% if unwrapping.cast != '' %} +{{into_container}} = {{unwrapping.cast}}({{from_container}}) +{% else %} +{{into_container}} = {{from_container}} +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap_union(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% for union_kind in unwrapping.union_kinds %} +{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}): +{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}} +{% endfor %} +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% if unwrapping.kind == 'dict' %} +{{ render_unwrap_dict(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'list' %} +{{ render_unwrap_list(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'object' %} +{{ render_unwrap_object(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'union' %} +{{ render_unwrap_union(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'fn' %} +{{into_container}} = {{ unwrapping.fn }}({{from_container}}) +{% else %} +{{ render_unwrap_other(unwrapping, from_container, into_container) }} +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + +def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}: + """ + {{from_type}} -> {{into_type}} + """ + {{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }} + return ret + + diff --git a/src/megasniff/utils.py b/src/megasniff/utils.py index 907bd4d..a6c00c2 100644 --- a/src/megasniff/utils.py +++ b/src/megasniff/utils.py @@ -71,3 +71,11 @@ def typename(tp: type) -> str: if get_origin(tp) is None and hasattr(tp, '__name__'): return tp.__name__ return str(tp) + + +def is_class_definition(obj): + return isinstance(obj, type) or inspect.isclass(obj) + + +def hashname(obj) -> str: + return '_' + str(hash(obj)).replace('-', '_')