From 9e3d4d0a2577046a3be356c810cea5e928aea1c2 Mon Sep 17 00:00:00 2001 From: nikto_b Date: Thu, 17 Jul 2025 01:19:44 +0300 Subject: [PATCH] Add signature generation --- pyproject.toml | 2 +- src/megasniff/inflator.py | 42 ++++++++++++------------ src/megasniff/templates/inflator.jinja2 | 2 +- tests/test_signature.py | 43 +++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 24 deletions(-) create mode 100644 tests/test_signature.py diff --git a/pyproject.toml b/pyproject.toml index 25b9371..ca0639d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "megasniff" -version = "0.2.0" +version = "0.2.1" description = "Library for in-time codegened type validation" authors = [ { name = "nikto_b", email = "niktob560@yandex.ru" } diff --git a/src/megasniff/inflator.py b/src/megasniff/inflator.py index 903f380..42e5803 100644 --- a/src/megasniff/inflator.py +++ b/src/megasniff/inflator.py @@ -7,7 +7,7 @@ from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass from types import NoneType, UnionType -from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set +from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType import jinja2 @@ -66,13 +66,19 @@ class SchemaInflatorGenerator: def schema_to_inflator(self, schema: type, - strict_mode_override: Optional[bool] = None) -> Callable[[dict[str, Any]], Any]: - txt, namespace = self._schema_to_inflator(schema, _funcname='inflate', - strict_mode_override=strict_mode_override) + strict_mode_override: Optional[bool] = None, + from_type_override: Optional[type | TypeAliasType] = None + ) -> Callable[[dict[str, Any]], Any]: + if from_type_override is not None and '__getitem__' not in dir(from_type_override): + raise RuntimeError('from_type_override must provide __getitem__') + txt, namespace = self._schema_to_inflator(schema, + _funcname='inflate', + strict_mode_override=strict_mode_override, + from_type_override=from_type_override, + ) imports = ('from typing import Any\n' 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') txt = imports + '\n' + txt - print(txt) exec(txt, namespace) return namespace['inflate'] @@ -113,9 +119,11 @@ class SchemaInflatorGenerator: def _schema_to_inflator(self, schema: type, strict_mode_override: Optional[bool] = None, + from_type_override: Optional[type | TypeAliasType] = None, *, _funcname='inflate', - _namespace=None) -> tuple[str, dict]: + _namespace=None + ) -> tuple[str, dict]: if strict_mode_override is not None: strict_mode = strict_mode_override else: @@ -136,6 +144,9 @@ class SchemaInflatorGenerator: return '', namespace namespace[f'{_funcname}_tgt_type'] = schema + namespace[utils.typename(schema)] = schema + if from_type_override is not None: + namespace['_from_type'] = from_type_override for argname, argtype in type_hints.items(): if argname in {'return', 'self'}: @@ -172,9 +183,7 @@ class SchemaInflatorGenerator: default_option, ) ) - # - # out_argtypes: list[str] = [] - # + for argt in argtypes: is_builtin = is_builtin_type(argt) @@ -192,23 +201,12 @@ class SchemaInflatorGenerator: pass else: namespace[argt.__name__] = argt - # - # render_data.append( - # FieldRenderData( - # argname, - # out_argtypes, - # utils.typename(argtype), - # has_default, - # allow_none, - # default_option, - # type_origin is not None and type_origin.__name__ == 'list', - # ) - # ) convertor_functext = self.template.render( funcname=_funcname, conversions=render_data, - tgt_type=utils.typename(schema) + tgt_type=utils.typename(schema), + from_type='_from_type' if from_type_override is not None else None ) convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext diff --git a/src/megasniff/templates/inflator.jinja2 b/src/megasniff/templates/inflator.jinja2 index e662c89..21351b8 100644 --- a/src/megasniff/templates/inflator.jinja2 +++ b/src/megasniff/templates/inflator.jinja2 @@ -2,7 +2,7 @@ {% import "unwrap_type_data.jinja2" as unwrap_type_data %} -def {{funcname}}(from_data: dict[str, Any]): +def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{from_type}}{% endif %}) {% if tgt_type is not none %} -> {{tgt_type}} {% endif %}: """ {{tgt_type}} """ diff --git a/tests/test_signature.py b/tests/test_signature.py new file mode 100644 index 0000000..aa03135 --- /dev/null +++ b/tests/test_signature.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import get_type_hints, Any, Annotated + +from megasniff import SchemaInflatorGenerator + + +def test_return_signature(): + @dataclass + class A: + a: list[int] + + infl = SchemaInflatorGenerator(strict_mode=True) + fn = infl.schema_to_inflator(A) + + hints = get_type_hints(fn) + assert hints['return'] == A + assert len(hints) == 2 + + +def test_argument_signature(): + @dataclass + class A: + a: list[int] + + infl = SchemaInflatorGenerator(strict_mode=True) + + type custom_from_type = dict[str, Any] + + fn1 = infl.schema_to_inflator(A, from_type_override=custom_from_type) + + fn2 = infl.schema_to_inflator(A) + + hints = get_type_hints(fn1) + assert hints['return'] == A + assert len(hints) == 2 + assert hints['from_data'] == custom_from_type + assert hints['from_data'] != dict[str, Any] + + hints = get_type_hints(fn2) + assert hints['return'] == A + assert len(hints) == 2 + assert hints['from_data'] != custom_from_type + assert hints['from_data'] == dict[str, Any]