diff --git a/README.md b/README.md index 55afca0..76804b2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # megasniff -### Автоматическая валидация данных по схеме и сборка объекта в одном флаконе +### Автоматическая валидация данных по схеме, сборка и разборка объекта в одном флаконе #### Как применять: @@ -18,7 +18,8 @@ class SomeSchema1: c: SomeSchema2 | str | None -class SomeSchema2(typing.TypedDict): +@dataclasses.dataclass +class SomeSchema2: field1: dict field2: float field3: typing.Optional[SomeSchema1] @@ -28,12 +29,16 @@ class SomeSchema2(typing.TypedDict): import megasniff infl = megasniff.SchemaInflatorGenerator() -fn = infl.schema_to_inflator(SomeSchema1) +defl = megasniff.SchemaDeflatorGenerator() +fn_in = infl.schema_to_inflator(SomeSchema1) +fn_out = defl.schema_to_deflator(SomeSchema1) # 3. Проверяем что всё работает -fn({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}}) +data = fn_in({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}}) # SomeSchema1(a=1, b=2.0, c={'field1': {}, 'field2': 1.1, 'field3': None}) +fn_out(data) +# {'a': 1, 'b': 2.0, 'c': {'field1': {}, 'field2': 1.1, 'field3': None}} ``` @@ -47,7 +52,11 @@ fn({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}}) - не проверяет типы generic-словарей, кортежей (реализация ожидается) - пользовательские проверки типов должны быть реализованы через наследование и проверки в конструкторе - опциональный `strict-mode`: выключение приведения базовых типов -- может генерировать кортежи верхнеуровневых объектов при наличии описания схемы (полезно при развертывании аргументов) +- для inflation может генерировать кортежи верхнеуровневых объектов при наличии описания схемы (полезно при + развертывании аргументов) +- `TypedDict` поддерживается только для inflation из-за сложностей выбора варианта при сборке `Union`-полей +- для deflation поддерживается включение режима `explicit_casts`, приводящего типы к тем, которые указаны в + аннотациях (не распространяется на `Union`-типы, т.к. невозможно определить какой из них должен быть выбран) ---- @@ -88,6 +97,9 @@ class A: ``` >>> {"a": [1, 1.1, "321"]} <<< A(a=[1, 1, 321]) +>>> A(a=[1, 1.1, "321"]) +<<< {"a": [1, 1.1, "321"]} # explicit_casts=False +<<< {"a": [1, 1, 321]} # explicit_casts=True ``` #### Strict-mode on: @@ -101,11 +113,15 @@ class A: ``` >>> {"a": [1, 1.1, "321"]} <<< FieldValidationException, т.к. 1.1 не является int +>>> A(a=[1, 1.1, "321"]) +<<< FieldValidationException, т.к. 1.1 не является int ``` ### Tuple unwrap + ``` fn = infl.schema_to_inflator( (('a', int), TupleSchemaItem(Optional[list[int]], key_name='b', has_default=True, default=None))) ``` + Создаёт `fn: (dict[str,Any]) -> tuple[int, Optional[list[int]]]: ...` (сигнатура остаётся `(dict[str,Any])->tuple`) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6033938..f497e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "megasniff" -version = "0.2.3.post2" +version = "0.2.4" description = "Library for in-time codegened type validation" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } diff --git a/src/megasniff/__init__.py b/src/megasniff/__init__.py index 21ed1a6..504d217 100644 --- a/src/megasniff/__init__.py +++ b/src/megasniff/__init__.py @@ -1 +1,2 @@ from .inflator import SchemaInflatorGenerator +from .deflator import SchemaDeflatorGenerator diff --git a/src/megasniff/__main__.py b/src/megasniff/__main__.py index 41203a8..46601ca 100644 --- a/src/megasniff/__main__.py +++ b/src/megasniff/__main__.py @@ -1,10 +1,13 @@ from __future__ import annotations +import json from dataclasses import dataclass +from types import NoneType from typing import Optional from typing import TypedDict import megasniff.exceptions +from megasniff.deflator import SchemaDeflatorGenerator, JsonObject from . import SchemaInflatorGenerator @@ -16,7 +19,8 @@ class ASchema: c: float = 1.1 -class BSchema(TypedDict): +@dataclass +class BSchema: a: int b: str c: float @@ -55,5 +59,40 @@ def main(): print(e) +@dataclass +class DSchema: + a: dict + b: dict[str, int | float | dict] + c: str | float | ASchema + d: ESchema + + +@dataclass +class ESchema: + a: list[list[list[str]]] + b: str | int + + +@dataclass +class ZSchema: + z: ZSchema | None + d: ZSchema | int + + +def main_deflator(): + deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True) + fn = deflator.schema_to_deflator(DSchema) + print(getattr(fn, '__megasniff_sources__', '## No data')) + # ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42))) + ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], ['b']))) + # assert ret['a'] == 1 + # assert ret['b'] == 1.1 + # assert ret['c'] == 'a' + # assert ret['d']['a'][0][0][0] == 'a' + # assert ret['d']['b'] == 'b' + print(json.dumps(ret, indent=4)) + pass + + if __name__ == '__main__': - main() + main_deflator() diff --git a/src/megasniff/deflator.py b/src/megasniff/deflator.py new file mode 100644 index 0000000..b8ee8f3 --- /dev/null +++ b/src/megasniff/deflator.py @@ -0,0 +1,307 @@ +# Copyright (C) 2025 Shevchenko A +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import annotations + +import importlib.resources +import typing +from collections.abc import Callable +from dataclasses import dataclass +from types import NoneType, UnionType +from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \ + OrderedDict, TypeAlias + +import jinja2 + +from .utils import * + +JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']] + + +class Unwrapping: + kind: str + + +class OtherUnwrapping(Unwrapping): + tp: str + + def __init__(self, tp: str = ''): + self.kind = 'other' + self.tp = tp + + +@dataclass() +class ObjectFieldUnwrapping: + key: str + object_key: str + unwrapping: Unwrapping + + +class ObjectUnwrapping(Unwrapping): + fields: list[ObjectFieldUnwrapping] + + def __init__(self, fields: list[ObjectFieldUnwrapping]): + self.kind = 'object' + self.fields = fields + + +class ListUnwrapping(Unwrapping): + item_unwrap: Unwrapping + + def __init__(self, item_unwrap: Unwrapping): + self.kind = 'list' + self.item_unwrap = item_unwrap + + +class DictUnwrapping(Unwrapping): + key_unwrap: Unwrapping + value_unwrap: Unwrapping + + def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping): + self.kind = 'dict' + self.key_unwrap = key_unwrap + self.value_unwrap = value_unwrap + + +class FuncUnwrapping(Unwrapping): + fn: str + + def __init__(self, fn: str): + self.kind = 'fn' + self.fn = fn + + +@dataclass +class UnionKindUnwrapping: + kind: str + unwrapping: Unwrapping + + +class UnionUnwrapping(Unwrapping): + union_kinds: list[UnionKindUnwrapping] + + def __init__(self, union_kinds: list[UnionKindUnwrapping]): + self.kind = 'union' + self.union_kinds = union_kinds + + +def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]: + if isinstance(t, TypeAliasType): + return _flatten_type(t.__value__) + + origin = get_origin(t) + + if origin is Annotated: + args = get_args(t) + return _flatten_type(args[0])[0], args[1] + + return t, None + + +def _schema_to_deflator_func(t: type | TypeAliasType) -> str: + t, _ = _flatten_type(t) + return 'deflate_' + typename(t).replace('.', '_') + + +def _fallback_unwrapper(obj: Any) -> JsonObject: + if isinstance(obj, (int, float, str, bool)): + return obj + elif isinstance(obj, list): + return list(map(_fallback_unwrapper, obj)) + elif isinstance(obj, dict): + return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items())) + elif hasattr(obj, '__dict__'): + ret = {} + for k, v in obj.__dict__: + if isinstance(k, str) and k.startswith('_'): + continue + k = _fallback_unwrapper(k) + v = _fallback_unwrapper(v) + ret[k] = v + return ret + return None + + +class SchemaDeflatorGenerator: + templateLoader: jinja2.BaseLoader + templateEnv: jinja2.Environment + + object_template: jinja2.Template + _store_sources: bool + _strict_mode: bool + _explicit_casts: bool + + def __init__(self, + loader: Optional[jinja2.BaseLoader] = None, + strict_mode: bool = False, + explicit_casts: bool = False, + store_sources: bool = False, + *, + object_template_filename: str = 'deflator.jinja2', + ): + + self._strict_mode = strict_mode + self._store_sources = store_sources + self._explicit_casts = explicit_casts + + if loader is None: + template_path = importlib.resources.files('megasniff.templates') + loader = jinja2.FileSystemLoader(str(template_path)) + + self.templateLoader = loader + self.templateEnv = jinja2.Environment(loader=self.templateLoader) + self.object_template = self.templateEnv.get_template(object_template_filename) + + def schema_to_deflator(self, + schema: type, + strict_mode_override: Optional[bool] = None, + explicit_casts_override: Optional[bool] = None, + ) -> Callable[[Any], dict[str, Any]]: + txt, namespace = self._schema_to_deflator(schema, + strict_mode_override=strict_mode_override, + explicit_casts_override=explicit_casts_override, + ) + imports = ('from typing import Any\n' + 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt = imports + '\n' + txt + exec(txt, namespace) + fn = namespace[_schema_to_deflator_func(schema)] + if self._store_sources: + setattr(fn, '__megasniff_sources__', txt) + return fn + + def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None): + if _visited_types is None: + _visited_types = [] + else: + _visited_types = _visited_types.copy() + + schema, field_rename = _flatten_type(schema) + + if schema in _visited_types: + return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema} + _visited_types.append(schema) + ongoing_types = set() + recurcive_types = set() + + origin = get_origin(schema) + + ret_unw = None + + if origin is not None: + if origin is list: + args = get_args(schema) + item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0], + _visited_types=_visited_types) + ret_unw = ListUnwrapping(item_unw) + recurcive_types |= item_rec + ongoing_types |= ongoings + elif origin is dict: + args = get_args(schema) + if len(args) != 2: + ret_unw = OtherUnwrapping() + else: + k, v = args + k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types) + v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types) + ongoing_types |= k_ongoings | v_ongoings + recurcive_types |= k_rec | v_rec + ret_unw = DictUnwrapping(k_unw, v_unw) + elif origin is UnionType or origin is Union: + args = get_args(schema) + union_unwraps = [] + for targ in args: + arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ, + _visited_types=_visited_types) + union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw)) + ongoing_types |= ongoings + recurcive_types |= arg_rec + ret_unw = UnionUnwrapping(union_unwraps) + else: + raise NotImplementedError + else: + if schema is int: + ret_unw = OtherUnwrapping('int') + elif schema is float: + ret_unw = OtherUnwrapping('float') + elif schema is bool: + ret_unw = OtherUnwrapping('bool') + elif schema is str: + ret_unw = OtherUnwrapping('str') + elif schema is None or schema is NoneType: + ret_unw = OtherUnwrapping() + elif schema is dict: + ret_unw = OtherUnwrapping() + elif schema is list: + ret_unw = OtherUnwrapping() + elif is_class_definition(schema): + hints = typing.get_type_hints(schema) + fields = [] + for k, f in hints.items(): + f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types) + fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw)) + ongoing_types |= ongoings + recurcive_types |= f_rec + + ret_unw = ObjectUnwrapping(fields) + else: + raise NotImplementedError() + + return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types + + def _schema_to_deflator(self, + schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type], + strict_mode_override: Optional[bool] = None, + explicit_casts_override: Optional[bool] = None, + into_type_override: Optional[type | TypeAliasType] = None, + *, + _funcname='deflate', + _namespace=None, + ) -> tuple[str, dict]: + if strict_mode_override is not None: + strict_mode = strict_mode_override + else: + strict_mode = self._strict_mode + if explicit_casts_override is not None: + explicit_casts = explicit_casts_override + else: + explicit_casts = self._explicit_casts + + template = self.object_template + + types_for_namespace = set() + recursive_types = {schema} + + namespace = { + 'JsonObject': JsonObject, + 'fallback_unwrapper': _fallback_unwrapper, + } + + convertor_functext = '' + + added_types = set() + + while len(recursive_types ^ (recursive_types & added_types)) > 0: + rec_t = list(recursive_types ^ (recursive_types & added_types))[0] + rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t) + recursive_types |= rec_rec_t + types_for_namespace |= rec_t_namespace + + rec_functext = template.render( + funcname=_schema_to_deflator_func(rec_t), + from_type=typename(rec_t), + into_type=None, + root_unwrap=rec_unw, + hashname=hashname, + strict_check=strict_mode, + explicit_cast=explicit_casts, + ) + + convertor_functext += '\n\n\n' + rec_functext + added_types.add(rec_t) + + for t in types_for_namespace: + namespace[typename(t)] = t + + convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) + + return convertor_functext, namespace diff --git a/src/megasniff/templates/deflator.jinja2 b/src/megasniff/templates/deflator.jinja2 new file mode 100644 index 0000000..bf6ee5c --- /dev/null +++ b/src/megasniff/templates/deflator.jinja2 @@ -0,0 +1,110 @@ +{% macro render_unwrap_object(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{ into_container }} = {} +{% for kv in unwrapping.fields %} +{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }} +{% endfor %} +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{ into_container }} = {} +{% if strict_check %} +if not isinstance({{from_container}}, dict): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}}))) +{% endif %} +{% if explicit_cast %} +{% set from_container = 'dict(' + from_container + ')' %} +{% endif %} +for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items(): +{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }} +{{ render_unwrap(unwrapping.value_unwrap, into_container + '[v_' + hashname(unwrapping) + ']', 'v_' + hashname(unwrapping)) | indent(4) }} +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_unwrap_list(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{into_container}} = [] +{% if strict_check %} +if not isinstance({{from_container}}, list): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}}))) +{% endif %} +{% if explicit_cast %} +{% set from_container = 'list(' + from_container + ')' %} +{% endif %} +for {{hashname(unwrapping)}} in {{from_container}}: +{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }} + {{into_container}}.append({{hashname(unwrapping)}}_tmp_container) +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap_other(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% if unwrapping.tp != '' and strict_check %} +if not isinstance({{from_container}}, {{unwrapping.tp}}): + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}}))) +{% endif %} +{% if unwrapping.tp != '' and explicit_cast %} +{{into_container}} = {{unwrapping.tp}}({{from_container}}) +{% else %} +{{into_container}} = {{from_container}} +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap_union(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% for union_kind in unwrapping.union_kinds %} +{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}): +{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}} +{% endfor %} +{% if strict_check %} +else: + raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}}))) +{% elif explicit_cast %} +else: +{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}} +{% else %} +else: + {{into_container}} = fallback_unwrap({{from_container}}) +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + +{% macro render_unwrap(unwrapping, from_container, into_container) -%} +{%- set out -%} +{% if unwrapping.kind == 'dict' %} +{{ render_unwrap_dict(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'list' %} +{{ render_unwrap_list(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'object' %} +{{ render_unwrap_object(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'union' %} +{{ render_unwrap_union(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'fn' %} +{{into_container}} = {{ unwrapping.fn }}({{from_container}}) +{% else %} +{{ render_unwrap_other(unwrapping, from_container, into_container) }} +{% endif %} +{%- endset %} +{{out}} +{%- endmacro %} + + +def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}: + """ + {{from_type}} -> {{into_type}} + """ + {{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }} + return ret + + diff --git a/src/megasniff/utils.py b/src/megasniff/utils.py index 907bd4d..a6c00c2 100644 --- a/src/megasniff/utils.py +++ b/src/megasniff/utils.py @@ -71,3 +71,11 @@ def typename(tp: type) -> str: if get_origin(tp) is None and hasattr(tp, '__name__'): return tp.__name__ return str(tp) + + +def is_class_definition(obj): + return isinstance(obj, type) or inspect.isclass(obj) + + +def hashname(obj) -> str: + return '_' + str(hash(obj)).replace('-', '_') diff --git a/tests/test_basic_deflator.py b/tests/test_basic_deflator.py new file mode 100644 index 0000000..a3cfb46 --- /dev/null +++ b/tests/test_basic_deflator.py @@ -0,0 +1,77 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional + +from megasniff.deflator import SchemaDeflatorGenerator +from src.megasniff import SchemaInflatorGenerator + + +def test_basic_deflator(): + class A: + a: int + + def __init__(self, a: int): + self.a = a + + class B: + def __init__(self, b: int): + self.b = b + + defl = SchemaDeflatorGenerator() + fn = defl.schema_to_deflator(A) + a = fn(A(42)) + + assert a['a'] == 42 + + fnb = defl.schema_to_deflator(B) + b = fnb(B(11)) + assert len(b) == 0 + + +def test_unions(): + @dataclass + class A: + a: int | str + + defl = SchemaDeflatorGenerator() + fn = defl.schema_to_deflator(A) + + a = fn(A(42)) + assert a['a'] == 42 + a = fn(A('42')) + assert a['a'] == '42' + a = fn(A('42a')) + assert a['a'] == '42a' + + +@dataclass +class CircA: + b: CircB + + +@dataclass +class CircB: + a: CircA | None + + +def test_circular(): + defl = SchemaDeflatorGenerator() + fn = defl.schema_to_deflator(CircA) + a = fn(CircA(CircB(CircA(CircB(None))))) + + assert isinstance(a['b'], dict) + assert isinstance(a['b']['a'], dict) + assert a['b']['a']['b']['a'] is None + + +def test_optional(): + @dataclass + class C: + a: Optional[int] = None + + defl = SchemaDeflatorGenerator() + fn = defl.schema_to_deflator(C) + c = fn(C()) + assert c['a'] is None + c = fn(C(123)) + assert c['a'] == 123 diff --git a/tests/test_explicit_cast_deflator.py b/tests/test_explicit_cast_deflator.py new file mode 100644 index 0000000..b4a01dd --- /dev/null +++ b/tests/test_explicit_cast_deflator.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass + +import pytest + +from megasniff import SchemaDeflatorGenerator +from megasniff.exceptions import FieldValidationException + + +def test_global_explicit_casts_basic(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(A) + a = fn(A(42)) + + assert a['a'] == 42 + + a = fn(A(42.0)) + assert a['a'] == 42 + + a = fn(A('42')) + assert a['a'] == 42 + + with pytest.raises(TypeError): + fn(A(['42'])) + + +def test_global_explicit_casts_basic_override(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(explicit_casts=False) + fn = defl.schema_to_deflator(A, explicit_casts_override=True) + a = fn(A(42)) + + assert a['a'] == 42 + + a = fn(A(42.0)) + assert a['a'] == 42 + + a = fn(A('42')) + assert a['a'] == 42 + + with pytest.raises(TypeError): + fn(A(['42'])) + + +def test_global_explicit_casts_list(): + @dataclass + class A: + a: list[int] + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(A) + a = fn(A([42])) + + assert a['a'] == [42] + + a = fn(A([42.0, 42])) + assert len(a['a']) == 2 + assert a['a'][0] == 42 + assert a['a'][1] == 42 + + +def test_global_explicit_casts_circular(): + @dataclass + class A: + a: list[int] + + @dataclass + class B: + b: list[A | int] + + defl = SchemaDeflatorGenerator(explicit_casts=True) + fn = defl.schema_to_deflator(B) + b = fn(B([A([]), 42])) + + assert len(b['b']) == 2 + assert isinstance(b['b'][0], dict) + assert len(b['b'][0]['a']) == 0 + assert isinstance(b['b'][1], int) + + b = fn(B([42.0])) + assert b['b'][0] == 42 + + b = fn(B([A([1.1])])) + assert b['b'][0]['a'][0] == 1 diff --git a/tests/test_strict_mode_deflator.py b/tests/test_strict_mode_deflator.py new file mode 100644 index 0000000..0d2d509 --- /dev/null +++ b/tests/test_strict_mode_deflator.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass + +import pytest + +from megasniff import SchemaDeflatorGenerator +from megasniff.exceptions import FieldValidationException + + +def test_global_strict_mode_basic(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(A) + a = fn(A(42)) + + assert a['a'] == 42 + + with pytest.raises(FieldValidationException): + fn(A(42.0)) + with pytest.raises(FieldValidationException): + fn(A('42')) + with pytest.raises(FieldValidationException): + fn(A(['42'])) + + +def test_global_strict_mode_basic_override(): + class A: + a: int + + def __init__(self, a): + self.a = a + + defl = SchemaDeflatorGenerator(strict_mode=False) + fn = defl.schema_to_deflator(A, strict_mode_override=True) + a = fn(A(42)) + + assert a['a'] == 42 + + with pytest.raises(FieldValidationException): + fn(A(42.0)) + with pytest.raises(FieldValidationException): + fn(A('42')) + with pytest.raises(FieldValidationException): + fn(A(['42'])) + + +def test_global_strict_mode_list(): + @dataclass + class A: + a: list[int] + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(A) + a = fn(A([42])) + + assert a['a'] == [42] + + with pytest.raises(FieldValidationException): + fn(A([42.0, 42])) + + +def test_global_strict_mode_circular(): + @dataclass + class A: + a: list[int] + + @dataclass + class B: + b: list[A | int] + + defl = SchemaDeflatorGenerator(strict_mode=True) + fn = defl.schema_to_deflator(B) + b = fn(B([A([]), 42])) + + assert len(b['b']) == 2 + assert isinstance(b['b'][0], dict) + assert len(b['b'][0]['a']) == 0 + assert isinstance(b['b'][1], int) + + with pytest.raises(FieldValidationException): + fn(B([42.0])) + + with pytest.raises(FieldValidationException): + fn(B([A([1.1])]))