Compare commits

...

2 Commits

9 changed files with 311 additions and 116 deletions

View File

@@ -4,10 +4,11 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from typing import TypedDict from typing import TypedDict
import megasniff.exceptions
from . import SchemaInflatorGenerator from . import SchemaInflatorGenerator
@dataclass @dataclass(frozen=True)
class ASchema: class ASchema:
a: int a: int
b: float | str b: float | str
@@ -22,11 +23,36 @@ class BSchema(TypedDict):
d: ASchema d: ASchema
@dataclass
class CSchema:
l: set[int | ASchema]
def main(): def main():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(ASchema) fn = infl.schema_to_inflator(ASchema)
d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}} # print(t)
print(fn(d)) # print(n)
# exec(t, n)
# fn = n['inflate']
# fn = infl.schema_to_generator(ASchema)
# # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
# d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}}
# d = {'a': 2, 'b': 2, 'bs': {'a': 2, 'b': 'a', 'c': 0, 'd': {'a': 2, 'b': 2}}}
# d = {'l': ['1', {'a': 42, 'b': 1}]}
d = {'a': 2, 'b': 2, 'bs': None}
try:
o = fn(d)
print(o)
for k, v in o.__dict__.items():
print(f'field {k}: {v}')
print(f'type: {type(v)}')
if isinstance(v, list):
for vi in v:
print(f'\ttype: {type(vi)}')
except megasniff.exceptions.FieldValidationException as e:
print(e.exceptions)
print(e)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,11 +1,13 @@
# Copyright (C) 2025 Shevchenko A # Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import collections.abc
import importlib.resources import importlib.resources
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set
import jinja2 import jinja2
@@ -14,35 +16,39 @@ from .utils import *
@dataclass @dataclass
class RenderData: class TypeRenderData:
typeref: list[TypeRenderData] | TypeRenderData | str
allow_none: bool
is_list: bool
is_union: bool
@dataclass
class IterableTypeRenderData(TypeRenderData):
iterable_type: str
is_list = True
is_union = False
@dataclass
class FieldRenderData:
argname: str argname: str
constrs: list[tuple[str, bool]] # typecall / use lookup table constrs: TypeRenderData
typename: str typename: str
is_optional: bool is_optional: bool
allow_none: bool allow_none: bool
default_option: Optional[str] default_option: Optional[str]
typeid: int
@dataclass
class SchemaInflatorTemplateSettings:
general: str = 'inflator.jinja2'
union: str = 'union.jinja2'
class SchemaInflatorGenerator: class SchemaInflatorGenerator:
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
template: jinja2.Template template: jinja2.Template
union_template: jinja2.Template
settings: SchemaInflatorTemplateSettings
def __init__(self, def __init__(self,
loader: Optional[jinja2.BaseLoader] = None, loader: Optional[jinja2.BaseLoader] = None,
template_settings: Optional[SchemaInflatorTemplateSettings] = None): template_filename: str = 'inflator.jinja2'):
if template_settings is None:
template_settings = SchemaInflatorTemplateSettings()
if loader is None: if loader is None:
template_path = importlib.resources.files('megasniff.templates') template_path = importlib.resources.files('megasniff.templates')
@@ -50,94 +56,143 @@ class SchemaInflatorGenerator:
self.templateLoader = loader self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader) self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(template_settings.general) self.template = self.templateEnv.get_template(template_filename)
self.union_template = self.templateEnv.get_template(template_settings.union)
def _union_inflator(self, def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]:
argname: str, txt, namespace = self._schema_to_inflator(schema, _funcname='inflate')
argtype: str, imports = ('from typing import Any\n'
constrs: list[tuple[str, bool]], 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
lookup_table: dict[str, Any]): txt = imports + '\n' + txt
txt = self.union_template.render(
argname=argname,
typename=argtype,
constrs=constrs
)
namespace = {
'_lookup_table': lookup_table
}
exec(txt, namespace) exec(txt, namespace)
return namespace['inflate'] return namespace['inflate']
def schema_to_generator(self, def _unwrap_typeref(self, t: type) -> TypeRenderData:
type_origin = get_origin(t)
allow_none = False
argtypes = t,
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(t)
if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True
is_union = len(argtypes) > 1
if is_union:
typerefs = list(map(lambda x: self._unwrap_typeref(x), argtypes))
return TypeRenderData(typerefs, allow_none, False, True)
elif type_origin in [list, set]:
rd = self._unwrap_typeref(argtypes[0])
return IterableTypeRenderData(rd, allow_none, True, False, type_origin.__name__)
else:
t = argtypes[0]
is_list = (type_origin or t) in [list, set]
if is_list:
t = type_origin or t
is_builtin = is_builtin_type(t)
return TypeRenderData(t.__name__ if is_builtin else f'inflate_{t.__name__}', allow_none, is_list, False)
def _schema_to_inflator(self,
schema: type, schema: type,
*, *,
_base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]: _funcname='inflate',
_namespace=None) -> tuple[str, dict]:
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян # Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
type_hints = get_kwargs_type_hints(schema) type_hints = get_kwargs_type_hints(schema)
render_data = [] render_data = []
lookup_table = _base_lookup_table or {}
if schema.__name__ not in lookup_table.keys(): txt_segments = []
lookup_table[schema.__name__] = None
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace
namespace[f'{_funcname}_tgt_type'] = schema
for argname, argtype in type_hints.items(): for argname, argtype in type_hints.items():
if argname in {'return', 'self'}: if argname in {'return', 'self'}:
continue continue
has_default, default_option = get_field_default(schema, argname) has_default, default_option = get_field_default(schema, argname)
typeref = self._unwrap_typeref(argtype)
argtypes = argtype, argtypes = argtype,
type_origin = get_origin(argtype)
allow_none = False allow_none = False
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated])): while get_origin(argtype) is not None:
type_origin = get_origin(argtype)
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(argtype) argtypes = get_args(argtype)
if len(argtypes) == 1:
argtype = argtypes[0]
else:
break
if NoneType in argtypes or None in argtypes: if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True allow_none = True
out_argtypes: list[tuple[str, bool]] = []
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
if argt.__name__ not in lookup_table.keys():
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table)
if argt is schema:
out_argtypes.append(('inflate', True))
else:
out_argtypes.append((argt.__name__, is_builtin))
if len(argtypes) > 1:
lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table)
render_data.append( render_data.append(
RenderData( FieldRenderData(
argname, argname,
out_argtypes, typeref,
utils.typename(argtype), utils.typename(argtype),
has_default, has_default,
allow_none, allow_none,
default_option, default_option,
hash(argtype)
) )
) )
#
# out_argtypes: list[str] = []
#
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt,
_funcname=f'inflate_{argt.__name__}',
_namespace=namespace)
namespace |= n
txt_segments.append(t)
elif argt is schema:
pass
else:
namespace[argt.__name__] = argt
#
# render_data.append(
# FieldRenderData(
# argname,
# out_argtypes,
# utils.typename(argtype),
# has_default,
# allow_none,
# default_option,
# type_origin is not None and type_origin.__name__ == 'list',
# )
# )
convertor_functext = self.template.render(
funcname=_funcname,
conversions=render_data,
tgt_type=utils.typename(schema)
)
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext
convertor_functext = self.template.render(conversions=render_data)
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
convertor_functext = convertor_functext.replace(', )', ')') convertor_functext = convertor_functext.replace(', )', ')')
namespace = {
'_tgt_type': schema,
'_lookup_table': lookup_table
}
exec(convertor_functext, namespace)
# пихаем сгенеренный метод в табличку, return convertor_functext, namespace
# ожидаем что она обновится во всех вложенных методах,
# разрешая циклические зависимости
lookup_table[schema.__name__] = namespace['inflate']
return namespace['inflate']

View File

@@ -1,8 +1,11 @@
{% set ns = namespace(retry_indent=0) %} {% set ns = namespace(retry_indent=0) %}
from typing import Any {% import "unwrap_type_data.jinja2" as unwrap_type_data %}
from megasniff.exceptions import MissingFieldException, FieldValidationException
def inflate(from_data: dict[str, Any]):
def {{funcname}}(from_data: dict[str, Any]):
"""
{{tgt_type}}
"""
from_data_keys = from_data.keys() from_data_keys = from_data.keys()
{% for conv in conversions %} {% for conv in conversions %}
@@ -21,17 +24,9 @@ def inflate(from_data: dict[str, Any]):
{{conv.argname}} = None {{conv.argname}} = None
{% endif %} {% endif %}
else: else:
try:
{% if conv.constrs | length > 1 or conv.constrs[0][1] is false %}
{{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data)
{% else %}
{{conv.argname}} = {{conv.constrs[0][0]}}(conv_data)
{% endif %}
except FieldValidationException as e:
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions)
{{ unwrap_type_data.render_segment(conv.argname, conv.constrs, "conv_data") | indent(4*3) }}
{% endfor %} {% endfor %}
return _tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %}) return {{funcname}}_tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %})

View File

@@ -1,19 +0,0 @@
{% set ns = namespace(retry_indent=0) %}
from typing import Any
from megasniff.exceptions import FieldValidationException
def inflate(conv_data: Any):
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type, is_builtin in constrs %}
{{ ' ' * ns.retry_indent }}try:
{% if is_builtin %}
{{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data)
{% else %}
{{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data)
{% endif %}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions)

View File

@@ -0,0 +1,51 @@
{% macro render_iterable(argname, typedef, conv_data) -%}
{%- set out -%}
{{argname}} = []
if not isinstance({{conv_data}}, list):
raise FieldValidationException('{{argname}}', "list", conv_data, [])
for item in {{conv_data}}:
{{ render_segment("_" + argname, typedef, "item") | indent(4) }}
{{argname}}.append(_{{argname}})
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_union(argname, conv, conv_data) -%}
{%- set out -%}
# unwrapping union {{conv}}
{% set ns = namespace(retry_indent=0) %}
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type in conv.typeref %}
{{ ' ' * ns.retry_indent }}try:
{{ render_segment(argname, union_type, conv_data) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_segment(argname, typeref, conv_data) -%}
{%- set out -%}
{% if typeref is string %}
{{argname}} = {{typeref}}({{conv_data}})
{% elif typeref.is_union %}
{{render_union(argname, typeref, conv_data)}}
{% elif typeref.is_list %}
{{render_iterable(argname, typeref.typeref, conv_data)}}
{{argname}} = {{typeref.iterable_type}}({{argname}})
{% else %}
{{render_segment(argname, typeref.typeref, conv_data)}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -11,7 +11,7 @@ def test_basic_constructor():
self.a = a self.a = a
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -23,7 +23,7 @@ def test_unions():
a: int | str a: int | str
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -45,10 +45,10 @@ class CircB:
def test_circular(): def test_circular():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(CircA) fn = infl.schema_to_inflator(CircA)
a = fn({'b': {'a': None}}) a = fn({'b': {'a': None}})
return isinstance(a.b, CircB) assert isinstance(a.b, CircB)
def test_optional(): def test_optional():
@@ -57,6 +57,6 @@ def test_optional():
a: Optional[int] = None a: Optional[int] = None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(C) fn = infl.schema_to_inflator(C)
c = fn({}) c = fn({})
assert c.a is None assert c.a is None

View File

@@ -12,7 +12,7 @@ def test_missing_field():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(MissingFieldException): with pytest.raises(MissingFieldException):
fn({}) fn({})
@@ -23,7 +23,7 @@ def test_null():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': None}) fn({'a': None})
@@ -34,6 +34,6 @@ def test_invalid_field():
a: float | int | None a: float | int | None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': {}}) fn({'a': {}})

87
tests/test_iterables.py Normal file
View File

@@ -0,0 +1,87 @@
from dataclasses import dataclass
from megasniff import SchemaInflatorGenerator
def test_list_basic():
@dataclass
class A:
l: list[int]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, list)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: list[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == ['1', '2.1', '0']
def test_list_union():
@dataclass
class A:
l: list[int | str]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, list)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
def test_set_basic():
@dataclass
class A:
l: set[int]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, set)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: set[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0', 0]})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == {'1', '2.1', '0'}