Compare commits

..

2 Commits

Author SHA1 Message Date
8e38c41aa5 Support Enum types 2026-02-18 20:10:23 +03:00
bc9da11db4 Add deflator "Unable to generate unwrapper for {schema}" message 2026-02-18 19:34:19 +03:00
5 changed files with 69 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "megasniff" name = "megasniff"
version = "0.2.6.post1" version = "0.2.7"
description = "Library for in-time codegened type validation" description = "Library for in-time codegened type validation"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -6,6 +6,7 @@ import importlib.resources
import typing import typing
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import EnumType
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \ from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias OrderedDict, TypeAlias
@@ -17,6 +18,7 @@ import uuid
from pathlib import Path from pathlib import Path
import tempfile import tempfile
import importlib.util import importlib.util
import enum
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']] JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
@@ -48,6 +50,11 @@ class ObjectUnwrapping(Unwrapping):
self.fields = fields self.fields = fields
class EnumUnwrapping(Unwrapping):
def __init__(self, ):
self.kind = 'enum'
class ListUnwrapping(Unwrapping): class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping item_unwrap: Unwrapping
@@ -257,6 +264,7 @@ class SchemaDeflatorGenerator:
recurcive_types |= arg_rec recurcive_types |= arg_rec
ret_unw = UnionUnwrapping(union_unwraps) ret_unw = UnionUnwrapping(union_unwraps)
else: else:
print(f'Unable to generate unwrapper for {schema} -- origin {origin} is not supported')
raise NotImplementedError raise NotImplementedError
else: else:
if schema is int: if schema is int:
@@ -273,6 +281,8 @@ class SchemaDeflatorGenerator:
ret_unw = OtherUnwrapping() ret_unw = OtherUnwrapping()
elif schema is list: elif schema is list:
ret_unw = OtherUnwrapping() ret_unw = OtherUnwrapping()
elif isinstance(schema, EnumType):
ret_unw = EnumUnwrapping()
elif is_class_definition(schema): elif is_class_definition(schema):
hints = typing.get_type_hints(schema) hints = typing.get_type_hints(schema)
fields = [] fields = []

View File

@@ -6,6 +6,7 @@ import importlib.resources
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import EnumType
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \ from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \
OrderedDict OrderedDict
@@ -193,6 +194,11 @@ class SchemaInflatorGenerator:
else: else:
strict_mode = self._strict_mode strict_mode = self._strict_mode
if _namespace is None:
namespace = {}
else:
namespace = _namespace
template = self.object_template template = self.object_template
mode = 'object' mode = 'object'
if isinstance(schema, dict): if isinstance(schema, dict):
@@ -201,6 +207,10 @@ class SchemaInflatorGenerator:
new_schema.append((argname, argtype)) new_schema.append((argname, argtype))
schema = new_schema schema = new_schema
if isinstance(schema, EnumType):
namespace[f'inflate_{schema.__name__}'] = schema
return '\n', namespace
if isinstance(schema, collections.abc.Iterable): if isinstance(schema, collections.abc.Iterable):
template = self.tuple_template template = self.tuple_template
mode = 'tuple' mode = 'tuple'
@@ -225,11 +235,6 @@ class SchemaInflatorGenerator:
txt_segments = [] txt_segments = []
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None: if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace return '', namespace

View File

@@ -8,6 +8,13 @@
{{out}} {{out}}
{%- endmacro %} {%- endmacro %}
{% macro render_unwrap_enum(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {{ from_container }}.value
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%} {% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%} {%- set out -%}
{{ into_container }} = {} {{ into_container }} = {}
@@ -88,6 +95,8 @@ else:
{{ render_unwrap_list(unwrapping, from_container, into_container) }} {{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %} {% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }} {{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'enum' %}
{{ render_unwrap_enum(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %} {% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }} {{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %} {% elif unwrapping.kind == 'fn' %}

View File

@@ -1,7 +1,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import json
from dataclasses import dataclass
from enum import Enum
from typing import Optional, get_type_hints
from megasniff import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator from src.megasniff import SchemaInflatorGenerator
@@ -60,3 +64,36 @@ def test_optional():
fn = infl.schema_to_inflator(C) fn = infl.schema_to_inflator(C)
c = fn({}) c = fn({})
assert c.a is None assert c.a is None
class AEnum(Enum):
a = 'a'
b = 'b'
c = 42
e1 = {'a': 'b'}
e2 = ['a', 'b']
@dataclass
class Z:
a: Optional[AEnum] = None
def test_enum():
infl = SchemaInflatorGenerator()
defl = SchemaDeflatorGenerator()
infl_fn = infl.schema_to_inflator(Z)
defl_fn = defl.schema_to_deflator(Z)
for it in AEnum:
ref = {'a': it.value}
ref_str = json.dumps(ref)
z = infl_fn(json.loads(ref_str))
assert z.a is not None
assert z.a.value == it.value
assert z.a.name == it.name
zdict = defl_fn(z)
assert len(zdict) == 1
assert zdict['a'] == it.value
assert json.dumps(zdict) == ref_str
assert infl_fn(zdict) == z