diff --git a/src/megasniff/__pycache__/utils.cpython-313.pyc b/src/megasniff/__pycache__/utils.cpython-313.pyc index 57ac6b9..fe9bd03 100644 Binary files a/src/megasniff/__pycache__/utils.cpython-313.pyc and b/src/megasniff/__pycache__/utils.cpython-313.pyc differ diff --git a/src/megasniff/inflator.py b/src/megasniff/inflator.py index 1cde799..47ae61b 100644 --- a/src/megasniff/inflator.py +++ b/src/megasniff/inflator.py @@ -9,6 +9,7 @@ from typing import Optional, get_origin, get_args, Union, Annotated import jinja2 +from . import utils from .utils import * @@ -20,27 +21,58 @@ class RenderData: is_optional: bool allow_none: bool default_option: Optional[str] + typeid: int + + +@dataclass +class SchemaInflatorTemplateSettings: + general: str = 'inflator.jinja2' + union: str = 'union.jinja2' class SchemaInflatorGenerator: templateLoader: jinja2.BaseLoader templateEnv: jinja2.Environment template: jinja2.Template + union_template: jinja2.Template + settings: SchemaInflatorTemplateSettings def __init__(self, loader: Optional[jinja2.BaseLoader] = None, - convertor_template: str = 'inflator.jinja2'): + template_settings: Optional[SchemaInflatorTemplateSettings] = None): + + if template_settings is None: + template_settings = SchemaInflatorTemplateSettings() + if loader is None: template_path = importlib.resources.files('megasniff.templates') loader = jinja2.FileSystemLoader(str(template_path)) + self.templateLoader = loader self.templateEnv = jinja2.Environment(loader=self.templateLoader) - self.template = self.templateEnv.get_template(convertor_template) + self.template = self.templateEnv.get_template(template_settings.general) + self.union_template = self.templateEnv.get_template(template_settings.union) + + def _union_inflator(self, + argname: str, + argtype: str, + constrs: list[tuple[str, bool]], + lookup_table: dict[str, Any]): + txt = self.union_template.render( + argname=argname, + typename=argtype, + constrs=constrs + ) + namespace = { + '_lookup_table': lookup_table + } + exec(txt, namespace) + return namespace['inflate'] def schema_to_generator(self, schema: type, *, - _base_lookup_table: Optional[dict[str, Any]] = None) -> Callable[[dict[str, Any]], Any]: + _base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]: # Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян type_hints = get_kwargs_type_hints(schema) render_data = [] @@ -72,21 +104,25 @@ class SchemaInflatorGenerator: if not is_builtin and argt is not schema: if argt.__name__ not in lookup_table.keys(): # если случилась циклическая зависимость, мы не хотим бексконечную рекурсию - lookup_table[argt.__name__] = self.schema_to_generator(argt, _base_lookup_table=lookup_table) + lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table) if argt is schema: out_argtypes.append(('inflate', True)) else: out_argtypes.append((argt.__name__, is_builtin)) + if len(argtypes) > 1: + lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table) + render_data.append( RenderData( argname, out_argtypes, - repr(argtype), + utils.typename(argtype), has_default, allow_none, - default_option + default_option, + hash(argtype) ) ) diff --git a/src/megasniff/templates/inflator.jinja2 b/src/megasniff/templates/inflator.jinja2 index 19605e4..ec5fde0 100644 --- a/src/megasniff/templates/inflator.jinja2 +++ b/src/megasniff/templates/inflator.jinja2 @@ -21,21 +21,16 @@ def inflate(from_data: dict[str, Any]): {{conv.argname}} = None {% endif %} else: + try: + {% if conv.constrs | length > 1 or conv.constrs[0][1] is false %} + {{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data) + {% else %} + {{conv.argname}} = {{conv.constrs[0][0]}}(conv_data) + {% endif %} + + except FieldValidationException as e: + raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions) - {% set ns.retry_indent = 0 %} - all_conv_exceptions = [] - {% for union_type, is_builtin in conv.constrs %} - {{ ' ' * ns.retry_indent }}try: - {% if is_builtin %} - {{ ' ' * ns.retry_indent }} {{conv.argname}} = {{union_type}}(conv_data) - {% else %} - {{ ' ' * ns.retry_indent }} {{conv.argname}} = _lookup_table['{{union_type}}'](conv_data) - {% endif %} - {{ ' ' * ns.retry_indent }}except Exception as e: - {{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) - {% set ns.retry_indent = ns.retry_indent + 1 %} - {% endfor %} - {{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions) {% endfor %} diff --git a/src/megasniff/templates/union.jinja2 b/src/megasniff/templates/union.jinja2 new file mode 100644 index 0000000..d180295 --- /dev/null +++ b/src/megasniff/templates/union.jinja2 @@ -0,0 +1,19 @@ +{% set ns = namespace(retry_indent=0) %} +from typing import Any +from megasniff.exceptions import FieldValidationException + +def inflate(conv_data: Any): + {% set ns.retry_indent = 0 %} + all_conv_exceptions = [] + {% for union_type, is_builtin in constrs %} + {{ ' ' * ns.retry_indent }}try: + {% if is_builtin %} + {{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data) + {% else %} + {{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data) + {% endif %} + {{ ' ' * ns.retry_indent }}except Exception as e: + {{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) + {% set ns.retry_indent = ns.retry_indent + 1 %} + {% endfor %} + raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions) diff --git a/src/megasniff/utils.py b/src/megasniff/utils.py index 1319f4e..aaadab0 100644 --- a/src/megasniff/utils.py +++ b/src/megasniff/utils.py @@ -1,6 +1,6 @@ import dataclasses import inspect -from typing import get_type_hints, Any +from typing import get_type_hints, Any, get_origin def is_typed_dict_type(tp: type) -> bool: @@ -50,3 +50,9 @@ def get_field_default(cls: type[Any], field: str) -> tuple[bool, Any]: def is_builtin_type(tp: type) -> bool: return getattr(tp, '__module__', None) == 'builtins' + + +def typename(tp: type) -> str: + if get_origin(tp) is None: + return tp.__name__ + return str(tp) diff --git a/tests/test_basic.py b/tests/test_basic.py index db7202c..a0aca82 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -17,6 +17,22 @@ def test_basic_constructor(): assert a.a == 42 +def test_unions(): + @dataclass + class A: + a: int | str + + infl = SchemaInflatorGenerator() + fn = infl.schema_to_generator(A) + + a = fn({'a': 42}) + assert a.a == 42 + a = fn({'a': '42'}) + assert a.a == 42 + a = fn({'a': '42a'}) + assert a.a == '42a' + + @dataclass class CircA: b: CircB