From de6362fa1d4e612586d24d31872d3fd5893061cf Mon Sep 17 00:00:00 2001 From: nikto_b Date: Fri, 29 Aug 2025 01:20:27 +0300 Subject: [PATCH] Create deflator strict mode and explicit casts flags with tests and default universal fallback unwrapper --- src/megasniff/__init__.py | 1 + src/megasniff/__main__.py | 75 +------------------- src/megasniff/deflator.py | 43 +++++++++-- src/megasniff/templates/deflator.jinja2 | 32 ++++++++- tests/test_explicit_cast_deflator.py | 94 +++++++++++++++++++++++++ tests/test_strict_mode_deflator.py | 88 +++++++++++++++++++++++ 6 files changed, 252 insertions(+), 81 deletions(-) create mode 100644 tests/test_explicit_cast_deflator.py create mode 100644 tests/test_strict_mode_deflator.py diff --git a/src/megasniff/__init__.py b/src/megasniff/__init__.py index 21ed1a6..504d217 100644 --- a/src/megasniff/__init__.py +++ b/src/megasniff/__init__.py @@ -1 +1,2 @@ from .inflator import SchemaInflatorGenerator +from .deflator import SchemaDeflatorGenerator diff --git a/src/megasniff/__main__.py b/src/megasniff/__main__.py index 38b4270..46601ca 100644 --- a/src/megasniff/__main__.py +++ b/src/megasniff/__main__.py @@ -70,7 +70,7 @@ class DSchema: @dataclass class ESchema: a: list[list[list[str]]] - b: str + b: str | int @dataclass @@ -80,11 +80,11 @@ class ZSchema: def main_deflator(): - deflator = SchemaDeflatorGenerator(store_sources=True) + deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True) fn = deflator.schema_to_deflator(DSchema) print(getattr(fn, '__megasniff_sources__', '## No data')) # ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42))) - ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b'))) + ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], ['b']))) # assert ret['a'] == 1 # assert ret['b'] == 1.1 # assert ret['c'] == 'a' @@ -96,72 +96,3 @@ def main_deflator(): if __name__ == '__main__': main_deflator() - - - -from typing import Any -from megasniff.exceptions import MissingFieldException, FieldValidationException - -def deflate_DSchema(from_data: DSchema) -> JsonObject: - """ - DSchema -> None - """ - ret = {} - ret['a'] = from_data.a - ret['b'] = {} - for k__8786182384272, v__8786182384272 in from_data.b.items(): - k__8786182384272 = str(k__8786182384272) - v__8786182384272 = str(ret['b'][v__8786182384272]) - if isinstance(from_data.c, str): - ret['c'] = str(from_data.c) - elif isinstance(from_data.c, float): - ret['c'] = float(from_data.c) - elif isinstance(from_data.c, ASchema): - ret['c'] = {} - ret['c']['a'] = int(from_data.c.a) - if isinstance(from_data.c.b, float): - ret['c']['b'] = float(from_data.c.b) - elif isinstance(from_data.c.b, str): - ret['c']['b'] = str(from_data.c.b) - if isinstance(from_data.c.bs, BSchema): - ret['c']['bs'] = {} - ret['c']['bs']['a'] = int(from_data.c.bs.a) - ret['c']['bs']['b'] = str(from_data.c.bs.b) - ret['c']['bs']['c'] = float(from_data.c.bs.c) - ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d) - elif isinstance(from_data.c.bs, NoneType): - ret['c']['bs'] = from_data.c.bs - ret['c']['c'] = float(from_data.c.c) - ret['d'] = {} - ret['d']['a'] = [] - for _8786182524849 in from_data.d.a: - _8786182524849_tmp_container = [] - for _8786182524629 in _8786182524849: - _8786182524629_tmp_container = [] - for _8786182384293 in _8786182524629: - _8786182384293_tmp_container = str(_8786182384293) - _8786182524629_tmp_container.append(_8786182384293_tmp_container) - _8786182524849_tmp_container.append(_8786182524629_tmp_container) - ret['d']['a'].append(_8786182524849_tmp_container) - ret['d']['b'] = str(from_data.d.b) - return ret -def deflate_ASchema(from_data: ASchema) -> JsonObject: - """ - ASchema -> None - """ - ret = {} - ret['a'] = int(from_data.a) - if isinstance(from_data.b, float): - ret['b'] = float(from_data.b) - elif isinstance(from_data.b, str): - ret['b'] = str(from_data.b) - if isinstance(from_data.bs, BSchema): - ret['bs'] = {} - ret['bs']['a'] = int(from_data.bs.a) - ret['bs']['b'] = str(from_data.bs.b) - ret['bs']['c'] = float(from_data.bs.c) - ret['bs']['d'] = deflate_ASchema(from_data.bs.d) - elif isinstance(from_data.bs, NoneType): - ret['bs'] = from_data.bs - ret['c'] = float(from_data.c) - return ret diff --git a/src/megasniff/deflator.py b/src/megasniff/deflator.py index 85a3503..b8ee8f3 100644 --- a/src/megasniff/deflator.py +++ b/src/megasniff/deflator.py @@ -22,11 +22,11 @@ class Unwrapping: class OtherUnwrapping(Unwrapping): - cast: str + tp: str - def __init__(self, cast: str = ''): + def __init__(self, tp: str = ''): self.kind = 'other' - self.cast = cast + self.tp = tp @dataclass() @@ -102,6 +102,25 @@ def _schema_to_deflator_func(t: type | TypeAliasType) -> str: return 'deflate_' + typename(t).replace('.', '_') +def _fallback_unwrapper(obj: Any) -> JsonObject: + if isinstance(obj, (int, float, str, bool)): + return obj + elif isinstance(obj, list): + return list(map(_fallback_unwrapper, obj)) + elif isinstance(obj, dict): + return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items())) + elif hasattr(obj, '__dict__'): + ret = {} + for k, v in obj.__dict__: + if isinstance(k, str) and k.startswith('_'): + continue + k = _fallback_unwrapper(k) + v = _fallback_unwrapper(v) + ret[k] = v + return ret + return None + + class SchemaDeflatorGenerator: templateLoader: jinja2.BaseLoader templateEnv: jinja2.Environment @@ -109,10 +128,12 @@ class SchemaDeflatorGenerator: object_template: jinja2.Template _store_sources: bool _strict_mode: bool + _explicit_casts: bool def __init__(self, loader: Optional[jinja2.BaseLoader] = None, strict_mode: bool = False, + explicit_casts: bool = False, store_sources: bool = False, *, object_template_filename: str = 'deflator.jinja2', @@ -120,6 +141,7 @@ class SchemaDeflatorGenerator: self._strict_mode = strict_mode self._store_sources = store_sources + self._explicit_casts = explicit_casts if loader is None: template_path = importlib.resources.files('megasniff.templates') @@ -132,12 +154,11 @@ class SchemaDeflatorGenerator: def schema_to_deflator(self, schema: type, strict_mode_override: Optional[bool] = None, - from_type_override: Optional[type | TypeAliasType] = None + explicit_casts_override: Optional[bool] = None, ) -> Callable[[Any], dict[str, Any]]: - if from_type_override is not None and '__getitem__' not in dir(from_type_override): - raise RuntimeError('from_type_override must provide __getitem__') txt, namespace = self._schema_to_deflator(schema, strict_mode_override=strict_mode_override, + explicit_casts_override=explicit_casts_override, ) imports = ('from typing import Any\n' 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') @@ -230,6 +251,7 @@ class SchemaDeflatorGenerator: def _schema_to_deflator(self, schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type], strict_mode_override: Optional[bool] = None, + explicit_casts_override: Optional[bool] = None, into_type_override: Optional[type | TypeAliasType] = None, *, _funcname='deflate', @@ -239,6 +261,10 @@ class SchemaDeflatorGenerator: strict_mode = strict_mode_override else: strict_mode = self._strict_mode + if explicit_casts_override is not None: + explicit_casts = explicit_casts_override + else: + explicit_casts = self._explicit_casts template = self.object_template @@ -246,7 +272,8 @@ class SchemaDeflatorGenerator: recursive_types = {schema} namespace = { - 'JsonObject': JsonObject + 'JsonObject': JsonObject, + 'fallback_unwrapper': _fallback_unwrapper, } convertor_functext = '' @@ -265,6 +292,8 @@ class SchemaDeflatorGenerator: into_type=None, root_unwrap=rec_unw, hashname=hashname, + strict_check=strict_mode, + explicit_cast=explicit_casts, ) convertor_functext += '\n\n\n' + rec_functext diff --git a/src/megasniff/templates/deflator.jinja2 b/src/megasniff/templates/deflator.jinja2 index b827b0a..bf6ee5c 100644 --- a/src/megasniff/templates/deflator.jinja2 +++ b/src/megasniff/templates/deflator.jinja2 @@ -11,6 +11,13 @@ {% macro render_unwrap_dict(unwrapping, from_container, into_container) -%} {%- set out -%} {{ into_container }} = {} +{% if strict_check %} +if not isinstance({{from_container}}, dict): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}}))) +{% endif %} +{% if explicit_cast %} +{% set from_container = 'dict(' + from_container + ')' %} +{% endif %} for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items(): {{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }} {{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }} @@ -21,6 +28,13 @@ for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}} {% macro render_unwrap_list(unwrapping, from_container, into_container) -%} {%- set out -%} {{into_container}} = [] +{% if strict_check %} +if not isinstance({{from_container}}, list): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}}))) +{% endif %} +{% if explicit_cast %} +{% set from_container = 'list(' + from_container + ')' %} +{% endif %} for {{hashname(unwrapping)}} in {{from_container}}: {{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }} {{into_container}}.append({{hashname(unwrapping)}}_tmp_container) @@ -31,8 +45,12 @@ for {{hashname(unwrapping)}} in {{from_container}}: {% macro render_unwrap_other(unwrapping, from_container, into_container) -%} {%- set out -%} -{% if unwrapping.cast != '' %} -{{into_container}} = {{unwrapping.cast}}({{from_container}}) +{% if unwrapping.tp != '' and strict_check %} +if not isinstance({{from_container}}, {{unwrapping.tp}}): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}}))) +{% endif %} +{% if unwrapping.tp != '' and explicit_cast %} +{{into_container}} = {{unwrapping.tp}}({{from_container}}) {% else %} {{into_container}} = {{from_container}} {% endif %} @@ -47,6 +65,16 @@ for {{hashname(unwrapping)}} in {{from_container}}: {% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}): {{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}} {% endfor %} +{% if strict_check %} +else: + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}}))) +{% elif explicit_cast %} +else: +{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}} +{% else %} +else: + {{into_container}} = fallback_unwrap({{from_container}}) +{% endif %} {%- endset %} {{out}} {%- endmacro %} diff --git a/tests/test_explicit_cast_deflator.py b/tests/test_explicit_cast_deflator.py new file mode 100644 index 0000000..b4a01dd --- /dev/null +++ b/tests/test_explicit_cast_deflator.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass + +import pytest + +from megasniff import SchemaDeflatorGenerator +from megasniff.exceptions import FieldValidationException + + +def test_global_explicit_casts_basic(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(A) + a = fn(A(42)) + + assert a['a'] == 42 + + a = fn(A(42.0)) + assert a['a'] == 42 + + a = fn(A('42')) + assert a['a'] == 42 + + with pytest.raises(TypeError): + fn(A(['42'])) + + +def test_global_explicit_casts_basic_override(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(explicit_casts=False) + fn = defl.schema_to_deflator(A, explicit_casts_override=True) + a = fn(A(42)) + + assert a['a'] == 42 + + a = fn(A(42.0)) + assert a['a'] == 42 + + a = fn(A('42')) + assert a['a'] == 42 + + with pytest.raises(TypeError): + fn(A(['42'])) + + +def test_global_explicit_casts_list(): + @dataclass + class A: + a: list[int] + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(A) + a = fn(A([42])) + + assert a['a'] == [42] + + a = fn(A([42.0, 42])) + assert len(a['a']) == 2 + assert a['a'][0] == 42 + assert a['a'][1] == 42 + + +def test_global_explicit_casts_circular(): + @dataclass + class A: + a: list[int] + + @dataclass + class B: + b: list[A | int] + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(B) + b = fn(B([A([]), 42])) + + assert len(b['b']) == 2 + assert isinstance(b['b'][0], dict) + assert len(b['b'][0]['a']) == 0 + assert isinstance(b['b'][1], int) + + b = fn(B([42.0])) + assert b['b'][0] == 42 + + b = fn(B([A([1.1])])) + assert b['b'][0]['a'][0] == 1 diff --git a/tests/test_strict_mode_deflator.py b/tests/test_strict_mode_deflator.py new file mode 100644 index 0000000..0d2d509 --- /dev/null +++ b/tests/test_strict_mode_deflator.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass + +import pytest + +from megasniff import SchemaDeflatorGenerator +from megasniff.exceptions import FieldValidationException + + +def test_global_strict_mode_basic(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(A) + a = fn(A(42)) + + assert a['a'] == 42 + + with pytest.raises(FieldValidationException): + fn(A(42.0)) + with pytest.raises(FieldValidationException): + fn(A('42')) + with pytest.raises(FieldValidationException): + fn(A(['42'])) + + +def test_global_strict_mode_basic_override(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(strict_mode=False) + fn = defl.schema_to_deflator(A, strict_mode_override=True) + a = fn(A(42)) + + assert a['a'] == 42 + + with pytest.raises(FieldValidationException): + fn(A(42.0)) + with pytest.raises(FieldValidationException): + fn(A('42')) + with pytest.raises(FieldValidationException): + fn(A(['42'])) + + +def test_global_strict_mode_list(): + @dataclass + class A: + a: list[int] + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(A) + a = fn(A([42])) + + assert a['a'] == [42] + + with pytest.raises(FieldValidationException): + fn(A([42.0, 42])) + + +def test_global_strict_mode_circular(): + @dataclass + class A: + a: list[int] + + @dataclass + class B: + b: list[A | int] + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(B) + b = fn(B([A([]), 42])) + + assert len(b['b']) == 2 + assert isinstance(b['b'][0], dict) + assert len(b['b'][0]['a']) == 0 + assert isinstance(b['b'][1], int) + + with pytest.raises(FieldValidationException): + fn(B([42.0])) + + with pytest.raises(FieldValidationException): + fn(B([A([1.1])]))