Add signature generation

This commit is contained in:
2025-07-17 01:19:44 +03:00
parent 9fc218e556
commit 9e3d4d0a25
4 changed files with 65 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "megasniff" name = "megasniff"
version = "0.2.0" version = "0.2.1"
description = "Library for in-time codegened type validation" description = "Library for in-time codegened type validation"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType
import jinja2 import jinja2
@@ -66,13 +66,19 @@ class SchemaInflatorGenerator:
def schema_to_inflator(self, def schema_to_inflator(self,
schema: type, schema: type,
strict_mode_override: Optional[bool] = None) -> Callable[[dict[str, Any]], Any]: strict_mode_override: Optional[bool] = None,
txt, namespace = self._schema_to_inflator(schema, _funcname='inflate', from_type_override: Optional[type | TypeAliasType] = None
strict_mode_override=strict_mode_override) ) -> Callable[[dict[str, Any]], Any]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_inflator(schema,
_funcname='inflate',
strict_mode_override=strict_mode_override,
from_type_override=from_type_override,
)
imports = ('from typing import Any\n' imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt txt = imports + '\n' + txt
print(txt)
exec(txt, namespace) exec(txt, namespace)
return namespace['inflate'] return namespace['inflate']
@@ -113,9 +119,11 @@ class SchemaInflatorGenerator:
def _schema_to_inflator(self, def _schema_to_inflator(self,
schema: type, schema: type,
strict_mode_override: Optional[bool] = None, strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None,
*, *,
_funcname='inflate', _funcname='inflate',
_namespace=None) -> tuple[str, dict]: _namespace=None
) -> tuple[str, dict]:
if strict_mode_override is not None: if strict_mode_override is not None:
strict_mode = strict_mode_override strict_mode = strict_mode_override
else: else:
@@ -136,6 +144,9 @@ class SchemaInflatorGenerator:
return '', namespace return '', namespace
namespace[f'{_funcname}_tgt_type'] = schema namespace[f'{_funcname}_tgt_type'] = schema
namespace[utils.typename(schema)] = schema
if from_type_override is not None:
namespace['_from_type'] = from_type_override
for argname, argtype in type_hints.items(): for argname, argtype in type_hints.items():
if argname in {'return', 'self'}: if argname in {'return', 'self'}:
@@ -172,9 +183,7 @@ class SchemaInflatorGenerator:
default_option, default_option,
) )
) )
#
# out_argtypes: list[str] = []
#
for argt in argtypes: for argt in argtypes:
is_builtin = is_builtin_type(argt) is_builtin = is_builtin_type(argt)
@@ -192,23 +201,12 @@ class SchemaInflatorGenerator:
pass pass
else: else:
namespace[argt.__name__] = argt namespace[argt.__name__] = argt
#
# render_data.append(
# FieldRenderData(
# argname,
# out_argtypes,
# utils.typename(argtype),
# has_default,
# allow_none,
# default_option,
# type_origin is not None and type_origin.__name__ == 'list',
# )
# )
convertor_functext = self.template.render( convertor_functext = self.template.render(
funcname=_funcname, funcname=_funcname,
conversions=render_data, conversions=render_data,
tgt_type=utils.typename(schema) tgt_type=utils.typename(schema),
from_type='_from_type' if from_type_override is not None else None
) )
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext

View File

@@ -2,7 +2,7 @@
{% import "unwrap_type_data.jinja2" as unwrap_type_data %} {% import "unwrap_type_data.jinja2" as unwrap_type_data %}
def {{funcname}}(from_data: dict[str, Any]): def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{from_type}}{% endif %}) {% if tgt_type is not none %} -> {{tgt_type}} {% endif %}:
""" """
{{tgt_type}} {{tgt_type}}
""" """

43
tests/test_signature.py Normal file
View File

@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import get_type_hints, Any, Annotated
from megasniff import SchemaInflatorGenerator
def test_return_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(A)
hints = get_type_hints(fn)
assert hints['return'] == A
assert len(hints) == 2
def test_argument_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
type custom_from_type = dict[str, Any]
fn1 = infl.schema_to_inflator(A, from_type_override=custom_from_type)
fn2 = infl.schema_to_inflator(A)
hints = get_type_hints(fn1)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] == custom_from_type
assert hints['from_data'] != dict[str, Any]
hints = get_type_hints(fn2)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] != custom_from_type
assert hints['from_data'] == dict[str, Any]