Extract complex type creation into separate template
This commit is contained in:
Binary file not shown.
@@ -9,6 +9,7 @@ from typing import Optional, get_origin, get_args, Union, Annotated
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
from . import utils
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
|
||||||
|
|
||||||
@@ -20,27 +21,58 @@ class RenderData:
|
|||||||
is_optional: bool
|
is_optional: bool
|
||||||
allow_none: bool
|
allow_none: bool
|
||||||
default_option: Optional[str]
|
default_option: Optional[str]
|
||||||
|
typeid: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SchemaInflatorTemplateSettings:
|
||||||
|
general: str = 'inflator.jinja2'
|
||||||
|
union: str = 'union.jinja2'
|
||||||
|
|
||||||
|
|
||||||
class SchemaInflatorGenerator:
|
class SchemaInflatorGenerator:
|
||||||
templateLoader: jinja2.BaseLoader
|
templateLoader: jinja2.BaseLoader
|
||||||
templateEnv: jinja2.Environment
|
templateEnv: jinja2.Environment
|
||||||
template: jinja2.Template
|
template: jinja2.Template
|
||||||
|
union_template: jinja2.Template
|
||||||
|
settings: SchemaInflatorTemplateSettings
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
loader: Optional[jinja2.BaseLoader] = None,
|
loader: Optional[jinja2.BaseLoader] = None,
|
||||||
convertor_template: str = 'inflator.jinja2'):
|
template_settings: Optional[SchemaInflatorTemplateSettings] = None):
|
||||||
|
|
||||||
|
if template_settings is None:
|
||||||
|
template_settings = SchemaInflatorTemplateSettings()
|
||||||
|
|
||||||
if loader is None:
|
if loader is None:
|
||||||
template_path = importlib.resources.files('megasniff.templates')
|
template_path = importlib.resources.files('megasniff.templates')
|
||||||
loader = jinja2.FileSystemLoader(str(template_path))
|
loader = jinja2.FileSystemLoader(str(template_path))
|
||||||
|
|
||||||
self.templateLoader = loader
|
self.templateLoader = loader
|
||||||
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
|
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
|
||||||
self.template = self.templateEnv.get_template(convertor_template)
|
self.template = self.templateEnv.get_template(template_settings.general)
|
||||||
|
self.union_template = self.templateEnv.get_template(template_settings.union)
|
||||||
|
|
||||||
|
def _union_inflator(self,
|
||||||
|
argname: str,
|
||||||
|
argtype: str,
|
||||||
|
constrs: list[tuple[str, bool]],
|
||||||
|
lookup_table: dict[str, Any]):
|
||||||
|
txt = self.union_template.render(
|
||||||
|
argname=argname,
|
||||||
|
typename=argtype,
|
||||||
|
constrs=constrs
|
||||||
|
)
|
||||||
|
namespace = {
|
||||||
|
'_lookup_table': lookup_table
|
||||||
|
}
|
||||||
|
exec(txt, namespace)
|
||||||
|
return namespace['inflate']
|
||||||
|
|
||||||
def schema_to_generator(self,
|
def schema_to_generator(self,
|
||||||
schema: type,
|
schema: type,
|
||||||
*,
|
*,
|
||||||
_base_lookup_table: Optional[dict[str, Any]] = None) -> Callable[[dict[str, Any]], Any]:
|
_base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]:
|
||||||
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
|
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
|
||||||
type_hints = get_kwargs_type_hints(schema)
|
type_hints = get_kwargs_type_hints(schema)
|
||||||
render_data = []
|
render_data = []
|
||||||
@@ -72,21 +104,25 @@ class SchemaInflatorGenerator:
|
|||||||
if not is_builtin and argt is not schema:
|
if not is_builtin and argt is not schema:
|
||||||
if argt.__name__ not in lookup_table.keys():
|
if argt.__name__ not in lookup_table.keys():
|
||||||
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
|
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
|
||||||
lookup_table[argt.__name__] = self.schema_to_generator(argt, _base_lookup_table=lookup_table)
|
lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table)
|
||||||
|
|
||||||
if argt is schema:
|
if argt is schema:
|
||||||
out_argtypes.append(('inflate', True))
|
out_argtypes.append(('inflate', True))
|
||||||
else:
|
else:
|
||||||
out_argtypes.append((argt.__name__, is_builtin))
|
out_argtypes.append((argt.__name__, is_builtin))
|
||||||
|
|
||||||
|
if len(argtypes) > 1:
|
||||||
|
lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table)
|
||||||
|
|
||||||
render_data.append(
|
render_data.append(
|
||||||
RenderData(
|
RenderData(
|
||||||
argname,
|
argname,
|
||||||
out_argtypes,
|
out_argtypes,
|
||||||
repr(argtype),
|
utils.typename(argtype),
|
||||||
has_default,
|
has_default,
|
||||||
allow_none,
|
allow_none,
|
||||||
default_option
|
default_option,
|
||||||
|
hash(argtype)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,21 +21,16 @@ def inflate(from_data: dict[str, Any]):
|
|||||||
{{conv.argname}} = None
|
{{conv.argname}} = None
|
||||||
{% endif %}
|
{% endif %}
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
{% if conv.constrs | length > 1 or conv.constrs[0][1] is false %}
|
||||||
|
{{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data)
|
||||||
|
{% else %}
|
||||||
|
{{conv.argname}} = {{conv.constrs[0][0]}}(conv_data)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
except FieldValidationException as e:
|
||||||
|
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions)
|
||||||
|
|
||||||
{% set ns.retry_indent = 0 %}
|
|
||||||
all_conv_exceptions = []
|
|
||||||
{% for union_type, is_builtin in conv.constrs %}
|
|
||||||
{{ ' ' * ns.retry_indent }}try:
|
|
||||||
{% if is_builtin %}
|
|
||||||
{{ ' ' * ns.retry_indent }} {{conv.argname}} = {{union_type}}(conv_data)
|
|
||||||
{% else %}
|
|
||||||
{{ ' ' * ns.retry_indent }} {{conv.argname}} = _lookup_table['{{union_type}}'](conv_data)
|
|
||||||
{% endif %}
|
|
||||||
{{ ' ' * ns.retry_indent }}except Exception as e:
|
|
||||||
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
|
|
||||||
{% set ns.retry_indent = ns.retry_indent + 1 %}
|
|
||||||
{% endfor %}
|
|
||||||
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
|
|
||||||
|
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|||||||
19
src/megasniff/templates/union.jinja2
Normal file
19
src/megasniff/templates/union.jinja2
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{% set ns = namespace(retry_indent=0) %}
|
||||||
|
from typing import Any
|
||||||
|
from megasniff.exceptions import FieldValidationException
|
||||||
|
|
||||||
|
def inflate(conv_data: Any):
|
||||||
|
{% set ns.retry_indent = 0 %}
|
||||||
|
all_conv_exceptions = []
|
||||||
|
{% for union_type, is_builtin in constrs %}
|
||||||
|
{{ ' ' * ns.retry_indent }}try:
|
||||||
|
{% if is_builtin %}
|
||||||
|
{{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data)
|
||||||
|
{% else %}
|
||||||
|
{{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data)
|
||||||
|
{% endif %}
|
||||||
|
{{ ' ' * ns.retry_indent }}except Exception as e:
|
||||||
|
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
|
||||||
|
{% set ns.retry_indent = ns.retry_indent + 1 %}
|
||||||
|
{% endfor %}
|
||||||
|
raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
from typing import get_type_hints, Any
|
from typing import get_type_hints, Any, get_origin
|
||||||
|
|
||||||
|
|
||||||
def is_typed_dict_type(tp: type) -> bool:
|
def is_typed_dict_type(tp: type) -> bool:
|
||||||
@@ -50,3 +50,9 @@ def get_field_default(cls: type[Any], field: str) -> tuple[bool, Any]:
|
|||||||
|
|
||||||
def is_builtin_type(tp: type) -> bool:
|
def is_builtin_type(tp: type) -> bool:
|
||||||
return getattr(tp, '__module__', None) == 'builtins'
|
return getattr(tp, '__module__', None) == 'builtins'
|
||||||
|
|
||||||
|
|
||||||
|
def typename(tp: type) -> str:
|
||||||
|
if get_origin(tp) is None:
|
||||||
|
return tp.__name__
|
||||||
|
return str(tp)
|
||||||
|
|||||||
@@ -17,6 +17,22 @@ def test_basic_constructor():
|
|||||||
assert a.a == 42
|
assert a.a == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_unions():
|
||||||
|
@dataclass
|
||||||
|
class A:
|
||||||
|
a: int | str
|
||||||
|
|
||||||
|
infl = SchemaInflatorGenerator()
|
||||||
|
fn = infl.schema_to_generator(A)
|
||||||
|
|
||||||
|
a = fn({'a': 42})
|
||||||
|
assert a.a == 42
|
||||||
|
a = fn({'a': '42'})
|
||||||
|
assert a.a == 42
|
||||||
|
a = fn({'a': '42a'})
|
||||||
|
assert a.a == '42a'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CircA:
|
class CircA:
|
||||||
b: CircB
|
b: CircB
|
||||||
|
|||||||
Reference in New Issue
Block a user