From aee6dcf3d38c5d00204e563d5d85288f735ffaa3 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Sat, 12 Jul 2025 04:16:19 +0300 Subject: [PATCH] Extract complex type creation into separate template --- .../__pycache__/utils.cpython-313.pyc | Bin 2752 -> 3019 bytes src/megasniff/inflator.py | 48 +++++++++++++++--- src/megasniff/templates/inflator.jinja2 | 23 ++++----- src/megasniff/templates/union.jinja2 | 19 +++++++ src/megasniff/utils.py | 8 ++- tests/test_basic.py | 16 ++++++ 6 files changed, 93 insertions(+), 21 deletions(-) create mode 100644 src/megasniff/templates/union.jinja2 diff --git a/src/megasniff/__pycache__/utils.cpython-313.pyc b/src/megasniff/__pycache__/utils.cpython-313.pyc index 57ac6b9889ace4d6bdd62eae1b292a39c85e7a2a..fe9bd03069e1949566ba01bdb8128f4e8d49a9de 100644 GIT binary patch delta 722 zcmZWmL1+^}6rIUzb~l@3W14JjQ;VgkExCr+QYj*M5f44M7cXAeT4_3r;m(-s_<_doBw|L4!l7vtMS{V1K*7%u(e zgL^lXo<70f_7;lW+i{4`FjH8vDOpNIGUEkyUjo&XBPL;KmQ;zwu@E&S(vqfX$(BT% zi8|J$EalWCn_A>bpX-^782yCNuv02xGR-6}rGL{{t0mFLTUrG{(6(fA$Rdg$xD~o&G;nAQF&fi)NfoE!>Q?SAxHyV` zz*WRqNCh{+P0;S%ogm)u4A0kl&r{=}?tIy{h4_07&yF6oiPMDZNoRGuf_x=VeH>^W z_SGeHjH&Kxk}(1^z+Oq^A+G6CvwR~^11vb-3{0LY2il)Y^DXJB{Oc@B#*V>9V%bZZYz})+yWhn`fd;$Lx3ooWr%UWzZ$p*YMg!LGTSg0@ z#=7QS#}Vp0hyUesaYhbCagxWWET Callable[[dict[str, Any]], Any]: + _base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]: # Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян type_hints = get_kwargs_type_hints(schema) render_data = [] @@ -72,21 +104,25 @@ class SchemaInflatorGenerator: if not is_builtin and argt is not schema: if argt.__name__ not in lookup_table.keys(): # если случилась циклическая зависимость, мы не хотим бексконечную рекурсию - lookup_table[argt.__name__] = self.schema_to_generator(argt, _base_lookup_table=lookup_table) + lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table) if argt is schema: out_argtypes.append(('inflate', True)) else: out_argtypes.append((argt.__name__, is_builtin)) + if len(argtypes) > 1: + lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table) + render_data.append( RenderData( argname, out_argtypes, - repr(argtype), + utils.typename(argtype), has_default, allow_none, - default_option + default_option, + hash(argtype) ) ) diff --git a/src/megasniff/templates/inflator.jinja2 b/src/megasniff/templates/inflator.jinja2 index 19605e4..ec5fde0 100644 --- a/src/megasniff/templates/inflator.jinja2 +++ b/src/megasniff/templates/inflator.jinja2 @@ -21,21 +21,16 @@ def inflate(from_data: dict[str, Any]): {{conv.argname}} = None {% endif %} else: + try: + {% if conv.constrs | length > 1 or conv.constrs[0][1] is false %} + {{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data) + {% else %} + {{conv.argname}} = {{conv.constrs[0][0]}}(conv_data) + {% endif %} + + except FieldValidationException as e: + raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions) - {% set ns.retry_indent = 0 %} - all_conv_exceptions = [] - {% for union_type, is_builtin in conv.constrs %} - {{ ' ' * ns.retry_indent }}try: - {% if is_builtin %} - {{ ' ' * ns.retry_indent }} {{conv.argname}} = {{union_type}}(conv_data) - {% else %} - {{ ' ' * ns.retry_indent }} {{conv.argname}} = _lookup_table['{{union_type}}'](conv_data) - {% endif %} - {{ ' ' * ns.retry_indent }}except Exception as e: - {{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) - {% set ns.retry_indent = ns.retry_indent + 1 %} - {% endfor %} - {{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions) {% endfor %} diff --git a/src/megasniff/templates/union.jinja2 b/src/megasniff/templates/union.jinja2 new file mode 100644 index 0000000..d180295 --- /dev/null +++ b/src/megasniff/templates/union.jinja2 @@ -0,0 +1,19 @@ +{% set ns = namespace(retry_indent=0) %} +from typing import Any +from megasniff.exceptions import FieldValidationException + +def inflate(conv_data: Any): + {% set ns.retry_indent = 0 %} + all_conv_exceptions = [] + {% for union_type, is_builtin in constrs %} + {{ ' ' * ns.retry_indent }}try: + {% if is_builtin %} + {{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data) + {% else %} + {{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data) + {% endif %} + {{ ' ' * ns.retry_indent }}except Exception as e: + {{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) + {% set ns.retry_indent = ns.retry_indent + 1 %} + {% endfor %} + raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions) diff --git a/src/megasniff/utils.py b/src/megasniff/utils.py index 1319f4e..aaadab0 100644 --- a/src/megasniff/utils.py +++ b/src/megasniff/utils.py @@ -1,6 +1,6 @@ import dataclasses import inspect -from typing import get_type_hints, Any +from typing import get_type_hints, Any, get_origin def is_typed_dict_type(tp: type) -> bool: @@ -50,3 +50,9 @@ def get_field_default(cls: type[Any], field: str) -> tuple[bool, Any]: def is_builtin_type(tp: type) -> bool: return getattr(tp, '__module__', None) == 'builtins' + + +def typename(tp: type) -> str: + if get_origin(tp) is None: + return tp.__name__ + return str(tp) diff --git a/tests/test_basic.py b/tests/test_basic.py index db7202c..a0aca82 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -17,6 +17,22 @@ def test_basic_constructor(): assert a.a == 42 +def test_unions(): + @dataclass + class A: + a: int | str + + infl = SchemaInflatorGenerator() + fn = infl.schema_to_generator(A) + + a = fn({'a': 42}) + assert a.a == 42 + a = fn({'a': '42'}) + assert a.a == 42 + a = fn({'a': '42a'}) + assert a.a == '42a' + + @dataclass class CircA: b: CircB