Create deflator strict mode and explicit casts flags with tests and default universal fallback unwrapper
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from .inflator import SchemaInflatorGenerator
|
||||
from .deflator import SchemaDeflatorGenerator
|
||||
|
||||
@@ -70,7 +70,7 @@ class DSchema:
|
||||
@dataclass
|
||||
class ESchema:
|
||||
a: list[list[list[str]]]
|
||||
b: str
|
||||
b: str | int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,11 +80,11 @@ class ZSchema:
|
||||
|
||||
|
||||
def main_deflator():
|
||||
deflator = SchemaDeflatorGenerator(store_sources=True)
|
||||
deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True)
|
||||
fn = deflator.schema_to_deflator(DSchema)
|
||||
print(getattr(fn, '__megasniff_sources__', '## No data'))
|
||||
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
|
||||
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
|
||||
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], ['b'])))
|
||||
# assert ret['a'] == 1
|
||||
# assert ret['b'] == 1.1
|
||||
# assert ret['c'] == 'a'
|
||||
@@ -96,72 +96,3 @@ def main_deflator():
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_deflator()
|
||||
|
||||
|
||||
|
||||
from typing import Any
|
||||
from megasniff.exceptions import MissingFieldException, FieldValidationException
|
||||
|
||||
def deflate_DSchema(from_data: DSchema) -> JsonObject:
|
||||
"""
|
||||
DSchema -> None
|
||||
"""
|
||||
ret = {}
|
||||
ret['a'] = from_data.a
|
||||
ret['b'] = {}
|
||||
for k__8786182384272, v__8786182384272 in from_data.b.items():
|
||||
k__8786182384272 = str(k__8786182384272)
|
||||
v__8786182384272 = str(ret['b'][v__8786182384272])
|
||||
if isinstance(from_data.c, str):
|
||||
ret['c'] = str(from_data.c)
|
||||
elif isinstance(from_data.c, float):
|
||||
ret['c'] = float(from_data.c)
|
||||
elif isinstance(from_data.c, ASchema):
|
||||
ret['c'] = {}
|
||||
ret['c']['a'] = int(from_data.c.a)
|
||||
if isinstance(from_data.c.b, float):
|
||||
ret['c']['b'] = float(from_data.c.b)
|
||||
elif isinstance(from_data.c.b, str):
|
||||
ret['c']['b'] = str(from_data.c.b)
|
||||
if isinstance(from_data.c.bs, BSchema):
|
||||
ret['c']['bs'] = {}
|
||||
ret['c']['bs']['a'] = int(from_data.c.bs.a)
|
||||
ret['c']['bs']['b'] = str(from_data.c.bs.b)
|
||||
ret['c']['bs']['c'] = float(from_data.c.bs.c)
|
||||
ret['c']['bs']['d'] = deflate_ASchema(from_data.c.bs.d)
|
||||
elif isinstance(from_data.c.bs, NoneType):
|
||||
ret['c']['bs'] = from_data.c.bs
|
||||
ret['c']['c'] = float(from_data.c.c)
|
||||
ret['d'] = {}
|
||||
ret['d']['a'] = []
|
||||
for _8786182524849 in from_data.d.a:
|
||||
_8786182524849_tmp_container = []
|
||||
for _8786182524629 in _8786182524849:
|
||||
_8786182524629_tmp_container = []
|
||||
for _8786182384293 in _8786182524629:
|
||||
_8786182384293_tmp_container = str(_8786182384293)
|
||||
_8786182524629_tmp_container.append(_8786182384293_tmp_container)
|
||||
_8786182524849_tmp_container.append(_8786182524629_tmp_container)
|
||||
ret['d']['a'].append(_8786182524849_tmp_container)
|
||||
ret['d']['b'] = str(from_data.d.b)
|
||||
return ret
|
||||
def deflate_ASchema(from_data: ASchema) -> JsonObject:
|
||||
"""
|
||||
ASchema -> None
|
||||
"""
|
||||
ret = {}
|
||||
ret['a'] = int(from_data.a)
|
||||
if isinstance(from_data.b, float):
|
||||
ret['b'] = float(from_data.b)
|
||||
elif isinstance(from_data.b, str):
|
||||
ret['b'] = str(from_data.b)
|
||||
if isinstance(from_data.bs, BSchema):
|
||||
ret['bs'] = {}
|
||||
ret['bs']['a'] = int(from_data.bs.a)
|
||||
ret['bs']['b'] = str(from_data.bs.b)
|
||||
ret['bs']['c'] = float(from_data.bs.c)
|
||||
ret['bs']['d'] = deflate_ASchema(from_data.bs.d)
|
||||
elif isinstance(from_data.bs, NoneType):
|
||||
ret['bs'] = from_data.bs
|
||||
ret['c'] = float(from_data.c)
|
||||
return ret
|
||||
|
||||
@@ -22,11 +22,11 @@ class Unwrapping:
|
||||
|
||||
|
||||
class OtherUnwrapping(Unwrapping):
|
||||
cast: str
|
||||
tp: str
|
||||
|
||||
def __init__(self, cast: str = ''):
|
||||
def __init__(self, tp: str = ''):
|
||||
self.kind = 'other'
|
||||
self.cast = cast
|
||||
self.tp = tp
|
||||
|
||||
|
||||
@dataclass()
|
||||
@@ -102,6 +102,25 @@ def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
|
||||
return 'deflate_' + typename(t).replace('.', '_')
|
||||
|
||||
|
||||
def _fallback_unwrapper(obj: Any) -> JsonObject:
|
||||
if isinstance(obj, (int, float, str, bool)):
|
||||
return obj
|
||||
elif isinstance(obj, list):
|
||||
return list(map(_fallback_unwrapper, obj))
|
||||
elif isinstance(obj, dict):
|
||||
return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items()))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
ret = {}
|
||||
for k, v in obj.__dict__:
|
||||
if isinstance(k, str) and k.startswith('_'):
|
||||
continue
|
||||
k = _fallback_unwrapper(k)
|
||||
v = _fallback_unwrapper(v)
|
||||
ret[k] = v
|
||||
return ret
|
||||
return None
|
||||
|
||||
|
||||
class SchemaDeflatorGenerator:
|
||||
templateLoader: jinja2.BaseLoader
|
||||
templateEnv: jinja2.Environment
|
||||
@@ -109,10 +128,12 @@ class SchemaDeflatorGenerator:
|
||||
object_template: jinja2.Template
|
||||
_store_sources: bool
|
||||
_strict_mode: bool
|
||||
_explicit_casts: bool
|
||||
|
||||
def __init__(self,
|
||||
loader: Optional[jinja2.BaseLoader] = None,
|
||||
strict_mode: bool = False,
|
||||
explicit_casts: bool = False,
|
||||
store_sources: bool = False,
|
||||
*,
|
||||
object_template_filename: str = 'deflator.jinja2',
|
||||
@@ -120,6 +141,7 @@ class SchemaDeflatorGenerator:
|
||||
|
||||
self._strict_mode = strict_mode
|
||||
self._store_sources = store_sources
|
||||
self._explicit_casts = explicit_casts
|
||||
|
||||
if loader is None:
|
||||
template_path = importlib.resources.files('megasniff.templates')
|
||||
@@ -132,12 +154,11 @@ class SchemaDeflatorGenerator:
|
||||
def schema_to_deflator(self,
|
||||
schema: type,
|
||||
strict_mode_override: Optional[bool] = None,
|
||||
from_type_override: Optional[type | TypeAliasType] = None
|
||||
explicit_casts_override: Optional[bool] = None,
|
||||
) -> Callable[[Any], dict[str, Any]]:
|
||||
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
|
||||
raise RuntimeError('from_type_override must provide __getitem__')
|
||||
txt, namespace = self._schema_to_deflator(schema,
|
||||
strict_mode_override=strict_mode_override,
|
||||
explicit_casts_override=explicit_casts_override,
|
||||
)
|
||||
imports = ('from typing import Any\n'
|
||||
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
|
||||
@@ -230,6 +251,7 @@ class SchemaDeflatorGenerator:
|
||||
def _schema_to_deflator(self,
|
||||
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
|
||||
strict_mode_override: Optional[bool] = None,
|
||||
explicit_casts_override: Optional[bool] = None,
|
||||
into_type_override: Optional[type | TypeAliasType] = None,
|
||||
*,
|
||||
_funcname='deflate',
|
||||
@@ -239,6 +261,10 @@ class SchemaDeflatorGenerator:
|
||||
strict_mode = strict_mode_override
|
||||
else:
|
||||
strict_mode = self._strict_mode
|
||||
if explicit_casts_override is not None:
|
||||
explicit_casts = explicit_casts_override
|
||||
else:
|
||||
explicit_casts = self._explicit_casts
|
||||
|
||||
template = self.object_template
|
||||
|
||||
@@ -246,7 +272,8 @@ class SchemaDeflatorGenerator:
|
||||
recursive_types = {schema}
|
||||
|
||||
namespace = {
|
||||
'JsonObject': JsonObject
|
||||
'JsonObject': JsonObject,
|
||||
'fallback_unwrapper': _fallback_unwrapper,
|
||||
}
|
||||
|
||||
convertor_functext = ''
|
||||
@@ -265,6 +292,8 @@ class SchemaDeflatorGenerator:
|
||||
into_type=None,
|
||||
root_unwrap=rec_unw,
|
||||
hashname=hashname,
|
||||
strict_check=strict_mode,
|
||||
explicit_cast=explicit_casts,
|
||||
)
|
||||
|
||||
convertor_functext += '\n\n\n' + rec_functext
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
|
||||
{%- set out -%}
|
||||
{{ into_container }} = {}
|
||||
{% if strict_check %}
|
||||
if not isinstance({{from_container}}, dict):
|
||||
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
|
||||
{% endif %}
|
||||
{% if explicit_cast %}
|
||||
{% set from_container = 'dict(' + from_container + ')' %}
|
||||
{% endif %}
|
||||
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
|
||||
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
|
||||
{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }}
|
||||
@@ -21,6 +28,13 @@ for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}
|
||||
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
|
||||
{%- set out -%}
|
||||
{{into_container}} = []
|
||||
{% if strict_check %}
|
||||
if not isinstance({{from_container}}, list):
|
||||
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}})))
|
||||
{% endif %}
|
||||
{% if explicit_cast %}
|
||||
{% set from_container = 'list(' + from_container + ')' %}
|
||||
{% endif %}
|
||||
for {{hashname(unwrapping)}} in {{from_container}}:
|
||||
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
|
||||
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
|
||||
@@ -31,8 +45,12 @@ for {{hashname(unwrapping)}} in {{from_container}}:
|
||||
|
||||
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
|
||||
{%- set out -%}
|
||||
{% if unwrapping.cast != '' %}
|
||||
{{into_container}} = {{unwrapping.cast}}({{from_container}})
|
||||
{% if unwrapping.tp != '' and strict_check %}
|
||||
if not isinstance({{from_container}}, {{unwrapping.tp}}):
|
||||
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}})))
|
||||
{% endif %}
|
||||
{% if unwrapping.tp != '' and explicit_cast %}
|
||||
{{into_container}} = {{unwrapping.tp}}({{from_container}})
|
||||
{% else %}
|
||||
{{into_container}} = {{from_container}}
|
||||
{% endif %}
|
||||
@@ -47,6 +65,16 @@ for {{hashname(unwrapping)}} in {{from_container}}:
|
||||
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
|
||||
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
|
||||
{% endfor %}
|
||||
{% if strict_check %}
|
||||
else:
|
||||
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
|
||||
{% elif explicit_cast %}
|
||||
else:
|
||||
{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}}
|
||||
{% else %}
|
||||
else:
|
||||
{{into_container}} = fallback_unwrap({{from_container}})
|
||||
{% endif %}
|
||||
{%- endset %}
|
||||
{{out}}
|
||||
{%- endmacro %}
|
||||
|
||||
94
tests/test_explicit_cast_deflator.py
Normal file
94
tests/test_explicit_cast_deflator.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from megasniff import SchemaDeflatorGenerator
|
||||
from megasniff.exceptions import FieldValidationException
|
||||
|
||||
|
||||
def test_global_explicit_casts_basic():
|
||||
class A:
|
||||
a: int
|
||||
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
defl = SchemaDeflatorGenerator(explicit_casts=True)
|
||||
fn = defl.schema_to_deflator(A)
|
||||
a = fn(A(42))
|
||||
|
||||
assert a['a'] == 42
|
||||
|
||||
a = fn(A(42.0))
|
||||
assert a['a'] == 42
|
||||
|
||||
a = fn(A('42'))
|
||||
assert a['a'] == 42
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
fn(A(['42']))
|
||||
|
||||
|
||||
def test_global_explicit_casts_basic_override():
|
||||
class A:
|
||||
a: int
|
||||
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
defl = SchemaDeflatorGenerator(explicit_casts=False)
|
||||
fn = defl.schema_to_deflator(A, explicit_casts_override=True)
|
||||
a = fn(A(42))
|
||||
|
||||
assert a['a'] == 42
|
||||
|
||||
a = fn(A(42.0))
|
||||
assert a['a'] == 42
|
||||
|
||||
a = fn(A('42'))
|
||||
assert a['a'] == 42
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
fn(A(['42']))
|
||||
|
||||
|
||||
def test_global_explicit_casts_list():
|
||||
@dataclass
|
||||
class A:
|
||||
a: list[int]
|
||||
|
||||
defl = SchemaDeflatorGenerator(explicit_casts=True)
|
||||
fn = defl.schema_to_deflator(A)
|
||||
a = fn(A([42]))
|
||||
|
||||
assert a['a'] == [42]
|
||||
|
||||
a = fn(A([42.0, 42]))
|
||||
assert len(a['a']) == 2
|
||||
assert a['a'][0] == 42
|
||||
assert a['a'][1] == 42
|
||||
|
||||
|
||||
def test_global_explicit_casts_circular():
|
||||
@dataclass
|
||||
class A:
|
||||
a: list[int]
|
||||
|
||||
@dataclass
|
||||
class B:
|
||||
b: list[A | int]
|
||||
|
||||
defl = SchemaDeflatorGenerator(explicit_casts=True)
|
||||
fn = defl.schema_to_deflator(B)
|
||||
b = fn(B([A([]), 42]))
|
||||
|
||||
assert len(b['b']) == 2
|
||||
assert isinstance(b['b'][0], dict)
|
||||
assert len(b['b'][0]['a']) == 0
|
||||
assert isinstance(b['b'][1], int)
|
||||
|
||||
b = fn(B([42.0]))
|
||||
assert b['b'][0] == 42
|
||||
|
||||
b = fn(B([A([1.1])]))
|
||||
assert b['b'][0]['a'][0] == 1
|
||||
88
tests/test_strict_mode_deflator.py
Normal file
88
tests/test_strict_mode_deflator.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from megasniff import SchemaDeflatorGenerator
|
||||
from megasniff.exceptions import FieldValidationException
|
||||
|
||||
|
||||
def test_global_strict_mode_basic():
|
||||
class A:
|
||||
a: int
|
||||
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
defl = SchemaDeflatorGenerator(strict_mode=True)
|
||||
fn = defl.schema_to_deflator(A)
|
||||
a = fn(A(42))
|
||||
|
||||
assert a['a'] == 42
|
||||
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A(42.0))
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A('42'))
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A(['42']))
|
||||
|
||||
|
||||
def test_global_strict_mode_basic_override():
|
||||
class A:
|
||||
a: int
|
||||
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
|
||||
defl = SchemaDeflatorGenerator(strict_mode=False)
|
||||
fn = defl.schema_to_deflator(A, strict_mode_override=True)
|
||||
a = fn(A(42))
|
||||
|
||||
assert a['a'] == 42
|
||||
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A(42.0))
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A('42'))
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A(['42']))
|
||||
|
||||
|
||||
def test_global_strict_mode_list():
|
||||
@dataclass
|
||||
class A:
|
||||
a: list[int]
|
||||
|
||||
defl = SchemaDeflatorGenerator(strict_mode=True)
|
||||
fn = defl.schema_to_deflator(A)
|
||||
a = fn(A([42]))
|
||||
|
||||
assert a['a'] == [42]
|
||||
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(A([42.0, 42]))
|
||||
|
||||
|
||||
def test_global_strict_mode_circular():
|
||||
@dataclass
|
||||
class A:
|
||||
a: list[int]
|
||||
|
||||
@dataclass
|
||||
class B:
|
||||
b: list[A | int]
|
||||
|
||||
defl = SchemaDeflatorGenerator(strict_mode=True)
|
||||
fn = defl.schema_to_deflator(B)
|
||||
b = fn(B([A([]), 42]))
|
||||
|
||||
assert len(b['b']) == 2
|
||||
assert isinstance(b['b'][0], dict)
|
||||
assert len(b['b'][0]['a']) == 0
|
||||
assert isinstance(b['b'][1], int)
|
||||
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(B([42.0]))
|
||||
|
||||
with pytest.raises(FieldValidationException):
|
||||
fn(B([A([1.1])]))
|
||||
Reference in New Issue
Block a user