diff --git a/src/megasniff/__main__.py b/src/megasniff/__main__.py index 3fa9088..e2e51d6 100644 --- a/src/megasniff/__main__.py +++ b/src/megasniff/__main__.py @@ -8,7 +8,7 @@ import megasniff.exceptions from . import SchemaInflatorGenerator -@dataclass +@dataclass(frozen=True) class ASchema: a: int b: float | str @@ -23,18 +23,33 @@ class BSchema(TypedDict): d: ASchema +@dataclass +class CSchema: + l: set[int | ASchema] + + def main(): infl = SchemaInflatorGenerator() - t, n = infl._schema_to_generator(ASchema) - print(t) - print(n) - exec(t, n) - fn = n['inflate'] + fn = infl.schema_to_inflator(ASchema) + # print(t) + # print(n) + # exec(t, n) + # fn = n['inflate'] # fn = infl.schema_to_generator(ASchema) # # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}} - d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}} + # d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}} + # d = {'a': 2, 'b': 2, 'bs': {'a': 2, 'b': 'a', 'c': 0, 'd': {'a': 2, 'b': 2}}} + # d = {'l': ['1', {'a': 42, 'b': 1}]} + d = {'a': 2, 'b': 2, 'bs': None} try: - print(fn(d)) + o = fn(d) + print(o) + for k, v in o.__dict__.items(): + print(f'field {k}: {v}') + print(f'type: {type(v)}') + if isinstance(v, list): + for vi in v: + print(f'\ttype: {type(vi)}') except megasniff.exceptions.FieldValidationException as e: print(e.exceptions) print(e) diff --git a/src/megasniff/__pycache__/__main__.cpython-313.pyc b/src/megasniff/__pycache__/__main__.cpython-313.pyc index c445e90..e22ffca 100644 Binary files a/src/megasniff/__pycache__/__main__.cpython-313.pyc and b/src/megasniff/__pycache__/__main__.cpython-313.pyc differ diff --git a/src/megasniff/inflator.py b/src/megasniff/inflator.py index b8135e2..ea3e803 100644 --- a/src/megasniff/inflator.py +++ b/src/megasniff/inflator.py @@ -1,12 +1,13 @@ # Copyright (C) 2025 Shevchenko A # SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import annotations import collections.abc import importlib.resources from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass from types import NoneType, UnionType -from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence +from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set import jinja2 @@ -15,21 +16,28 @@ from .utils import * @dataclass -class RenderData: +class TypeRenderData: + typeref: list[TypeRenderData] | TypeRenderData | str + allow_none: bool + is_list: bool + is_union: bool + + +@dataclass +class IterableTypeRenderData(TypeRenderData): + iterable_type: str + is_list = True + is_union = False + + +@dataclass +class FieldRenderData: argname: str - constrs: list[str] # typecall / use lookup table + constrs: TypeRenderData typename: str is_optional: bool allow_none: bool default_option: Optional[str] - typeid: int - - -@dataclass -class SchemaInflatorTemplateSettings: - general: str = 'inflator.jinja2' - union: str = 'union.jinja2' - iterable: str = 'iterable.jinja2' class SchemaInflatorGenerator: @@ -37,17 +45,10 @@ class SchemaInflatorGenerator: templateEnv: jinja2.Environment template: jinja2.Template - union_template: jinja2.Template - iterable_template: jinja2.Template - - settings: SchemaInflatorTemplateSettings def __init__(self, loader: Optional[jinja2.BaseLoader] = None, - template_settings: Optional[SchemaInflatorTemplateSettings] = None): - - if template_settings is None: - template_settings = SchemaInflatorTemplateSettings() + template_filename: str = 'inflator.jinja2'): if loader is None: template_path = importlib.resources.files('megasniff.templates') @@ -55,17 +56,46 @@ class SchemaInflatorGenerator: self.templateLoader = loader self.templateEnv = jinja2.Environment(loader=self.templateLoader) - self.template = self.templateEnv.get_template(template_settings.general) - self.union_template = self.templateEnv.get_template(template_settings.union) - self.iterable_template = self.templateEnv.get_template(template_settings.iterable) + self.template = self.templateEnv.get_template(template_filename) def schema_to_inflator(self, schema: type) -> Callable[[dict[str, Any]], Any]: txt, namespace = self._schema_to_inflator(schema, _funcname='inflate') - txt = ('from typing import Any\n' - 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt + imports = ('from typing import Any\n' + 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') + txt = imports + '\n' + txt exec(txt, namespace) return namespace['inflate'] + def _unwrap_typeref(self, t: type) -> TypeRenderData: + type_origin = get_origin(t) + allow_none = False + argtypes = t, + + if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])): + argtypes = get_args(t) + + if NoneType in argtypes or None in argtypes: + argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) + allow_none = True + + is_union = len(argtypes) > 1 + + if is_union: + typerefs = list(map(lambda x: self._unwrap_typeref(x), argtypes)) + return TypeRenderData(typerefs, allow_none, False, True) + elif type_origin in [list, set]: + rd = self._unwrap_typeref(argtypes[0]) + return IterableTypeRenderData(rd, allow_none, True, False, type_origin.__name__) + else: + t = argtypes[0] + + is_list = (type_origin or t) in [list, set] + if is_list: + t = type_origin or t + + is_builtin = is_builtin_type(t) + return TypeRenderData(t.__name__ if is_builtin else f'inflate_{t.__name__}', allow_none, is_list, False) + def _schema_to_inflator(self, schema: type, *, @@ -92,49 +122,67 @@ class SchemaInflatorGenerator: continue has_default, default_option = get_field_default(schema, argname) + + typeref = self._unwrap_typeref(argtype) + argtypes = argtype, - type_origin = get_origin(argtype) allow_none = False - if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated])): - argtypes = get_args(argtype) + while get_origin(argtype) is not None: + type_origin = get_origin(argtype) + + if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])): + argtypes = get_args(argtype) + if len(argtypes) == 1: + argtype = argtypes[0] + else: + break if NoneType in argtypes or None in argtypes: argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) allow_none = True - out_argtypes: list[str] = [] - - for argt in argtypes: - is_builtin = is_builtin_type(argt) - if not is_builtin and argt is not schema: - # если случилась циклическая зависимость, мы не хотим бексконечную рекурсию - if argt.__name__ not in namespace.keys(): - t, n = self._schema_to_inflator(argt, _funcname=f'inflate_{argt.__name__}', - _namespace=namespace) - namespace |= n - txt_segments.append(t) - out_argtypes.append(f'inflate_{argt.__name__}') - # lookup_table[hash(argt)] = infl - # namespace[argt.__name__] = infl - - elif argt is schema: - out_argtypes.append(_funcname) - else: - namespace[argt.__name__] = argt - out_argtypes.append(argt.__name__) - render_data.append( - RenderData( + FieldRenderData( argname, - out_argtypes, + typeref, utils.typename(argtype), has_default, allow_none, default_option, - hash(argtype) ) ) + # + # out_argtypes: list[str] = [] + # + for argt in argtypes: + + is_builtin = is_builtin_type(argt) + if not is_builtin and argt is not schema: + # если случилась циклическая зависимость, мы не хотим бексконечную рекурсию + if argt.__name__ not in namespace.keys(): + t, n = self._schema_to_inflator(argt, + _funcname=f'inflate_{argt.__name__}', + _namespace=namespace) + namespace |= n + txt_segments.append(t) + + elif argt is schema: + pass + else: + namespace[argt.__name__] = argt + # + # render_data.append( + # FieldRenderData( + # argname, + # out_argtypes, + # utils.typename(argtype), + # has_default, + # allow_none, + # default_option, + # type_origin is not None and type_origin.__name__ == 'list', + # ) + # ) convertor_functext = self.template.render( funcname=_funcname, diff --git a/src/megasniff/templates/basic.jinja2 b/src/megasniff/templates/basic.jinja2 deleted file mode 100644 index c94b1be..0000000 --- a/src/megasniff/templates/basic.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{% macro render_segment(argname, typename) -%} -{%- set out -%} -{{argname}} = {{typename}}(conv_data) -{%- endset %} -{{out}} -{%- endmacro %} \ No newline at end of file diff --git a/src/megasniff/templates/inflator.jinja2 b/src/megasniff/templates/inflator.jinja2 index 17dcfe0..f093ce7 100644 --- a/src/megasniff/templates/inflator.jinja2 +++ b/src/megasniff/templates/inflator.jinja2 @@ -1,6 +1,6 @@ {% set ns = namespace(retry_indent=0) %} -{% import "basic.jinja2" as basic %} -{% import "union.jinja2" as union %} +{% import "unwrap_type_data.jinja2" as unwrap_type_data %} + def {{funcname}}(from_data: dict[str, Any]): """ @@ -24,16 +24,8 @@ def {{funcname}}(from_data: dict[str, Any]): {{conv.argname}} = None {% endif %} else: - try: - {% if conv.constrs | length > 1 %} -{{ union.render_segment(conv) | indent(4*4) }} - {% else %} -{{ basic.render_segment(conv.argname, conv.constrs[0]) | indent(4*4) }} - {% endif %} - - except FieldValidationException as e: - raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions) +{{ unwrap_type_data.render_segment(conv.argname, conv.constrs, "conv_data") | indent(4*3) }} {% endfor %} diff --git a/src/megasniff/templates/iterable.jinja2 b/src/megasniff/templates/iterable.jinja2 deleted file mode 100644 index 1a88b22..0000000 --- a/src/megasniff/templates/iterable.jinja2 +++ /dev/null @@ -1,5 +0,0 @@ -def inflate(iterable): - ret = {{ rettype }}() - for item in iterable: - ret.{{ retadd }}(_lookup_table[{{item_id}}](item)) - return ret \ No newline at end of file diff --git a/src/megasniff/templates/union.jinja2 b/src/megasniff/templates/union.jinja2 deleted file mode 100644 index 17a5b1c..0000000 --- a/src/megasniff/templates/union.jinja2 +++ /dev/null @@ -1,17 +0,0 @@ -{% import "basic.jinja2" as basic %} -{% macro render_segment(conv) -%} -{%- set out -%} -{% set ns = namespace(retry_indent=0) %} -{% set ns.retry_indent = 0 %} -all_conv_exceptions = [] -{% for union_type in conv.constrs %} -{{ ' ' * ns.retry_indent }}try: -{{ basic.render_segment(conv.argname, union_type) | indent((ns.retry_indent + 1) * 4) }} -{{ ' ' * ns.retry_indent }}except Exception as e: -{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) -{% set ns.retry_indent = ns.retry_indent + 1 %} -{% endfor %} -{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions) -{%- endset %} -{{out}} -{%- endmacro %} \ No newline at end of file diff --git a/src/megasniff/templates/unwrap_type_data.jinja2 b/src/megasniff/templates/unwrap_type_data.jinja2 new file mode 100644 index 0000000..8eb296d --- /dev/null +++ b/src/megasniff/templates/unwrap_type_data.jinja2 @@ -0,0 +1,51 @@ +{% macro render_iterable(argname, typedef, conv_data) -%} +{%- set out -%} +{{argname}} = [] +if not isinstance({{conv_data}}, list): + raise FieldValidationException('{{argname}}', "list", conv_data, []) +for item in {{conv_data}}: +{{ render_segment("_" + argname, typedef, "item") | indent(4) }} + {{argname}}.append(_{{argname}}) +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_union(argname, conv, conv_data) -%} +{%- set out -%} +# unwrapping union {{conv}} +{% set ns = namespace(retry_indent=0) %} +{% set ns.retry_indent = 0 %} +all_conv_exceptions = [] +{% for union_type in conv.typeref %} +{{ ' ' * ns.retry_indent }}try: +{{ render_segment(argname, union_type, conv_data) | indent((ns.retry_indent + 1) * 4) }} +{{ ' ' * ns.retry_indent }}except Exception as e: +{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e) +{% set ns.retry_indent = ns.retry_indent + 1 %} +{% endfor %} +{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions) +{%- endset %} +{{out}} +{%- endmacro %} + +{% macro render_segment(argname, typeref, conv_data) -%} +{%- set out -%} + +{% if typeref is string %} +{{argname}} = {{typeref}}({{conv_data}}) + +{% elif typeref.is_union %} +{{render_union(argname, typeref, conv_data)}} + +{% elif typeref.is_list %} +{{render_iterable(argname, typeref.typeref, conv_data)}} +{{argname}} = {{typeref.iterable_type}}({{argname}}) + +{% else %} +{{render_segment(argname, typeref.typeref, conv_data)}} + +{% endif %} + +{%- endset %} +{{out}} +{%- endmacro %} \ No newline at end of file diff --git a/tests/test_iterables.py b/tests/test_iterables.py index 9493feb..9079d22 100644 --- a/tests/test_iterables.py +++ b/tests/test_iterables.py @@ -3,20 +3,85 @@ from dataclasses import dataclass from megasniff import SchemaInflatorGenerator -# def test_list(): -# @dataclass -# class A: -# l: list[int] -# -# infl = SchemaInflatorGenerator() -# fn = infl.schema_to_generator(A) -# -# a = fn({'l': []}) -# assert isinstance(a.l, list) -# assert len(a.l) == 0 -# -# a = fn({'l': [1, 2.1, '0']}) -# print(a.l) -# assert isinstance(a.l, list) -# assert len(a.l) == 3 -# assert all(map(lambda x: isinstance(x, int), a.l)) +def test_list_basic(): + @dataclass + class A: + l: list[int] + + infl = SchemaInflatorGenerator() + fn = infl.schema_to_inflator(A) + + a = fn({'l': []}) + assert isinstance(a.l, list) + assert len(a.l) == 0 + + a = fn({'l': [1, 2.1, '0']}) + print(a.l) + assert isinstance(a.l, list) + assert len(a.l) == 3 + assert all(map(lambda x: isinstance(x, int), a.l)) + + @dataclass + class B: + l: list[str] + + fn = infl.schema_to_inflator(B) + + a = fn({'l': [1, 2.1, '0']}) + print(a.l) + assert isinstance(a.l, list) + assert len(a.l) == 3 + assert all(map(lambda x: isinstance(x, str), a.l)) + assert a.l == ['1', '2.1', '0'] + + +def test_list_union(): + @dataclass + class A: + l: list[int | str] + + infl = SchemaInflatorGenerator() + fn = infl.schema_to_inflator(A) + + a = fn({'l': []}) + assert isinstance(a.l, list) + assert len(a.l) == 0 + + a = fn({'l': [1, 2.1, '0']}) + print(a.l) + assert isinstance(a.l, list) + assert len(a.l) == 3 + assert all(map(lambda x: isinstance(x, int), a.l)) + + +def test_set_basic(): + + @dataclass + class A: + l: set[int] + + infl = SchemaInflatorGenerator() + fn = infl.schema_to_inflator(A) + + a = fn({'l': []}) + assert isinstance(a.l, set) + assert len(a.l) == 0 + + a = fn({'l': [1, 2.1, '0']}) + print(a.l) + assert isinstance(a.l, set) + assert len(a.l) == 3 + assert all(map(lambda x: isinstance(x, int), a.l)) + + @dataclass + class B: + l: set[str] + + fn = infl.schema_to_inflator(B) + + a = fn({'l': [1, 2.1, '0', 0]}) + print(a.l) + assert isinstance(a.l, set) + assert len(a.l) == 3 + assert all(map(lambda x: isinstance(x, str), a.l)) + assert a.l == {'1', '2.1', '0'}