Remove lookup_table from inflator generated code, rename generating func

This commit is contained in:
2025-07-12 05:50:56 +03:00
parent aee6dcf3d3
commit 897eccd8d1
10 changed files with 131 additions and 79 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from typing import TypedDict from typing import TypedDict
import megasniff.exceptions
from . import SchemaInflatorGenerator from . import SchemaInflatorGenerator
@@ -24,9 +25,19 @@ class BSchema(TypedDict):
def main(): def main():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(ASchema) t, n = infl._schema_to_generator(ASchema)
d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}} print(t)
print(n)
exec(t, n)
fn = n['inflate']
# fn = infl.schema_to_generator(ASchema)
# # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}}
try:
print(fn(d)) print(fn(d))
except megasniff.exceptions.FieldValidationException as e:
print(e.exceptions)
print(e)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,11 +1,12 @@
# Copyright (C) 2025 Shevchenko A # Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
import collections.abc
import importlib.resources import importlib.resources
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence
import jinja2 import jinja2
@@ -16,7 +17,7 @@ from .utils import *
@dataclass @dataclass
class RenderData: class RenderData:
argname: str argname: str
constrs: list[tuple[str, bool]] # typecall / use lookup table constrs: list[str] # typecall / use lookup table
typename: str typename: str
is_optional: bool is_optional: bool
allow_none: bool allow_none: bool
@@ -28,13 +29,17 @@ class RenderData:
class SchemaInflatorTemplateSettings: class SchemaInflatorTemplateSettings:
general: str = 'inflator.jinja2' general: str = 'inflator.jinja2'
union: str = 'union.jinja2' union: str = 'union.jinja2'
iterable: str = 'iterable.jinja2'
class SchemaInflatorGenerator: class SchemaInflatorGenerator:
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
template: jinja2.Template template: jinja2.Template
union_template: jinja2.Template union_template: jinja2.Template
iterable_template: jinja2.Template
settings: SchemaInflatorTemplateSettings settings: SchemaInflatorTemplateSettings
def __init__(self, def __init__(self,
@@ -52,34 +57,35 @@ class SchemaInflatorGenerator:
self.templateEnv = jinja2.Environment(loader=self.templateLoader) self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(template_settings.general) self.template = self.templateEnv.get_template(template_settings.general)
self.union_template = self.templateEnv.get_template(template_settings.union) self.union_template = self.templateEnv.get_template(template_settings.union)
self.iterable_template = self.templateEnv.get_template(template_settings.iterable)
def _union_inflator(self, def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]:
argname: str, txt, namespace = self._schema_to_inflator(schema, _funcname='inflate')
argtype: str, txt = ('from typing import Any\n'
constrs: list[tuple[str, bool]], 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt
lookup_table: dict[str, Any]):
txt = self.union_template.render(
argname=argname,
typename=argtype,
constrs=constrs
)
namespace = {
'_lookup_table': lookup_table
}
exec(txt, namespace) exec(txt, namespace)
return namespace['inflate'] return namespace['inflate']
def schema_to_generator(self, def _schema_to_inflator(self,
schema: type, schema: type,
*, *,
_base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]: _funcname='inflate',
_namespace=None) -> tuple[str, dict]:
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян # Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
type_hints = get_kwargs_type_hints(schema) type_hints = get_kwargs_type_hints(schema)
render_data = [] render_data = []
lookup_table = _base_lookup_table or {}
if schema.__name__ not in lookup_table.keys(): txt_segments = []
lookup_table[schema.__name__] = None
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace
namespace[f'{_funcname}_tgt_type'] = schema
for argname, argtype in type_hints.items(): for argname, argtype in type_hints.items():
if argname in {'return', 'self'}: if argname in {'return', 'self'}:
@@ -97,22 +103,26 @@ class SchemaInflatorGenerator:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True allow_none = True
out_argtypes: list[tuple[str, bool]] = [] out_argtypes: list[str] = []
for argt in argtypes: for argt in argtypes:
is_builtin = is_builtin_type(argt) is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema: if not is_builtin and argt is not schema:
if argt.__name__ not in lookup_table.keys():
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию # если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table) if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt, _funcname=f'inflate_{argt.__name__}',
_namespace=namespace)
namespace |= n
txt_segments.append(t)
out_argtypes.append(f'inflate_{argt.__name__}')
# lookup_table[hash(argt)] = infl
# namespace[argt.__name__] = infl
if argt is schema: elif argt is schema:
out_argtypes.append(('inflate', True)) out_argtypes.append(_funcname)
else: else:
out_argtypes.append((argt.__name__, is_builtin)) namespace[argt.__name__] = argt
out_argtypes.append(argt.__name__)
if len(argtypes) > 1:
lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table)
render_data.append( render_data.append(
RenderData( RenderData(
@@ -126,18 +136,15 @@ class SchemaInflatorGenerator:
) )
) )
convertor_functext = self.template.render(conversions=render_data) convertor_functext = self.template.render(
funcname=_funcname,
conversions=render_data,
tgt_type=utils.typename(schema)
)
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
convertor_functext = convertor_functext.replace(', )', ')') convertor_functext = convertor_functext.replace(', )', ')')
namespace = {
'_tgt_type': schema,
'_lookup_table': lookup_table
}
exec(convertor_functext, namespace)
# пихаем сгенеренный метод в табличку, return convertor_functext, namespace
# ожидаем что она обновится во всех вложенных методах,
# разрешая циклические зависимости
lookup_table[schema.__name__] = namespace['inflate']
return namespace['inflate']

View File

@@ -0,0 +1,6 @@
{% macro render_segment(argname, typename) -%}
{%- set out -%}
{{argname}} = {{typename}}(conv_data)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -1,8 +1,11 @@
{% set ns = namespace(retry_indent=0) %} {% set ns = namespace(retry_indent=0) %}
from typing import Any {% import "basic.jinja2" as basic %}
from megasniff.exceptions import MissingFieldException, FieldValidationException {% import "union.jinja2" as union %}
def inflate(from_data: dict[str, Any]): def {{funcname}}(from_data: dict[str, Any]):
"""
{{tgt_type}}
"""
from_data_keys = from_data.keys() from_data_keys = from_data.keys()
{% for conv in conversions %} {% for conv in conversions %}
@@ -22,10 +25,10 @@ def inflate(from_data: dict[str, Any]):
{% endif %} {% endif %}
else: else:
try: try:
{% if conv.constrs | length > 1 or conv.constrs[0][1] is false %} {% if conv.constrs | length > 1 %}
{{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data) {{ union.render_segment(conv) | indent(4*4) }}
{% else %} {% else %}
{{conv.argname}} = {{conv.constrs[0][0]}}(conv_data) {{ basic.render_segment(conv.argname, conv.constrs[0]) | indent(4*4) }}
{% endif %} {% endif %}
except FieldValidationException as e: except FieldValidationException as e:
@@ -34,4 +37,4 @@ def inflate(from_data: dict[str, Any]):
{% endfor %} {% endfor %}
return _tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %}) return {{funcname}}_tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %})

View File

@@ -0,0 +1,5 @@
def inflate(iterable):
ret = {{ rettype }}()
for item in iterable:
ret.{{ retadd }}(_lookup_table[{{item_id}}](item))
return ret

View File

@@ -1,19 +1,17 @@
{% import "basic.jinja2" as basic %}
{% macro render_segment(conv) -%}
{%- set out -%}
{% set ns = namespace(retry_indent=0) %} {% set ns = namespace(retry_indent=0) %}
from typing import Any
from megasniff.exceptions import FieldValidationException
def inflate(conv_data: Any):
{% set ns.retry_indent = 0 %} {% set ns.retry_indent = 0 %}
all_conv_exceptions = [] all_conv_exceptions = []
{% for union_type, is_builtin in constrs %} {% for union_type in conv.constrs %}
{{ ' ' * ns.retry_indent }}try: {{ ' ' * ns.retry_indent }}try:
{% if is_builtin %} {{ basic.render_segment(conv.argname, union_type) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data)
{% else %}
{{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data)
{% endif %}
{{ ' ' * ns.retry_indent }}except Exception as e: {{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) {{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %} {% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %} {% endfor %}
raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions) {{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -11,7 +11,7 @@ def test_basic_constructor():
self.a = a self.a = a
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -23,7 +23,7 @@ def test_unions():
a: int | str a: int | str
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -45,10 +45,10 @@ class CircB:
def test_circular(): def test_circular():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(CircA) fn = infl.schema_to_inflator(CircA)
a = fn({'b': {'a': None}}) a = fn({'b': {'a': None}})
return isinstance(a.b, CircB) assert isinstance(a.b, CircB)
def test_optional(): def test_optional():
@@ -57,6 +57,6 @@ def test_optional():
a: Optional[int] = None a: Optional[int] = None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(C) fn = infl.schema_to_inflator(C)
c = fn({}) c = fn({})
assert c.a is None assert c.a is None

View File

@@ -12,7 +12,7 @@ def test_missing_field():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(MissingFieldException): with pytest.raises(MissingFieldException):
fn({}) fn({})
@@ -23,7 +23,7 @@ def test_null():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': None}) fn({'a': None})
@@ -34,6 +34,6 @@ def test_invalid_field():
a: float | int | None a: float | int | None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': {}}) fn({'a': {}})

22
tests/test_iterables.py Normal file
View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
from megasniff import SchemaInflatorGenerator
# def test_list():
# @dataclass
# class A:
# l: list[int]
#
# infl = SchemaInflatorGenerator()
# fn = infl.schema_to_generator(A)
#
# a = fn({'l': []})
# assert isinstance(a.l, list)
# assert len(a.l) == 0
#
# a = fn({'l': [1, 2.1, '0']})
# print(a.l)
# assert isinstance(a.l, list)
# assert len(a.l) == 3
# assert all(map(lambda x: isinstance(x, int), a.l))