Create deflator strict mode and explicit casts flags with tests and default universal fallback unwrapper

This commit is contained in:
2025-08-29 01:20:27 +03:00
parent 51817784a3
commit de6362fa1d
6 changed files with 252 additions and 81 deletions

View File

@@ -1 +1,2 @@
from .inflator import SchemaInflatorGenerator
from .deflator import SchemaDeflatorGenerator

View File

@@ -70,7 +70,7 @@ class DSchema:
@dataclass
class ESchema:
a: list[list[list[str]]]
b: str
b: str | int
@dataclass
@@ -80,11 +80,11 @@ class ZSchema:
def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True)
deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True)
fn = deflator.schema_to_deflator(DSchema)
print(getattr(fn, '__megasniff_sources__', '## No data'))
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], ['b'])))
# assert ret['a'] == 1
# assert ret['b'] == 1.1
# assert ret['c'] == 'a'
@@ -96,72 +96,3 @@ def main_deflator():
if __name__ == '__main__':
main_deflator()
from typing import Any
from megasniff.exceptions import MissingFieldException, FieldValidationException
def deflate_DSchema(from_data: DSchema) -> JsonObject:
"""
DSchema -> None
"""
ret = {}
ret['a'] = from_data.a
ret['b'] = {}
for k__8786182384272, v__8786182384272 in from_data.b.items():
k__8786182384272 = str(k__8786182384272)
v__8786182384272 = str(ret['b'][v__8786182384272])
if isinstance(from_data.c, str):
ret['c'] = str(from_data.c)
elif isinstance(from_data.c, float):
ret['c'] = float(from_data.c)
elif isinstance(from_data.c, ASchema):
ret['c'] = {}
ret['c']['a'] = int(from_data.c.a)
if isinstance(from_data.c.b, float):
ret['c']['b'] = float(from_data.c.b)
elif isinstance(from_data.c.b, str):
ret['c']['b'] = str(from_data.c.b)
if isinstance(from_data.c.bs, BSchema):
ret['c']['bs'] = {}
ret['c']['bs']['a'] = int(from_data.c.bs.a)
ret['c']['bs']['b'] = str(from_data.c.bs.b)
ret['c']['bs']['c'] = float(from_data.c.bs.c)
ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d)
elif isinstance(from_data.c.bs, NoneType):
ret['c']['bs'] = from_data.c.bs
ret['c']['c'] = float(from_data.c.c)
ret['d'] = {}
ret['d']['a'] = []
for _8786182524849 in from_data.d.a:
_8786182524849_tmp_container = []
for _8786182524629 in _8786182524849:
_8786182524629_tmp_container = []
for _8786182384293 in _8786182524629:
_8786182384293_tmp_container = str(_8786182384293)
_8786182524629_tmp_container.append(_8786182384293_tmp_container)
_8786182524849_tmp_container.append(_8786182524629_tmp_container)
ret['d']['a'].append(_8786182524849_tmp_container)
ret['d']['b'] = str(from_data.d.b)
return ret
def deflate_ASchema(from_data: ASchema) -> JsonObject:
"""
ASchema -> None
"""
ret = {}
ret['a'] = int(from_data.a)
if isinstance(from_data.b, float):
ret['b'] = float(from_data.b)
elif isinstance(from_data.b, str):
ret['b'] = str(from_data.b)
if isinstance(from_data.bs, BSchema):
ret['bs'] = {}
ret['bs']['a'] = int(from_data.bs.a)
ret['bs']['b'] = str(from_data.bs.b)
ret['bs']['c'] = float(from_data.bs.c)
ret['bs']['d'] = deflate_ASchema(from_data.bs.d)
elif isinstance(from_data.bs, NoneType):
ret['bs'] = from_data.bs
ret['c'] = float(from_data.c)
return ret

View File

@@ -22,11 +22,11 @@ class Unwrapping:
class OtherUnwrapping(Unwrapping):
cast: str
tp: str
def __init__(self, cast: str = ''):
def __init__(self, tp: str = ''):
self.kind = 'other'
self.cast = cast
self.tp = tp
@dataclass()
@@ -102,6 +102,25 @@ def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
return 'deflate_' + typename(t).replace('.', '_')
def _fallback_unwrapper(obj: Any) -> JsonObject:
if isinstance(obj, (int, float, str, bool)):
return obj
elif isinstance(obj, list):
return list(map(_fallback_unwrapper, obj))
elif isinstance(obj, dict):
return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items()))
elif hasattr(obj, '__dict__'):
ret = {}
for k, v in obj.__dict__:
if isinstance(k, str) and k.startswith('_'):
continue
k = _fallback_unwrapper(k)
v = _fallback_unwrapper(v)
ret[k] = v
return ret
return None
class SchemaDeflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
@@ -109,10 +128,12 @@ class SchemaDeflatorGenerator:
object_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
_explicit_casts: bool
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
strict_mode: bool = False,
explicit_casts: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'deflator.jinja2',
@@ -120,6 +141,7 @@ class SchemaDeflatorGenerator:
self._strict_mode = strict_mode
self._store_sources = store_sources
self._explicit_casts = explicit_casts
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
@@ -132,12 +154,11 @@ class SchemaDeflatorGenerator:
def schema_to_deflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None
explicit_casts_override: Optional[bool] = None,
) -> Callable[[Any], dict[str, Any]]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_deflator(schema,
strict_mode_override=strict_mode_override,
explicit_casts_override=explicit_casts_override,
)
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
@@ -230,6 +251,7 @@ class SchemaDeflatorGenerator:
def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='deflate',
@@ -239,6 +261,10 @@ class SchemaDeflatorGenerator:
strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
if explicit_casts_override is not None:
explicit_casts = explicit_casts_override
else:
explicit_casts = self._explicit_casts
template = self.object_template
@@ -246,7 +272,8 @@ class SchemaDeflatorGenerator:
recursive_types = {schema}
namespace = {
'JsonObject': JsonObject
'JsonObject': JsonObject,
'fallback_unwrapper': _fallback_unwrapper,
}
convertor_functext = ''
@@ -265,6 +292,8 @@ class SchemaDeflatorGenerator:
into_type=None,
root_unwrap=rec_unw,
hashname=hashname,
strict_check=strict_mode,
explicit_cast=explicit_casts,
)
convertor_functext += '\n\n\n' + rec_functext

View File

@@ -11,6 +11,13 @@
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% if strict_check %}
if not isinstance({{from_container}}, dict):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'dict(' + from_container + ')' %}
{% endif %}
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }}
@@ -21,6 +28,13 @@ for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{into_container}} = []
{% if strict_check %}
if not isinstance({{from_container}}, list):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'list(' + from_container + ')' %}
{% endif %}
for {{hashname(unwrapping)}} in {{from_container}}:
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
@@ -31,8 +45,12 @@ for {{hashname(unwrapping)}} in {{from_container}}:
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.cast != '' %}
{{into_container}} = {{unwrapping.cast}}({{from_container}})
{% if unwrapping.tp != '' and strict_check %}
if not isinstance({{from_container}}, {{unwrapping.tp}}):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}})))
{% endif %}
{% if unwrapping.tp != '' and explicit_cast %}
{{into_container}} = {{unwrapping.tp}}({{from_container}})
{% else %}
{{into_container}} = {{from_container}}
{% endif %}
@@ -47,6 +65,16 @@ for {{hashname(unwrapping)}} in {{from_container}}:
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
{% endfor %}
{% if strict_check %}
else:
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% elif explicit_cast %}
else:
{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}}
{% else %}
else:
{{into_container}} = fallback_unwrap({{from_container}})
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -0,0 +1,94 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_explicit_casts_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=False)
fn = defl.schema_to_deflator(A, explicit_casts_override=True)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
a = fn(A([42.0, 42]))
assert len(a['a']) == 2
assert a['a'][0] == 42
assert a['a'][1] == 42
def test_global_explicit_casts_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
b = fn(B([42.0]))
assert b['b'][0] == 42
b = fn(B([A([1.1])]))
assert b['b'][0]['a'][0] == 1

View File

@@ -0,0 +1,88 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_strict_mode_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=False)
fn = defl.schema_to_deflator(A, strict_mode_override=True)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
with pytest.raises(FieldValidationException):
fn(A([42.0, 42]))
def test_global_strict_mode_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
with pytest.raises(FieldValidationException):
fn(B([42.0]))
with pytest.raises(FieldValidationException):
fn(B([A([1.1])]))