Add signature generation

This commit is contained in:
2025-07-17 01:19:44 +03:00
parent 9fc218e556
commit 9e3d4d0a25
4 changed files with 65 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "megasniff"
version = "0.2.0"
version = "0.2.1"
description = "Library for in-time codegened type validation"
authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set
from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType
import jinja2
@@ -66,13 +66,19 @@ class SchemaInflatorGenerator:
def schema_to_inflator(self,
schema: type,
strict_mode_override: Optional[bool] = None) -> Callable[[dict[str, Any]], Any]:
txt, namespace = self._schema_to_inflator(schema, _funcname='inflate',
strict_mode_override=strict_mode_override)
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None
) -> Callable[[dict[str, Any]], Any]:
if from_type_override is not None and '__getitem__' not in dir(from_type_override):
raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_inflator(schema,
_funcname='inflate',
strict_mode_override=strict_mode_override,
from_type_override=from_type_override,
)
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
print(txt)
exec(txt, namespace)
return namespace['inflate']
@@ -113,9 +119,11 @@ class SchemaInflatorGenerator:
def _schema_to_inflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='inflate',
_namespace=None) -> tuple[str, dict]:
_namespace=None
) -> tuple[str, dict]:
if strict_mode_override is not None:
strict_mode = strict_mode_override
else:
@@ -136,6 +144,9 @@ class SchemaInflatorGenerator:
return '', namespace
namespace[f'{_funcname}_tgt_type'] = schema
namespace[utils.typename(schema)] = schema
if from_type_override is not None:
namespace['_from_type'] = from_type_override
for argname, argtype in type_hints.items():
if argname in {'return', 'self'}:
@@ -172,9 +183,7 @@ class SchemaInflatorGenerator:
default_option,
)
)
#
# out_argtypes: list[str] = []
#
for argt in argtypes:
is_builtin = is_builtin_type(argt)
@@ -192,23 +201,12 @@ class SchemaInflatorGenerator:
pass
else:
namespace[argt.__name__] = argt
#
# render_data.append(
# FieldRenderData(
# argname,
# out_argtypes,
# utils.typename(argtype),
# has_default,
# allow_none,
# default_option,
# type_origin is not None and type_origin.__name__ == 'list',
# )
# )
convertor_functext = self.template.render(
funcname=_funcname,
conversions=render_data,
tgt_type=utils.typename(schema)
tgt_type=utils.typename(schema),
from_type='_from_type' if from_type_override is not None else None
)
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext

View File

@@ -2,7 +2,7 @@
{% import "unwrap_type_data.jinja2" as unwrap_type_data %}
def {{funcname}}(from_data: dict[str, Any]):
def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{from_type}}{% endif %}) {% if tgt_type is not none %} -> {{tgt_type}} {% endif %}:
"""
{{tgt_type}}
"""

43
tests/test_signature.py Normal file
View File

@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import get_type_hints, Any, Annotated
from megasniff import SchemaInflatorGenerator
def test_return_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(A)
hints = get_type_hints(fn)
assert hints['return'] == A
assert len(hints) == 2
def test_argument_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
type custom_from_type = dict[str, Any]
fn1 = infl.schema_to_inflator(A, from_type_override=custom_from_type)
fn2 = infl.schema_to_inflator(A)
hints = get_type_hints(fn1)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] == custom_from_type
assert hints['from_data'] != dict[str, Any]
hints = get_type_hints(fn2)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] != custom_from_type
assert hints['from_data'] == dict[str, Any]