Remove lookup_table from inflator generated code, rename generating func

This commit is contained in:
2025-07-12 05:50:56 +03:00
parent aee6dcf3d3
commit 897eccd8d1
10 changed files with 131 additions and 79 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Optional
from typing import TypedDict
import megasniff.exceptions
from . import SchemaInflatorGenerator
@@ -24,9 +25,19 @@ class BSchema(TypedDict):
def main():
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(ASchema)
d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
print(fn(d))
t, n = infl._schema_to_generator(ASchema)
print(t)
print(n)
exec(t, n)
fn = n['inflate']
# fn = infl.schema_to_generator(ASchema)
# # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}}
try:
print(fn(d))
except megasniff.exceptions.FieldValidationException as e:
print(e.exceptions)
print(e)
if __name__ == '__main__':

View File

@@ -1,11 +1,12 @@
# Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later
import collections.abc
import importlib.resources
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence
import jinja2
@@ -16,7 +17,7 @@ from .utils import *
@dataclass
class RenderData:
argname: str
constrs: list[tuple[str, bool]] # typecall / use lookup table
constrs: list[str] # typecall / use lookup table
typename: str
is_optional: bool
allow_none: bool
@@ -28,13 +29,17 @@ class RenderData:
class SchemaInflatorTemplateSettings:
general: str = 'inflator.jinja2'
union: str = 'union.jinja2'
iterable: str = 'iterable.jinja2'
class SchemaInflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
template: jinja2.Template
union_template: jinja2.Template
iterable_template: jinja2.Template
settings: SchemaInflatorTemplateSettings
def __init__(self,
@@ -52,34 +57,35 @@ class SchemaInflatorGenerator:
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(template_settings.general)
self.union_template = self.templateEnv.get_template(template_settings.union)
self.iterable_template = self.templateEnv.get_template(template_settings.iterable)
def _union_inflator(self,
argname: str,
argtype: str,
constrs: list[tuple[str, bool]],
lookup_table: dict[str, Any]):
txt = self.union_template.render(
argname=argname,
typename=argtype,
constrs=constrs
)
namespace = {
'_lookup_table': lookup_table
}
def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]:
txt, namespace = self._schema_to_inflator(schema, _funcname='inflate')
txt = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt
exec(txt, namespace)
return namespace['inflate']
def schema_to_generator(self,
def _schema_to_inflator(self,
schema: type,
*,
_base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]:
_funcname='inflate',
_namespace=None) -> tuple[str, dict]:
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
type_hints = get_kwargs_type_hints(schema)
render_data = []
lookup_table = _base_lookup_table or {}
if schema.__name__ not in lookup_table.keys():
lookup_table[schema.__name__] = None
txt_segments = []
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace
namespace[f'{_funcname}_tgt_type'] = schema
for argname, argtype in type_hints.items():
if argname in {'return', 'self'}:
@@ -97,22 +103,26 @@ class SchemaInflatorGenerator:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True
out_argtypes: list[tuple[str, bool]] = []
out_argtypes: list[str] = []
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
if argt.__name__ not in lookup_table.keys():
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table)
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt, _funcname=f'inflate_{argt.__name__}',
_namespace=namespace)
namespace |= n
txt_segments.append(t)
out_argtypes.append(f'inflate_{argt.__name__}')
# lookup_table[hash(argt)] = infl
# namespace[argt.__name__] = infl
if argt is schema:
out_argtypes.append(('inflate', True))
elif argt is schema:
out_argtypes.append(_funcname)
else:
out_argtypes.append((argt.__name__, is_builtin))
if len(argtypes) > 1:
lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table)
namespace[argt.__name__] = argt
out_argtypes.append(argt.__name__)
render_data.append(
RenderData(
@@ -126,18 +136,15 @@ class SchemaInflatorGenerator:
)
)
convertor_functext = self.template.render(conversions=render_data)
convertor_functext = self.template.render(
funcname=_funcname,
conversions=render_data,
tgt_type=utils.typename(schema)
)
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
convertor_functext = convertor_functext.replace(', )', ')')
namespace = {
'_tgt_type': schema,
'_lookup_table': lookup_table
}
exec(convertor_functext, namespace)
# пихаем сгенеренный метод в табличку,
# ожидаем что она обновится во всех вложенных методах,
# разрешая циклические зависимости
lookup_table[schema.__name__] = namespace['inflate']
return namespace['inflate']
return convertor_functext, namespace

View File

@@ -0,0 +1,6 @@
{% macro render_segment(argname, typename) -%}
{%- set out -%}
{{argname}} = {{typename}}(conv_data)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -1,8 +1,11 @@
{% set ns = namespace(retry_indent=0) %}
from typing import Any
from megasniff.exceptions import MissingFieldException, FieldValidationException
{% import "basic.jinja2" as basic %}
{% import "union.jinja2" as union %}
def inflate(from_data: dict[str, Any]):
def {{funcname}}(from_data: dict[str, Any]):
"""
{{tgt_type}}
"""
from_data_keys = from_data.keys()
{% for conv in conversions %}
@@ -22,10 +25,10 @@ def inflate(from_data: dict[str, Any]):
{% endif %}
else:
try:
{% if conv.constrs | length > 1 or conv.constrs[0][1] is false %}
{{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data)
{% if conv.constrs | length > 1 %}
{{ union.render_segment(conv) | indent(4*4) }}
{% else %}
{{conv.argname}} = {{conv.constrs[0][0]}}(conv_data)
{{ basic.render_segment(conv.argname, conv.constrs[0]) | indent(4*4) }}
{% endif %}
except FieldValidationException as e:
@@ -34,4 +37,4 @@ def inflate(from_data: dict[str, Any]):
{% endfor %}
return _tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %})
return {{funcname}}_tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %})

View File

@@ -0,0 +1,5 @@
def inflate(iterable):
ret = {{ rettype }}()
for item in iterable:
ret.{{ retadd }}(_lookup_table[{{item_id}}](item))
return ret

View File

@@ -1,19 +1,17 @@
{% import "basic.jinja2" as basic %}
{% macro render_segment(conv) -%}
{%- set out -%}
{% set ns = namespace(retry_indent=0) %}
from typing import Any
from megasniff.exceptions import FieldValidationException
def inflate(conv_data: Any):
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type, is_builtin in constrs %}
{{ ' ' * ns.retry_indent }}try:
{% if is_builtin %}
{{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data)
{% else %}
{{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data)
{% endif %}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type in conv.constrs %}
{{ ' ' * ns.retry_indent }}try:
{{ basic.render_segment(conv.argname, union_type) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -11,7 +11,7 @@ def test_basic_constructor():
self.a = a
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A)
fn = infl.schema_to_inflator(A)
a = fn({'a': 42})
assert a.a == 42
@@ -23,7 +23,7 @@ def test_unions():
a: int | str
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A)
fn = infl.schema_to_inflator(A)
a = fn({'a': 42})
assert a.a == 42
@@ -45,10 +45,10 @@ class CircB:
def test_circular():
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(CircA)
fn = infl.schema_to_inflator(CircA)
a = fn({'b': {'a': None}})
return isinstance(a.b, CircB)
assert isinstance(a.b, CircB)
def test_optional():
@@ -57,6 +57,6 @@ def test_optional():
a: Optional[int] = None
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(C)
fn = infl.schema_to_inflator(C)
c = fn({})
assert c.a is None

View File

@@ -12,7 +12,7 @@ def test_missing_field():
a: int
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A)
fn = infl.schema_to_inflator(A)
with pytest.raises(MissingFieldException):
fn({})
@@ -23,7 +23,7 @@ def test_null():
a: int
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A)
fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException):
fn({'a': None})
@@ -34,6 +34,6 @@ def test_invalid_field():
a: float | int | None
infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A)
fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException):
fn({'a': {}})

22
tests/test_iterables.py Normal file
View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
from megasniff import SchemaInflatorGenerator
# def test_list():
# @dataclass
# class A:
# l: list[int]
#
# infl = SchemaInflatorGenerator()
# fn = infl.schema_to_generator(A)
#
# a = fn({'l': []})
# assert isinstance(a.l, list)
# assert len(a.l) == 0
#
# a = fn({'l': [1, 2.1, '0']})
# print(a.l)
# assert isinstance(a.l, list)
# assert len(a.l) == 3
# assert all(map(lambda x: isinstance(x, int), a.l))