Fix recursive union-iterable-*-types codegen

This commit is contained in:
2025-07-14 16:27:55 +03:00
parent 897eccd8d1
commit bc6acb099f
9 changed files with 258 additions and 115 deletions

View File

@@ -8,7 +8,7 @@ import megasniff.exceptions
from . import SchemaInflatorGenerator from . import SchemaInflatorGenerator
@dataclass @dataclass(frozen=True)
class ASchema: class ASchema:
a: int a: int
b: float | str b: float | str
@@ -23,18 +23,33 @@ class BSchema(TypedDict):
d: ASchema d: ASchema
@dataclass
class CSchema:
l: set[int | ASchema]
def main(): def main():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
t, n = infl._schema_to_generator(ASchema) fn = infl.schema_to_inflator(ASchema)
print(t) # print(t)
print(n) # print(n)
exec(t, n) # exec(t, n)
fn = n['inflate'] # fn = n['inflate']
# fn = infl.schema_to_generator(ASchema) # fn = infl.schema_to_generator(ASchema)
# # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}} # # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}} # d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}}
# d = {'a': 2, 'b': 2, 'bs': {'a': 2, 'b': 'a', 'c': 0, 'd': {'a': 2, 'b': 2}}}
# d = {'l': ['1', {'a': 42, 'b': 1}]}
d = {'a': 2, 'b': 2, 'bs': None}
try: try:
print(fn(d)) o = fn(d)
print(o)
for k, v in o.__dict__.items():
print(f'field {k}: {v}')
print(f'type: {type(v)}')
if isinstance(v, list):
for vi in v:
print(f'\ttype: {type(vi)}')
except megasniff.exceptions.FieldValidationException as e: except megasniff.exceptions.FieldValidationException as e:
print(e.exceptions) print(e.exceptions)
print(e) print(e)

View File

@@ -1,12 +1,13 @@
# Copyright (C) 2025 Shevchenko A # Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import collections.abc import collections.abc
import importlib.resources import importlib.resources
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set
import jinja2 import jinja2
@@ -15,21 +16,28 @@ from .utils import *
@dataclass @dataclass
class RenderData: class TypeRenderData:
typeref: list[TypeRenderData] | TypeRenderData | str
allow_none: bool
is_list: bool
is_union: bool
@dataclass
class IterableTypeRenderData(TypeRenderData):
iterable_type: str
is_list = True
is_union = False
@dataclass
class FieldRenderData:
argname: str argname: str
constrs: list[str] # typecall / use lookup table constrs: TypeRenderData
typename: str typename: str
is_optional: bool is_optional: bool
allow_none: bool allow_none: bool
default_option: Optional[str] default_option: Optional[str]
typeid: int
@dataclass
class SchemaInflatorTemplateSettings:
general: str = 'inflator.jinja2'
union: str = 'union.jinja2'
iterable: str = 'iterable.jinja2'
class SchemaInflatorGenerator: class SchemaInflatorGenerator:
@@ -37,17 +45,10 @@ class SchemaInflatorGenerator:
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
template: jinja2.Template template: jinja2.Template
union_template: jinja2.Template
iterable_template: jinja2.Template
settings: SchemaInflatorTemplateSettings
def __init__(self, def __init__(self,
loader: Optional[jinja2.BaseLoader] = None, loader: Optional[jinja2.BaseLoader] = None,
template_settings: Optional[SchemaInflatorTemplateSettings] = None): template_filename: str = 'inflator.jinja2'):
if template_settings is None:
template_settings = SchemaInflatorTemplateSettings()
if loader is None: if loader is None:
template_path = importlib.resources.files('megasniff.templates') template_path = importlib.resources.files('megasniff.templates')
@@ -55,17 +56,46 @@ class SchemaInflatorGenerator:
self.templateLoader = loader self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader) self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(template_settings.general) self.template = self.templateEnv.get_template(template_filename)
self.union_template = self.templateEnv.get_template(template_settings.union)
self.iterable_template = self.templateEnv.get_template(template_settings.iterable)
def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]: def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]:
txt, namespace = self._schema_to_inflator(schema, _funcname='inflate') txt, namespace = self._schema_to_inflator(schema, _funcname='inflate')
txt = ('from typing import Any\n' imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
exec(txt, namespace) exec(txt, namespace)
return namespace['inflate'] return namespace['inflate']
def _unwrap_typeref(self, t: type) -> TypeRenderData:
type_origin = get_origin(t)
allow_none = False
argtypes = t,
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(t)
if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True
is_union = len(argtypes) > 1
if is_union:
typerefs = list(map(lambda x: self._unwrap_typeref(x), argtypes))
return TypeRenderData(typerefs, allow_none, False, True)
elif type_origin in [list, set]:
rd = self._unwrap_typeref(argtypes[0])
return IterableTypeRenderData(rd, allow_none, True, False, type_origin.__name__)
else:
t = argtypes[0]
is_list = (type_origin or t) in [list, set]
if is_list:
t = type_origin or t
is_builtin = is_builtin_type(t)
return TypeRenderData(t.__name__ if is_builtin else f'inflate_{t.__name__}', allow_none, is_list, False)
def _schema_to_inflator(self, def _schema_to_inflator(self,
schema: type, schema: type,
*, *,
@@ -92,49 +122,67 @@ class SchemaInflatorGenerator:
continue continue
has_default, default_option = get_field_default(schema, argname) has_default, default_option = get_field_default(schema, argname)
typeref = self._unwrap_typeref(argtype)
argtypes = argtype, argtypes = argtype,
type_origin = get_origin(argtype)
allow_none = False allow_none = False
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated])): while get_origin(argtype) is not None:
type_origin = get_origin(argtype)
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(argtype) argtypes = get_args(argtype)
if len(argtypes) == 1:
argtype = argtypes[0]
else:
break
if NoneType in argtypes or None in argtypes: if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True allow_none = True
out_argtypes: list[str] = []
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt, _funcname=f'inflate_{argt.__name__}',
_namespace=namespace)
namespace |= n
txt_segments.append(t)
out_argtypes.append(f'inflate_{argt.__name__}')
# lookup_table[hash(argt)] = infl
# namespace[argt.__name__] = infl
elif argt is schema:
out_argtypes.append(_funcname)
else:
namespace[argt.__name__] = argt
out_argtypes.append(argt.__name__)
render_data.append( render_data.append(
RenderData( FieldRenderData(
argname, argname,
out_argtypes, typeref,
utils.typename(argtype), utils.typename(argtype),
has_default, has_default,
allow_none, allow_none,
default_option, default_option,
hash(argtype)
) )
) )
#
# out_argtypes: list[str] = []
#
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt,
_funcname=f'inflate_{argt.__name__}',
_namespace=namespace)
namespace |= n
txt_segments.append(t)
elif argt is schema:
pass
else:
namespace[argt.__name__] = argt
#
# render_data.append(
# FieldRenderData(
# argname,
# out_argtypes,
# utils.typename(argtype),
# has_default,
# allow_none,
# default_option,
# type_origin is not None and type_origin.__name__ == 'list',
# )
# )
convertor_functext = self.template.render( convertor_functext = self.template.render(
funcname=_funcname, funcname=_funcname,

View File

@@ -1,6 +0,0 @@
{% macro render_segment(argname, typename) -%}
{%- set out -%}
{{argname}} = {{typename}}(conv_data)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -1,6 +1,6 @@
{% set ns = namespace(retry_indent=0) %} {% set ns = namespace(retry_indent=0) %}
{% import "basic.jinja2" as basic %} {% import "unwrap_type_data.jinja2" as unwrap_type_data %}
{% import "union.jinja2" as union %}
def {{funcname}}(from_data: dict[str, Any]): def {{funcname}}(from_data: dict[str, Any]):
""" """
@@ -24,16 +24,8 @@ def {{funcname}}(from_data: dict[str, Any]):
{{conv.argname}} = None {{conv.argname}} = None
{% endif %} {% endif %}
else: else:
try:
{% if conv.constrs | length > 1 %}
{{ union.render_segment(conv) | indent(4*4) }}
{% else %}
{{ basic.render_segment(conv.argname, conv.constrs[0]) | indent(4*4) }}
{% endif %}
except FieldValidationException as e:
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions)
{{ unwrap_type_data.render_segment(conv.argname, conv.constrs, "conv_data") | indent(4*3) }}
{% endfor %} {% endfor %}

View File

@@ -1,5 +0,0 @@
def inflate(iterable):
ret = {{ rettype }}()
for item in iterable:
ret.{{ retadd }}(_lookup_table[{{item_id}}](item))
return ret

View File

@@ -1,17 +0,0 @@
{% import "basic.jinja2" as basic %}
{% macro render_segment(conv) -%}
{%- set out -%}
{% set ns = namespace(retry_indent=0) %}
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type in conv.constrs %}
{{ ' ' * ns.retry_indent }}try:
{{ basic.render_segment(conv.argname, union_type) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -0,0 +1,51 @@
{% macro render_iterable(argname, typedef, conv_data) -%}
{%- set out -%}
{{argname}} = []
if not isinstance({{conv_data}}, list):
raise FieldValidationException('{{argname}}', "list", conv_data, [])
for item in {{conv_data}}:
{{ render_segment("_" + argname, typedef, "item") | indent(4) }}
{{argname}}.append(_{{argname}})
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_union(argname, conv, conv_data) -%}
{%- set out -%}
# unwrapping union {{conv}}
{% set ns = namespace(retry_indent=0) %}
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type in conv.typeref %}
{{ ' ' * ns.retry_indent }}try:
{{ render_segment(argname, union_type, conv_data) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_segment(argname, typeref, conv_data) -%}
{%- set out -%}
{% if typeref is string %}
{{argname}} = {{typeref}}({{conv_data}})
{% elif typeref.is_union %}
{{render_union(argname, typeref, conv_data)}}
{% elif typeref.is_list %}
{{render_iterable(argname, typeref.typeref, conv_data)}}
{{argname}} = {{typeref.iterable_type}}({{argname}})
{% else %}
{{render_segment(argname, typeref.typeref, conv_data)}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -3,20 +3,85 @@ from dataclasses import dataclass
from megasniff import SchemaInflatorGenerator from megasniff import SchemaInflatorGenerator
# def test_list(): def test_list_basic():
# @dataclass @dataclass
# class A: class A:
# l: list[int] l: list[int]
#
# infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
# fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
#
# a = fn({'l': []}) a = fn({'l': []})
# assert isinstance(a.l, list) assert isinstance(a.l, list)
# assert len(a.l) == 0 assert len(a.l) == 0
#
# a = fn({'l': [1, 2.1, '0']}) a = fn({'l': [1, 2.1, '0']})
# print(a.l) print(a.l)
# assert isinstance(a.l, list) assert isinstance(a.l, list)
# assert len(a.l) == 3 assert len(a.l) == 3
# assert all(map(lambda x: isinstance(x, int), a.l)) assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: list[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == ['1', '2.1', '0']
def test_list_union():
@dataclass
class A:
l: list[int | str]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, list)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
def test_set_basic():
@dataclass
class A:
l: set[int]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, set)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: set[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0', 0]})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == {'1', '2.1', '0'}