From 9267f744d80038381546ef44f4d64d06b90808b8 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Thu, 19 Feb 2026 22:39:56 +0300 Subject: [PATCH] Dict inflator and deflator support --- pyproject.toml | 2 +- src/megasniff/deflator.py | 2 + src/megasniff/inflator.py | 46 ++++++--- src/megasniff/templates/inflator.jinja2 | 6 ++ .../templates/unwrap_type_data.jinja2 | 13 ++- tests/test_basic.py | 47 +++++++++ tests/test_dicts.py | 99 +++++++++++++++++++ 7 files changed, 198 insertions(+), 17 deletions(-) create mode 100644 tests/test_dicts.py diff --git a/pyproject.toml b/pyproject.toml index da289b5..5b6c551 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "megasniff" -version = "0.2.7" +version = "0.2.8" description = "Library for in-time codegened type validation" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } diff --git a/src/megasniff/deflator.py b/src/megasniff/deflator.py index 07f780a..04a0acb 100644 --- a/src/megasniff/deflator.py +++ b/src/megasniff/deflator.py @@ -283,6 +283,8 @@ class SchemaDeflatorGenerator: ret_unw = OtherUnwrapping() elif isinstance(schema, EnumType): ret_unw = EnumUnwrapping() + elif issubclass(schema, uuid.UUID): + ret_unw = FuncUnwrapping('str') elif is_class_definition(schema): hints = typing.get_type_hints(schema) fields = [] diff --git a/src/megasniff/inflator.py b/src/megasniff/inflator.py index 92020c8..74a4f2f 100644 --- a/src/megasniff/inflator.py +++ b/src/megasniff/inflator.py @@ -3,13 +3,14 @@ from __future__ import annotations import collections.abc import importlib.resources +import uuid from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass from enum import EnumType from types import NoneType, UnionType from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \ - OrderedDict + OrderedDict, Dict import jinja2 @@ -22,6 +23,7 @@ class TypeRenderData: typeref: list[TypeRenderData] | TypeRenderData | str allow_none: bool is_list: bool + is_dict: bool is_union: bool is_strict: bool @@ -30,6 +32,14 @@ class TypeRenderData: class IterableTypeRenderData(TypeRenderData): iterable_type: str is_list = True + is_dict = False + is_union = False + + +@dataclass +class DictTypeRenderData(TypeRenderData): + is_list = False + is_dict = True is_union = False @@ -152,21 +162,24 @@ class SchemaInflatorGenerator: allow_none = False argtypes = t, - if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])): + if any(map(lambda x: type_origin is x, + [Union, UnionType, Optional, Annotated, list, List, set, Set, dict, Dict])): argtypes = get_args(t) if NoneType in argtypes or None in argtypes: argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) allow_none = True - is_union = len(argtypes) > 1 - - if is_union: + if type_origin in [dict, Dict]: + k = self._unwrap_typeref(argtypes[0], strict_mode) + v = self._unwrap_typeref(argtypes[1], strict_mode) + return DictTypeRenderData([k, v], allow_none, False, True, False, False) + elif len(argtypes) > 1: typerefs = list(map(lambda x: self._unwrap_typeref(x, strict_mode), argtypes)) - return TypeRenderData(typerefs, allow_none, False, True, False) - elif type_origin in [list, set]: + return TypeRenderData(typerefs, allow_none, False, False, True, False) + elif type_origin in [list, set, List, Set]: rd = self._unwrap_typeref(argtypes[0], strict_mode) - return IterableTypeRenderData(rd, allow_none, True, False, False, type_origin.__name__) + return IterableTypeRenderData(rd, allow_none, True, False, False, False, type_origin.__name__) else: t = argtypes[0] @@ -179,6 +192,7 @@ class SchemaInflatorGenerator: allow_none, is_list, False, + False, strict_mode if is_builtin else False) def _schema_to_inflator(self, @@ -208,7 +222,8 @@ class SchemaInflatorGenerator: schema = new_schema if isinstance(schema, EnumType): - namespace[f'inflate_{schema.__name__}'] = schema + if not is_builtin_type(schema): + namespace[f'inflate_{schema.__name__}'] = schema return '\n', namespace if isinstance(schema, collections.abc.Iterable): @@ -239,8 +254,9 @@ class SchemaInflatorGenerator: return '', namespace if mode == 'object': - namespace[f'{_funcname}_tgt_type'] = schema - namespace[utils.typename(schema)] = schema + if not is_builtin_type(schema): + namespace[f'{_funcname}_tgt_type'] = schema + namespace[utils.typename(schema)] = schema if from_type_override is not None: namespace['_from_type'] = from_type_override @@ -258,8 +274,8 @@ class SchemaInflatorGenerator: while get_origin(argtype) is not None: type_origin = get_origin(argtype) - - if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])): + if any(map(lambda x: type_origin is x, + [Union, UnionType, Optional, Annotated, list, List, set, Set, dict, Dict])): argtypes = get_args(argtype) if len(argtypes) == 1: argtype = argtypes[0] @@ -297,12 +313,14 @@ class SchemaInflatorGenerator: elif argt is schema: pass else: - namespace[argt.__name__] = argt + if not is_builtin_type(argt): + namespace[argt.__name__] = argt convertor_functext = template.render( funcname=_funcname, conversions=render_data, tgt_type=utils.typename(schema), + is_opaque=len(type_hints) == 0, from_type='_from_type' if from_type_override is not None else None ) diff --git a/src/megasniff/templates/inflator.jinja2 b/src/megasniff/templates/inflator.jinja2 index 43d6e9f..eab741b 100644 --- a/src/megasniff/templates/inflator.jinja2 +++ b/src/megasniff/templates/inflator.jinja2 @@ -6,6 +6,7 @@ def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{ """ {{tgt_type}} """ +{% if not is_opaque %} from_data_keys = from_data.keys() {% for conv in conversions %} @@ -29,4 +30,9 @@ def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{ {% endfor %} + return {{funcname}}_tgt_type({% for conv in conversions %}{{conv.argname_escaped}}={{conv.argname_escaped}}, {% endfor %}) +{% else %} + return {{funcname}}_tgt_type(from_data) +{% endif %} + diff --git a/src/megasniff/templates/unwrap_type_data.jinja2 b/src/megasniff/templates/unwrap_type_data.jinja2 index 5374207..f966448 100644 --- a/src/megasniff/templates/unwrap_type_data.jinja2 +++ b/src/megasniff/templates/unwrap_type_data.jinja2 @@ -3,8 +3,8 @@ {{argname}} = [] if not isinstance({{conv_data}}, list): raise FieldValidationException('{{argname}}', "list", conv_data, []) -for item in {{conv_data}}: -{{ render_segment("_" + argname, typedef, "item", false ) | indent(4) }} +for {{"_" + argname + "_item"}} in {{conv_data}}: +{{ render_segment("_" + argname, typedef, "_" + argname + "_item", false ) | indent(4) }} {{argname}}.append(_{{argname}}) {%- endset %} {{out}} @@ -39,12 +39,21 @@ if not isinstance({{conv_data}}, {{typeref}}): {{argname}} = {{typeref}}({{conv_data}}) {% elif typeref.is_union %} +# union typeref {{render_union(argname, typeref, conv_data)}} {% elif typeref.is_list %} +# list typeref {{render_iterable(argname, typeref.typeref, conv_data)}} {{argname}} = {{typeref.iterable_type}}({{argname}}) +{% elif typeref.is_dict %} +# dict typeref +{{render_iterable(argname + "_k", typeref.typeref[0], "list(" + conv_data + ".keys())")}} +{{render_iterable(argname + "_v", typeref.typeref[1], "list(" + conv_data + ".values())")}} + +{{argname}} = dict(zip({{argname}}_k,{{argname}}_v)) + {% else %} {{render_segment(argname, typeref.typeref, conv_data, typeref.is_strict)}} diff --git a/tests/test_basic.py b/tests/test_basic.py index c3681ba..19124e7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import uuid from dataclasses import dataclass from enum import Enum from typing import Optional, get_type_hints @@ -66,6 +67,52 @@ def test_optional(): assert c.a is None +def test_uuid(): + @dataclass + class K: + a: uuid.UUID + b: list[uuid.UUID] + c: Optional[uuid.UUID] + d: dict[uuid.UUID, uuid.UUID] + + infl = SchemaInflatorGenerator(store_sources=True) + defl = SchemaDeflatorGenerator(store_sources=True) + infl_fn = infl.schema_to_inflator(K) + defl_fn = defl.schema_to_deflator(K) + okd = { + 'a': str(uuid.uuid4()), + 'b': [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())], + 'c': None, + 'd': {str(uuid.uuid4()): str(uuid.uuid4()), str(uuid.uuid4()): str(uuid.uuid4())} + } + + k = infl_fn(okd) + + kd = defl_fn(k) + + assert isinstance(kd['a'], str) + assert isinstance(kd['b'], list) + assert len(kd['b']) == 3 + assert isinstance(kd['b'][0], str) + assert isinstance(kd['b'][1], str) + assert isinstance(kd['b'][2], str) + assert kd['c'] is None + assert isinstance(kd['d'], dict) + assert len(kd['d']) == 2 + assert all(map(lambda x: isinstance(x, str), kd['d'].keys())) + assert all(map(lambda x: isinstance(x, str), kd['d'].values())) + + assert isinstance(k.a, uuid.UUID) + assert isinstance(k.b[0], uuid.UUID) + assert isinstance(k.b[1], uuid.UUID) + assert isinstance(k.b[2], uuid.UUID) + assert k.c is None + assert isinstance(k.d, dict) + assert len(k.d) == 2 + assert all(map(lambda x: isinstance(x, uuid.UUID), k.d.keys())) + assert all(map(lambda x: isinstance(x, uuid.UUID), k.d.values())) + + class AEnum(Enum): a = 'a' b = 'b' diff --git a/tests/test_dicts.py b/tests/test_dicts.py new file mode 100644 index 0000000..2011b5f --- /dev/null +++ b/tests/test_dicts.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import Optional, get_type_hints, Dict + +from megasniff import SchemaDeflatorGenerator +from src.megasniff import SchemaInflatorGenerator + + +def test_dicts(): + @dataclass() + class A: + a: dict[str, int] + b: Dict[int, int] + c: dict[str, list[int]] + d: dict[str, dict[str, int | str]] + + infl = SchemaInflatorGenerator(store_sources=True) + defl = SchemaDeflatorGenerator(store_sources=True) + + infl_fn = infl.schema_to_inflator(A) + defl_fn = defl.schema_to_deflator(A) + + a = infl_fn({ + 'a': { + 1: '42', + 2: '123', + 'asdf': 42 + }, + 'b': { + 1: 1, + '2': '2', + '3': 3, + 4: '4' + }, + 'c': { + 'a': [1, 2, 3, '4'] + }, + 'd': { + 'a': { + 'a': 1, + 'b': '1', + 'c': 'asdf' + } + } + }) + + assert a.a['1'] == 42 + assert a.a['2'] == 123 + assert a.a['asdf'] == 42 + + assert a.b[1] == 1 + assert a.b[2] == 2 + assert a.b[3] == 3 + assert a.b[4] == 4 + + assert a.c['a'][0] == 1 + assert a.c['a'][1] == 2 + assert a.c['a'][2] == 3 + assert a.c['a'][3] == 4 + + assert a.d['a']['a'] == 1 + assert a.d['a']['b'] == 1 + assert a.d['a']['c'] == 'asdf' + + +def test_uuid_dicts(): + @dataclass() + class A: + a: dict[uuid.UUID, uuid.UUID] + + infl = SchemaInflatorGenerator(store_sources=True) + defl = SchemaDeflatorGenerator(store_sources=True) + + infl_fn = infl.schema_to_inflator(A) + defl_fn = defl.schema_to_deflator(A) + + uuids = [uuid.uuid4() for _ in range(32)] + + a = infl_fn({ + 'a': { + str(uuids[0]): str(uuids[0]), + str(uuids[1]): str(uuids[2]), + str(uuids[3]): str(uuids[4]), + } + }) + + assert a.a[uuids[0]] == uuids[0] + assert a.a[uuids[1]] == uuids[2] + assert a.a[uuids[3]] == uuids[4] + + ad = json.loads(json.dumps(defl_fn(a), default=str)) + + assert ad['a'][str(uuids[0])] == str(uuids[0]) + assert ad['a'][str(uuids[1])] == str(uuids[2]) + assert ad['a'][str(uuids[3])] == str(uuids[4])