Compare commits

12 Commits

12 changed files with 903 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
# megasniff
### Автоматическая валидация данных по схеме и сборка объекта в одном флаконе
### Автоматическая валидация данных по схеме, сборка и разборка объекта в одном флаконе
#### Как применять:
@@ -18,7 +18,8 @@ class SomeSchema1:
c: SomeSchema2 | str | None
class SomeSchema2(typing.TypedDict):
@dataclasses.dataclass
class SomeSchema2:
field1: dict
field2: float
field3: typing.Optional[SomeSchema1]
@@ -28,12 +29,16 @@ class SomeSchema2(typing.TypedDict):
import megasniff
infl = megasniff.SchemaInflatorGenerator()
fn = infl.schema_to_inflator(SomeSchema1)
defl = megasniff.SchemaDeflatorGenerator()
fn_in = infl.schema_to_inflator(SomeSchema1)
fn_out = defl.schema_to_deflator(SomeSchema1)
# 3. Проверяем что всё работает
fn({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}})
data = fn_in({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}})
# SomeSchema1(a=1, b=2.0, c={'field1': {}, 'field2': 1.1, 'field3': None})
fn_out(data)
# {'a': 1, 'b': 2.0, 'c': {'field1': {}, 'field2': 1.1, 'field3': None}}
```
@@ -47,7 +52,11 @@ fn({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}})
- не проверяет типы generic-словарей, кортежей (реализация ожидается)
- пользовательские проверки типов должны быть реализованы через наследование и проверки в конструкторе
- опциональный `strict-mode`: выключение приведения базовых типов
- может генерировать кортежи верхнеуровневых объектов при наличии описания схемы (полезно при развертывании аргументов)
- для inflation может генерировать кортежи верхнеуровневых объектов при наличии описания схемы (полезно при
развертывании аргументов)
- `TypedDict` поддерживается только для inflation из-за сложностей выбора варианта при сборке `Union`-полей
- для deflation поддерживается включение режима `explicit_casts`, приводящего типы к тем, которые указаны в
аннотациях (не распространяется на `Union`-типы, т.к. невозможно определить какой из них должен быть выбран)
----
@@ -88,6 +97,9 @@ class A:
```
>>> {"a": [1, 1.1, "321"]}
<<< A(a=[1, 1, 321])
>>> A(a=[1, 1.1, "321"])
<<< {"a": [1, 1.1, "321"]} # explicit_casts=False
<<< {"a": [1, 1, 321]} # explicit_casts=True
```
#### Strict-mode on:
@@ -101,11 +113,15 @@ class A:
```
>>> {"a": [1, 1.1, "321"]}
<<< FieldValidationException, т.к. 1.1 не является int
>>> A(a=[1, 1.1, "321"])
<<< FieldValidationException, т.к. 1.1 не является int
```
### Tuple unwrap
```
fn = infl.schema_to_inflator(
(('a', int), TupleSchemaItem(Optional[list[int]], key_name='b', has_default=True, default=None)))
```
Создаёт `fn: (dict[str,Any]) -> tuple[int, Optional[list[int]]]: ...` (сигнатура остаётся `(dict[str,Any])->tuple`)

View File

@@ -1,6 +1,6 @@
[project]
name = "megasniff"
version = "0.2.3.post2"
version = "0.2.6.post1"
description = "Library for in-time codegened type validation"
authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -1 +1,2 @@
from .inflator import SchemaInflatorGenerator
from .deflator import SchemaDeflatorGenerator

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from types import NoneType
from typing import Optional
from typing import TypedDict
import megasniff.exceptions
from megasniff.deflator import SchemaDeflatorGenerator, JsonObject
from . import SchemaInflatorGenerator
@@ -16,7 +19,8 @@ class ASchema:
c: float = 1.1
class BSchema(TypedDict):
@dataclass
class BSchema:
a: int
b: str
c: float
@@ -55,5 +59,41 @@ def main():
print(e)
@dataclass
class DSchema:
a: dict
b: dict[str, int | float | dict]
c: str | float | ASchema
d: ESchema
@dataclass
class ESchema:
a: list[list[list[str]]]
b: str | int
@dataclass
class ZSchema:
z: ZSchema | None
d: ZSchema | int
def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True)
fn = deflator.schema_to_deflator(DSchema | int)
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
ret = fn(42)
# assert ret['a'] == 1
# assert ret['b'] == 1.1
# assert ret['c'] == 'a'
# assert ret['d']['a'][0][0][0] == 'a'
# assert ret['d']['b'] == 'b'
print(json.dumps(ret, indent=4))
pass
if __name__ == '__main__':
main()
main_deflator()

347
src/megasniff/deflator.py Normal file
View File

@@ -0,0 +1,347 @@
# Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import importlib.resources
import typing
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias
import jinja2
from .utils import *
import uuid
from pathlib import Path
import tempfile
import importlib.util
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
class Unwrapping:
kind: str
class OtherUnwrapping(Unwrapping):
tp: str
def __init__(self, tp: str = ''):
self.kind = 'other'
self.tp = tp
@dataclass()
class ObjectFieldUnwrapping:
key: str
object_key: str
unwrapping: Unwrapping
class ObjectUnwrapping(Unwrapping):
fields: list[ObjectFieldUnwrapping]
def __init__(self, fields: list[ObjectFieldUnwrapping]):
self.kind = 'object'
self.fields = fields
class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping
def __init__(self, item_unwrap: Unwrapping):
self.kind = 'list'
self.item_unwrap = item_unwrap
class DictUnwrapping(Unwrapping):
key_unwrap: Unwrapping
value_unwrap: Unwrapping
def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping):
self.kind = 'dict'
self.key_unwrap = key_unwrap
self.value_unwrap = value_unwrap
class FuncUnwrapping(Unwrapping):
fn: str
def __init__(self, fn: str):
self.kind = 'fn'
self.fn = fn
@dataclass
class UnionKindUnwrapping:
kind: str
unwrapping: Unwrapping
class UnionUnwrapping(Unwrapping):
union_kinds: list[UnionKindUnwrapping]
def __init__(self, union_kinds: list[UnionKindUnwrapping]):
self.kind = 'union'
self.union_kinds = union_kinds
def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]:
if isinstance(t, TypeAliasType):
return _flatten_type(t.__value__)
origin = get_origin(t)
if origin is Annotated:
args = get_args(t)
return _flatten_type(args[0])[0], args[1]
return t, None
def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
t, _ = _flatten_type(t)
return ('deflate_' + typename(t)
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_'))
def _fallback_unwrapper(obj: Any) -> JsonObject:
if isinstance(obj, (int, float, str, bool)):
return obj
elif isinstance(obj, list):
return list(map(_fallback_unwrapper, obj))
elif isinstance(obj, dict):
return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items()))
elif hasattr(obj, '__dict__'):
ret = {}
for k, v in obj.__dict__:
if isinstance(k, str) and k.startswith('_'):
continue
k = _fallback_unwrapper(k)
v = _fallback_unwrapper(v)
ret[k] = v
return ret
return None
class SchemaDeflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
object_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
_explicit_casts: bool
_out_directory: str | None
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
strict_mode: bool = False,
explicit_casts: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'deflator.jinja2',
out_directory: str | None = None,
):
self._strict_mode = strict_mode
self._store_sources = store_sources
self._explicit_casts = explicit_casts
self._out_directory = out_directory
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
loader = jinja2.FileSystemLoader(str(template_path))
self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.object_template = self.templateEnv.get_template(object_template_filename)
def schema_to_deflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None,
ignore_directory: bool = False,
out_directory_override: Optional[str] = None,
) -> Callable[[Any], dict[str, Any]]:
txt, namespace = self._schema_to_deflator(schema,
strict_mode_override=strict_mode_override,
explicit_casts_override=explicit_casts_override,
)
out_dir = self._out_directory
if out_directory_override:
out_dir = out_directory_override
if ignore_directory:
out_dir = None
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
if out_dir is not None:
filename = f"{uuid.uuid4()}.py"
filepath = Path(out_dir) / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(txt)
spec = importlib.util.spec_from_file_location("generated_module", filepath)
module = importlib.util.module_from_spec(spec)
module.__dict__.update(namespace)
spec.loader.exec_module(module)
fn_name = _schema_to_deflator_func(schema)
fn = getattr(module, fn_name)
else:
exec(txt, namespace)
fn = namespace[_schema_to_deflator_func(schema)]
if self._store_sources:
setattr(fn, '__megasniff_sources__', txt)
return fn
def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None):
if _visited_types is None:
_visited_types = []
else:
_visited_types = _visited_types.copy()
schema, field_rename = _flatten_type(schema)
if schema in _visited_types:
return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema}
_visited_types.append(schema)
ongoing_types = set()
recurcive_types = set()
origin = get_origin(schema)
ret_unw = None
if origin is not None:
if origin is list:
args = get_args(schema)
item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0],
_visited_types=_visited_types)
ret_unw = ListUnwrapping(item_unw)
recurcive_types |= item_rec
ongoing_types |= ongoings
elif origin is dict:
args = get_args(schema)
if len(args) != 2:
ret_unw = OtherUnwrapping()
else:
k, v = args
k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
ongoing_types |= k_ongoings | v_ongoings
recurcive_types |= k_rec | v_rec
ret_unw = DictUnwrapping(k_unw, v_unw)
elif origin is UnionType or origin is Union:
args = get_args(schema)
union_unwraps = []
for targ in args:
arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ,
_visited_types=_visited_types)
union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw))
ongoing_types |= ongoings
recurcive_types |= arg_rec
ret_unw = UnionUnwrapping(union_unwraps)
else:
raise NotImplementedError
else:
if schema is int:
ret_unw = OtherUnwrapping('int')
elif schema is float:
ret_unw = OtherUnwrapping('float')
elif schema is bool:
ret_unw = OtherUnwrapping('bool')
elif schema is str:
ret_unw = OtherUnwrapping('str')
elif schema is None or schema is NoneType:
ret_unw = OtherUnwrapping()
elif schema is dict:
ret_unw = OtherUnwrapping()
elif schema is list:
ret_unw = OtherUnwrapping()
elif is_class_definition(schema):
hints = typing.get_type_hints(schema)
fields = []
for k, f in hints.items():
f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types)
fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw))
ongoing_types |= ongoings
recurcive_types |= f_rec
ret_unw = ObjectUnwrapping(fields)
else:
raise NotImplementedError(f'type not implemented yet: {schema}')
return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types
def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='deflate',
_namespace=None,
) -> tuple[str, dict]:
if strict_mode_override is not None:
strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
if explicit_casts_override is not None:
explicit_casts = explicit_casts_override
else:
explicit_casts = self._explicit_casts
template = self.object_template
types_for_namespace = set()
recursive_types = {schema}
namespace = {
'JsonObject': JsonObject,
'fallback_unwrapper': _fallback_unwrapper,
}
convertor_functext = ''
added_types = set()
while len(recursive_types ^ (recursive_types & added_types)) > 0:
rec_t = list(recursive_types ^ (recursive_types & added_types))[0]
rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t)
recursive_types |= rec_rec_t
types_for_namespace |= rec_t_namespace
rec_functext = template.render(
funcname=_schema_to_deflator_func(rec_t),
from_type=typename(rec_t),
into_type=None,
root_unwrap=rec_unw,
hashname=hashname,
strict_check=strict_mode,
explicit_cast=explicit_casts,
)
convertor_functext += '\n\n\n' + rec_functext
added_types.add(rec_t)
for t in types_for_namespace:
namespace[typename(t)] = t
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
return convertor_functext, namespace

View File

@@ -73,6 +73,7 @@ class SchemaInflatorGenerator:
tuple_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
_out_directory: str | None
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
@@ -81,10 +82,12 @@ class SchemaInflatorGenerator:
*,
object_template_filename: str = 'inflator.jinja2',
tuple_template_filename: str = 'inflator_tuple.jinja2',
out_directory: str | None = None,
):
self._strict_mode = strict_mode
self._store_sources = store_sources
self._out_directory = out_directory
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
@@ -98,7 +101,9 @@ class SchemaInflatorGenerator:
def schema_to_inflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None
from_type_override: Optional[type | TypeAliasType] = None,
ignore_directory: bool = False,
out_directory_override: Optional[str] = None,
) -> Callable[[dict[str, Any]], Any]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
@@ -107,11 +112,36 @@ class SchemaInflatorGenerator:
strict_mode_override=strict_mode_override,
from_type_override=from_type_override,
)
out_dir = self._out_directory
if out_directory_override:
out_dir = out_directory_override
if ignore_directory:
out_dir = None
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
exec(txt, namespace)
fn = namespace['inflate']
if out_dir is not None:
filename = f"{uuid.uuid4()}.py"
filepath = Path(out_dir) / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(txt)
spec = importlib.util.spec_from_file_location("generated_module", filepath)
module = importlib.util.module_from_spec(spec)
module.__dict__.update(namespace)
spec.loader.exec_module(module)
fn_name = _schema_to_deflator_func(schema)
fn = getattr(module, fn_name)
else:
exec(txt, namespace)
fn = namespace['inflate']
if self._store_sources:
setattr(fn, '__megasniff_sources__', txt)
return fn

View File

@@ -0,0 +1,110 @@
{% macro render_unwrap_object(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% for kv in unwrapping.fields %}
{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% if strict_check %}
if not isinstance({{from_container}}, dict):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'dict(' + from_container + ')' %}
{% endif %}
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
{{ render_unwrap(unwrapping.value_unwrap, 'v_' + hashname(unwrapping), into_container + '[k_' + hashname(unwrapping) + ']') | indent(4) }}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{into_container}} = []
{% if strict_check %}
if not isinstance({{from_container}}, list):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'list(' + from_container + ')' %}
{% endif %}
for {{hashname(unwrapping)}} in {{from_container}}:
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.tp != '' and strict_check %}
if not isinstance({{from_container}}, {{unwrapping.tp}}):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}})))
{% endif %}
{% if unwrapping.tp != '' and explicit_cast %}
{{into_container}} = {{unwrapping.tp}}({{from_container}})
{% else %}
{{into_container}} = {{from_container}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_union(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% for union_kind in unwrapping.union_kinds %}
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
{% endfor %}
{% if strict_check %}
else:
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% elif explicit_cast %}
else:
{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}}
{% else %}
else:
{{into_container}} = fallback_unwrap({{from_container}})
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.kind == 'dict' %}
{{ render_unwrap_dict(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'list' %}
{{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %}
{{into_container}} = {{ unwrapping.fn }}({{from_container}})
{% else %}
{{ render_unwrap_other(unwrapping, from_container, into_container) }}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}:
"""
{{from_type}} -> {{into_type}}
"""
{{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }}
return ret

View File

@@ -68,6 +68,26 @@ def is_builtin_type(tp: type) -> bool:
def typename(tp: type) -> str:
ret = ''
if get_origin(tp) is None and hasattr(tp, '__name__'):
return tp.__name__
return str(tp)
ret = tp.__name__
else:
ret = str(tp)
ret = (ret
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_')
.replace('\'', '')
.replace('<', '')
.replace('>', ''))
return ret
def is_class_definition(obj):
return (isinstance(obj, type) or inspect.isclass(obj))
def hashname(obj) -> str:
return '_' + str(hash(obj)).replace('-', '_')

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from megasniff.deflator import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator
def test_basic_deflator():
class A:
a: int
def __init__(self, a: int):
self.a = a
class B:
def __init__(self, b: int):
self.b = b
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
fnb = defl.schema_to_deflator(B)
b = fnb(B(11))
assert len(b) == 0
def test_unions():
@dataclass
class A:
a: int | str
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == '42'
a = fn(A('42a'))
assert a['a'] == '42a'
def test_dict_body():
@dataclass
class A:
a: dict[str, float]
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A({'1': 1.1, '2': 2.2}))
print(a)
assert a['a']['1'] == 1.1
assert a['a']['2'] == 2.2
@dataclass
class CircA:
b: CircB
@dataclass
class CircB:
a: CircA | None
def test_circular():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(CircA)
a = fn(CircA(CircB(CircA(CircB(None)))))
assert isinstance(a['b'], dict)
assert isinstance(a['b']['a'], dict)
assert a['b']['a']['b']['a'] is None
def test_optional():
@dataclass
class C:
a: Optional[int] = None
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(C)
c = fn(C())
assert c['a'] is None
c = fn(C(123))
assert c['a'] == 123

View File

@@ -0,0 +1,94 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_explicit_casts_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=False)
fn = defl.schema_to_deflator(A, explicit_casts_override=True)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
a = fn(A([42.0, 42]))
assert len(a['a']) == 2
assert a['a'][0] == 42
assert a['a'][1] == 42
def test_global_explicit_casts_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
b = fn(B([42.0]))
assert b['b'][0] == 42
b = fn(B([A([1.1])]))
assert b['b'][0]['a'][0] == 1

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import pytest
from megasniff.deflator import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator
def test_str_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(str, explicit_casts_override=True)
a = fn('asdf')
assert a == 'asdf'
a = fn(1234)
assert a == '1234'
fn1 = defl.schema_to_deflator(str, strict_mode_override=True)
with pytest.raises(Exception):
fn1(1234)
def test_int_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(int, explicit_casts_override=True)
a = fn(1234)
assert a == 1234
a = fn('1234')
assert a == 1234
fn1 = defl.schema_to_deflator(int, strict_mode_override=True)
with pytest.raises(Exception):
fn1('1234')
def test_float_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(float, explicit_casts_override=True)
a = fn(1234.1)
assert a == 1234.1
a = fn('1234')
assert a == 1234.0
fn1 = defl.schema_to_deflator(float, strict_mode_override=True)
with pytest.raises(Exception):
fn1(1234)

View File

@@ -0,0 +1,88 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_strict_mode_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=False)
fn = defl.schema_to_deflator(A, strict_mode_override=True)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
with pytest.raises(FieldValidationException):
fn(A([42.0, 42]))
def test_global_strict_mode_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
with pytest.raises(FieldValidationException):
fn(B([42.0]))
with pytest.raises(FieldValidationException):
fn(B([A([1.1])]))