Support Enum types

This commit is contained in:
2026-02-18 20:10:23 +03:00
parent bc9da11db4
commit 8e38c41aa5
5 changed files with 68 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "megasniff"
version = "0.2.6.post1"
version = "0.2.7"
description = "Library for in-time codegened type validation"
authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -6,6 +6,7 @@ import importlib.resources
import typing
from collections.abc import Callable
from dataclasses import dataclass
from enum import EnumType
from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias
@@ -17,6 +18,7 @@ import uuid
from pathlib import Path
import tempfile
import importlib.util
import enum
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
@@ -48,6 +50,11 @@ class ObjectUnwrapping(Unwrapping):
self.fields = fields
class EnumUnwrapping(Unwrapping):
def __init__(self, ):
self.kind = 'enum'
class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping
@@ -274,6 +281,8 @@ class SchemaDeflatorGenerator:
ret_unw = OtherUnwrapping()
elif schema is list:
ret_unw = OtherUnwrapping()
elif isinstance(schema, EnumType):
ret_unw = EnumUnwrapping()
elif is_class_definition(schema):
hints = typing.get_type_hints(schema)
fields = []

View File

@@ -6,6 +6,7 @@ import importlib.resources
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from enum import EnumType
from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \
OrderedDict
@@ -193,6 +194,11 @@ class SchemaInflatorGenerator:
else:
strict_mode = self._strict_mode
if _namespace is None:
namespace = {}
else:
namespace = _namespace
template = self.object_template
mode = 'object'
if isinstance(schema, dict):
@@ -201,6 +207,10 @@ class SchemaInflatorGenerator:
new_schema.append((argname, argtype))
schema = new_schema
if isinstance(schema, EnumType):
namespace[f'inflate_{schema.__name__}'] = schema
return '\n', namespace
if isinstance(schema, collections.abc.Iterable):
template = self.tuple_template
mode = 'tuple'
@@ -225,11 +235,6 @@ class SchemaInflatorGenerator:
txt_segments = []
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace

View File

@@ -8,6 +8,13 @@
{{out}}
{%- endmacro %}
{% macro render_unwrap_enum(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {{ from_container }}.value
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
@@ -88,6 +95,8 @@ else:
{{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'enum' %}
{{ render_unwrap_enum(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %}

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import json
from dataclasses import dataclass
from enum import Enum
from typing import Optional, get_type_hints
from megasniff import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator
@@ -60,3 +64,36 @@ def test_optional():
fn = infl.schema_to_inflator(C)
c = fn({})
assert c.a is None
class AEnum(Enum):
a = 'a'
b = 'b'
c = 42
e1 = {'a': 'b'}
e2 = ['a', 'b']
@dataclass
class Z:
a: Optional[AEnum] = None
def test_enum():
infl = SchemaInflatorGenerator()
defl = SchemaDeflatorGenerator()
infl_fn = infl.schema_to_inflator(Z)
defl_fn = defl.schema_to_deflator(Z)
for it in AEnum:
ref = {'a': it.value}
ref_str = json.dumps(ref)
z = infl_fn(json.loads(ref_str))
assert z.a is not None
assert z.a.value == it.value
assert z.a.name == it.name
zdict = defl_fn(z)
assert len(zdict) == 1
assert zdict['a'] == it.value
assert json.dumps(zdict) == ref_str
assert infl_fn(zdict) == z