diff --git a/pyproject.toml b/pyproject.toml index 214d32c..da289b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "megasniff" -version = "0.2.6.post1" +version = "0.2.7" description = "Library for in-time codegened type validation" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } diff --git a/src/megasniff/deflator.py b/src/megasniff/deflator.py index b5c02f3..07f780a 100644 --- a/src/megasniff/deflator.py +++ b/src/megasniff/deflator.py @@ -6,6 +6,7 @@ import importlib.resources import typing from collections.abc import Callable from dataclasses import dataclass +from enum import EnumType from types import NoneType, UnionType from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \ OrderedDict, TypeAlias @@ -17,6 +18,7 @@ import uuid from pathlib import Path import tempfile import importlib.util +import enum JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']] @@ -48,6 +50,11 @@ class ObjectUnwrapping(Unwrapping): self.fields = fields +class EnumUnwrapping(Unwrapping): + def __init__(self, ): + self.kind = 'enum' + + class ListUnwrapping(Unwrapping): item_unwrap: Unwrapping @@ -274,6 +281,8 @@ class SchemaDeflatorGenerator: ret_unw = OtherUnwrapping() elif schema is list: ret_unw = OtherUnwrapping() + elif isinstance(schema, EnumType): + ret_unw = EnumUnwrapping() elif is_class_definition(schema): hints = typing.get_type_hints(schema) fields = [] diff --git a/src/megasniff/inflator.py b/src/megasniff/inflator.py index 0c77a3d..92020c8 100644 --- a/src/megasniff/inflator.py +++ b/src/megasniff/inflator.py @@ -6,6 +6,7 @@ import importlib.resources from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass +from enum import EnumType from types import NoneType, UnionType from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \ OrderedDict @@ -193,6 +194,11 @@ class SchemaInflatorGenerator: else: strict_mode = self._strict_mode + if _namespace is None: + namespace = {} + else: + namespace = _namespace + template = self.object_template mode = 'object' if isinstance(schema, dict): @@ -201,6 +207,10 @@ class SchemaInflatorGenerator: new_schema.append((argname, argtype)) schema = new_schema + if isinstance(schema, EnumType): + namespace[f'inflate_{schema.__name__}'] = schema + return '\n', namespace + if isinstance(schema, collections.abc.Iterable): template = self.tuple_template mode = 'tuple' @@ -225,11 +235,6 @@ class SchemaInflatorGenerator: txt_segments = [] - if _namespace is None: - namespace = {} - else: - namespace = _namespace - if namespace.get(f'{_funcname}_tgt_type') is not None: return '', namespace diff --git a/src/megasniff/templates/deflator.jinja2 b/src/megasniff/templates/deflator.jinja2 index 8a470e7..359d909 100644 --- a/src/megasniff/templates/deflator.jinja2 +++ b/src/megasniff/templates/deflator.jinja2 @@ -8,6 +8,13 @@ {{out}} {%- endmacro %} +{% macro render_unwrap_enum(unwrapping, from_container, into_container) -%} +{%- set out -%} +{{ into_container }} = {{ from_container }}.value +{%- endset %} +{{out}} +{%- endmacro %} + {% macro render_unwrap_dict(unwrapping, from_container, into_container) -%} {%- set out -%} {{ into_container }} = {} @@ -88,6 +95,8 @@ else: {{ render_unwrap_list(unwrapping, from_container, into_container) }} {% elif unwrapping.kind == 'object' %} {{ render_unwrap_object(unwrapping, from_container, into_container) }} +{% elif unwrapping.kind == 'enum' %} +{{ render_unwrap_enum(unwrapping, from_container, into_container) }} {% elif unwrapping.kind == 'union' %} {{ render_unwrap_union(unwrapping, from_container, into_container) }} {% elif unwrapping.kind == 'fn' %} diff --git a/tests/test_basic.py b/tests/test_basic.py index 1d78056..c3681ba 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Optional +import json +from dataclasses import dataclass +from enum import Enum +from typing import Optional, get_type_hints + +from megasniff import SchemaDeflatorGenerator from src.megasniff import SchemaInflatorGenerator @@ -60,3 +64,36 @@ def test_optional(): fn = infl.schema_to_inflator(C) c = fn({}) assert c.a is None + + +class AEnum(Enum): + a = 'a' + b = 'b' + c = 42 + e1 = {'a': 'b'} + e2 = ['a', 'b'] + + +@dataclass +class Z: + a: Optional[AEnum] = None + + +def test_enum(): + infl = SchemaInflatorGenerator() + defl = SchemaDeflatorGenerator() + infl_fn = infl.schema_to_inflator(Z) + defl_fn = defl.schema_to_deflator(Z) + + for it in AEnum: + ref = {'a': it.value} + ref_str = json.dumps(ref) + z = infl_fn(json.loads(ref_str)) + assert z.a is not None + assert z.a.value == it.value + assert z.a.name == it.name + zdict = defl_fn(z) + assert len(zdict) == 1 + assert zdict['a'] == it.value + assert json.dumps(zdict) == ref_str + assert infl_fn(zdict) == z