Compare commits

22 Commits

Author SHA1 Message Date
89c4bcae90 Add flat types deflator tests 2025-10-17 00:26:10 +03:00
9775bc2cc6 Fix dict body validation x2 2025-10-16 22:19:42 +03:00
4068af462b Experimental support of a file-based inflators/deflators 2025-10-16 21:01:02 +03:00
d4d7a68d7a Fix dict body validation 2025-10-16 20:51:19 +03:00
b724e3c5dd Hotfix typename escape 2025-09-14 01:54:06 +03:00
4b77eb4217 Fix deflator and typename item types escaping 2025-09-14 01:51:41 +03:00
3aae5cf2d2 Escape deflaters names, extend signatures to allow root UnionTypes 2025-08-29 02:29:24 +03:00
8b29b941af Bump version 2025-08-29 01:29:24 +03:00
ebc296a270 Update README.md 2025-08-29 01:29:04 +03:00
de6362fa1d Create deflator strict mode and explicit casts flags with tests and default universal fallback unwrapper 2025-08-29 01:20:27 +03:00
51817784a3 Create basic deflator tests 2025-08-29 00:40:14 +03:00
cc77cc7012 Create basic deflator generator 2025-08-29 00:34:14 +03:00
36e343d3bc Make argnames escape 2025-08-20 21:59:46 +03:00
0786fc600a Fix default string option rendering 2025-08-20 03:08:37 +03:00
b11266990b Add store_sources option that stores rendered source in a __megasniff_sources__ property 2025-08-20 00:33:09 +03:00
c11a63c8a5 Allow constructing iflators for dict->tuple for further args unwrap 2025-08-19 16:51:52 +03:00
9e3d4d0a25 Add signature generation 2025-07-17 01:19:44 +03:00
9fc218e556 Clean __pycache__ 2025-07-14 17:04:31 +03:00
f8cacf9319 Bump version 2025-07-14 16:59:03 +03:00
9f54115160 Create toggle for strict-mode inflate 2025-07-14 16:54:34 +03:00
bc6acb099f Fix recursive union-iterable-*-types codegen 2025-07-14 16:27:55 +03:00
897eccd8d1 Remove lookup_table from inflator generated code, rename generating func 2025-07-12 05:50:56 +03:00
27 changed files with 1558 additions and 143 deletions

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
# ---> Python # ---> Python
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
**/__pycache__
*.py[cod] *.py[cod]
*$py.class *$py.class

View File

@@ -1,12 +1,16 @@
# megasniff # megasniff
### Автоматическая валидация данных по схеме и сборка объекта в одном флаконе
### Автоматическая валидация данных по схеме, сборка и разборка объекта в одном флаконе
#### Как применять: #### Как применять:
```python ```python
# 1. Объявляем схемы # 1. Объявляем схемы
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import typing import typing
@dataclasses.dataclass @dataclasses.dataclass
class SomeSchema1: class SomeSchema1:
a: int a: int
@@ -14,45 +18,110 @@ class SomeSchema1:
c: SomeSchema2 | str | None c: SomeSchema2 | str | None
class SomeSchema2(typing.TypedDict): @dataclasses.dataclass
class SomeSchema2:
field1: dict field1: dict
field2: float field2: float
field3: typing.Optional[SomeSchema1] field3: typing.Optional[SomeSchema1]
# 2. Генерируем метод для валидации и сборки # 2. Генерируем метод для валидации и сборки
import megasniff import megasniff
infl = megasniff.SchemaInflatorGenerator() infl = megasniff.SchemaInflatorGenerator()
fn = infl.schema_to_generator(SomeSchema1) defl = megasniff.SchemaDeflatorGenerator()
fn_in = infl.schema_to_inflator(SomeSchema1)
fn_out = defl.schema_to_deflator(SomeSchema1)
# 3. Проверяем что всё работает # 3. Проверяем что всё работает
fn({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1'}}) data = fn_in({'a': 1, 'b': 2, 'c': {'field1': {}, 'field2': '1.1', 'field3': None}})
# SomeSchema1(a=1, b=2.0, c={'field1': {}, 'field2': 1.1, 'field3': None}) # SomeSchema1(a=1, b=2.0, c={'field1': {}, 'field2': 1.1, 'field3': None})
fn_out(data)
# {'a': 1, 'b': 2.0, 'c': {'field1': {}, 'field2': 1.1, 'field3': None}}
``` ```
Особенности работы: Особенности работы:
- поддерживает циклические зависимости - поддерживает циклические зависимости
- проверяем `Union`-типы через ретрай на выбросе исключения - проверяет `Union`-типы через ретрай на выбросе исключения
- по умолчанию использует готовый щаблон для кодогенерации и исполняет его по запросу, требуется особое внимание к сохранности данного шаблона - по умолчанию использует готовый щаблон для кодогенерации и исполняет его по запросу, требуется особое внимание к
- не проверяет типы списков, словарей, кортежей (реализация ожидается) сохранности данного шаблона
- проверяет типы списков, может приводить списки к множествам
- не проверяет типы generic-словарей, кортежей (реализация ожидается)
- пользовательские проверки типов должны быть реализованы через наследование и проверки в конструкторе - пользовательские проверки типов должны быть реализованы через наследование и проверки в конструкторе
- опциональный `strict-mode`: выключение приведения базовых типов
- для inflation может генерировать кортежи верхнеуровневых объектов при наличии описания схемы (полезно при
развертывании аргументов)
- `TypedDict` поддерживается только для inflation из-за сложностей выбора варианта при сборке `Union`-полей
- для deflation поддерживается включение режима `explicit_casts`, приводящего типы к тем, которые указаны в
аннотациях (не распространяется на `Union`-типы, т.к. невозможно определить какой из них должен быть выбран)
---- ----
### Как установить: ### Как установить:
#### [uv](https://docs.astral.sh/uv/concepts/projects/dependencies/#dependency-sources): #### [uv](https://docs.astral.sh/uv/concepts/projects/dependencies/#dependency-sources):
```bash ```bash
uv add megasniff --index sniff_index=https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple uv add megasniff --index sniff_index=https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple
``` ```
#### [poetry](https://python-poetry.org/docs/repositories/#private-repository-example): #### [poetry](https://python-poetry.org/docs/repositories/#private-repository-example):
1. Добавить репозиторий в `pyproject.toml` 1. Добавить репозиторий в `pyproject.toml`
```bash ```bash
poetry source add --priority=supplemental sniff_index https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple poetry source add --priority=supplemental sniff_index https://git.nikto-b.ru/api/packages/nikto_b/pypi/simple
``` ```
2. Поставить пакет 2. Поставить пакет
```bash ```bash
poetry add --source sniff_index megasniff poetry add --source sniff_index megasniff
``` ```
----
### Strict-mode:
#### Strict-mode off:
```
@dataclass
class A:
a: list[int]
```
```
>>> {"a": [1, 1.1, "321"]}
<<< A(a=[1, 1, 321])
>>> A(a=[1, 1.1, "321"])
<<< {"a": [1, 1.1, "321"]} # explicit_casts=False
<<< {"a": [1, 1, 321]} # explicit_casts=True
```
#### Strict-mode on:
```
@dataclass
class A:
a: list[int]
```
```
>>> {"a": [1, 1.1, "321"]}
<<< FieldValidationException, т.к. 1.1 не является int
>>> A(a=[1, 1.1, "321"])
<<< FieldValidationException, т.к. 1.1 не является int
```
### Tuple unwrap
```
fn = infl.schema_to_inflator(
(('a', int), TupleSchemaItem(Optional[list[int]], key_name='b', has_default=True, default=None)))
```
Создаёт `fn: (dict[str,Any]) -> tuple[int, Optional[list[int]]]: ...` (сигнатура остаётся `(dict[str,Any])->tuple`)

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "megasniff" name = "megasniff"
version = "0.1.2" version = "0.2.6.post1"
description = "Library for in-time codegened type validation" description = "Library for in-time codegened type validation"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -1 +1,2 @@
from .inflator import SchemaInflatorGenerator from .inflator import SchemaInflatorGenerator
from .deflator import SchemaDeflatorGenerator

View File

@@ -1,13 +1,17 @@
from __future__ import annotations from __future__ import annotations
import json
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType
from typing import Optional from typing import Optional
from typing import TypedDict from typing import TypedDict
import megasniff.exceptions
from megasniff.deflator import SchemaDeflatorGenerator, JsonObject
from . import SchemaInflatorGenerator from . import SchemaInflatorGenerator
@dataclass @dataclass(frozen=True)
class ASchema: class ASchema:
a: int a: int
b: float | str b: float | str
@@ -15,19 +19,81 @@ class ASchema:
c: float = 1.1 c: float = 1.1
class BSchema(TypedDict): @dataclass
class BSchema:
a: int a: int
b: str b: str
c: float c: float
d: ASchema d: ASchema
@dataclass
class CSchema:
l: set[int | ASchema]
def main(): def main():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_generator(ASchema) fn = infl.schema_to_inflator(ASchema)
d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}} # print(t)
print(fn(d)) # print(n)
# exec(t, n)
# fn = n['inflate']
# fn = infl.schema_to_generator(ASchema)
# # d = {'a': '42', 'b': 'a0.3', 'bs': {'a': 1, 'b': 'a', 'c': 1, 'd': {'a': 1, 'b': ''}}}
# d = {'a': 1, 'b': 1, 'c': 0, 'bs': {'a': 1, 'b': 2, 'c': 3, 'd': {'a': 1, 'b': 2.1, 'bs': None}}}
# d = {'a': 2, 'b': 2, 'bs': {'a': 2, 'b': 'a', 'c': 0, 'd': {'a': 2, 'b': 2}}}
# d = {'l': ['1', {'a': 42, 'b': 1}]}
d = {'a': 2, 'b': '2', 'bs': None}
try:
o = fn(d)
print(o)
for k, v in o.__dict__.items():
print(f'field {k}: {v}')
print(f'type: {type(v)}')
if isinstance(v, list):
for vi in v:
print(f'\ttype: {type(vi)}')
except megasniff.exceptions.FieldValidationException as e:
print(e.exceptions)
print(e)
@dataclass
class DSchema:
a: dict
b: dict[str, int | float | dict]
c: str | float | ASchema
d: ESchema
@dataclass
class ESchema:
a: list[list[list[str]]]
b: str | int
@dataclass
class ZSchema:
z: ZSchema | None
d: ZSchema | int
def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True)
fn = deflator.schema_to_deflator(DSchema | int)
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
ret = fn(42)
# assert ret['a'] == 1
# assert ret['b'] == 1.1
# assert ret['c'] == 'a'
# assert ret['d']['a'][0][0][0] == 'a'
# assert ret['d']['b'] == 'b'
print(json.dumps(ret, indent=4))
pass
if __name__ == '__main__': if __name__ == '__main__':
main() main_deflator()

347
src/megasniff/deflator.py Normal file
View File

@@ -0,0 +1,347 @@
# Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import importlib.resources
import typing
from collections.abc import Callable
from dataclasses import dataclass
from types import NoneType, UnionType
from typing import get_args, Union, Annotated, Sequence, TypeAliasType, \
OrderedDict, TypeAlias
import jinja2
from .utils import *
import uuid
from pathlib import Path
import tempfile
import importlib.util
JsonObject: TypeAlias = Union[None, bool, int, float, str, list['JsonObject'], dict[str, 'JsonObject']]
class Unwrapping:
kind: str
class OtherUnwrapping(Unwrapping):
tp: str
def __init__(self, tp: str = ''):
self.kind = 'other'
self.tp = tp
@dataclass()
class ObjectFieldUnwrapping:
key: str
object_key: str
unwrapping: Unwrapping
class ObjectUnwrapping(Unwrapping):
fields: list[ObjectFieldUnwrapping]
def __init__(self, fields: list[ObjectFieldUnwrapping]):
self.kind = 'object'
self.fields = fields
class ListUnwrapping(Unwrapping):
item_unwrap: Unwrapping
def __init__(self, item_unwrap: Unwrapping):
self.kind = 'list'
self.item_unwrap = item_unwrap
class DictUnwrapping(Unwrapping):
key_unwrap: Unwrapping
value_unwrap: Unwrapping
def __init__(self, key_unwrap: Unwrapping, value_unwrap: Unwrapping):
self.kind = 'dict'
self.key_unwrap = key_unwrap
self.value_unwrap = value_unwrap
class FuncUnwrapping(Unwrapping):
fn: str
def __init__(self, fn: str):
self.kind = 'fn'
self.fn = fn
@dataclass
class UnionKindUnwrapping:
kind: str
unwrapping: Unwrapping
class UnionUnwrapping(Unwrapping):
union_kinds: list[UnionKindUnwrapping]
def __init__(self, union_kinds: list[UnionKindUnwrapping]):
self.kind = 'union'
self.union_kinds = union_kinds
def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]:
if isinstance(t, TypeAliasType):
return _flatten_type(t.__value__)
origin = get_origin(t)
if origin is Annotated:
args = get_args(t)
return _flatten_type(args[0])[0], args[1]
return t, None
def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
t, _ = _flatten_type(t)
return ('deflate_' + typename(t)
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_'))
def _fallback_unwrapper(obj: Any) -> JsonObject:
if isinstance(obj, (int, float, str, bool)):
return obj
elif isinstance(obj, list):
return list(map(_fallback_unwrapper, obj))
elif isinstance(obj, dict):
return dict(map(lambda x: (_fallback_unwrapper(x[0]), _fallback_unwrapper(x[1])), obj.items()))
elif hasattr(obj, '__dict__'):
ret = {}
for k, v in obj.__dict__:
if isinstance(k, str) and k.startswith('_'):
continue
k = _fallback_unwrapper(k)
v = _fallback_unwrapper(v)
ret[k] = v
return ret
return None
class SchemaDeflatorGenerator:
templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment
object_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
_explicit_casts: bool
_out_directory: str | None
def __init__(self,
loader: Optional[jinja2.BaseLoader] = None,
strict_mode: bool = False,
explicit_casts: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'deflator.jinja2',
out_directory: str | None = None,
):
self._strict_mode = strict_mode
self._store_sources = store_sources
self._explicit_casts = explicit_casts
self._out_directory = out_directory
if loader is None:
template_path = importlib.resources.files('megasniff.templates')
loader = jinja2.FileSystemLoader(str(template_path))
self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.object_template = self.templateEnv.get_template(object_template_filename)
def schema_to_deflator(self,
schema: type,
strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None,
ignore_directory: bool = False,
out_directory_override: Optional[str] = None,
) -> Callable[[Any], dict[str, Any]]:
txt, namespace = self._schema_to_deflator(schema,
strict_mode_override=strict_mode_override,
explicit_casts_override=explicit_casts_override,
)
out_dir = self._out_directory
if out_directory_override:
out_dir = out_directory_override
if ignore_directory:
out_dir = None
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
if out_dir is not None:
filename = f"{uuid.uuid4()}.py"
filepath = Path(out_dir) / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(txt)
spec = importlib.util.spec_from_file_location("generated_module", filepath)
module = importlib.util.module_from_spec(spec)
module.__dict__.update(namespace)
spec.loader.exec_module(module)
fn_name = _schema_to_deflator_func(schema)
fn = getattr(module, fn_name)
else:
exec(txt, namespace)
fn = namespace[_schema_to_deflator_func(schema)]
if self._store_sources:
setattr(fn, '__megasniff_sources__', txt)
return fn
def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None):
if _visited_types is None:
_visited_types = []
else:
_visited_types = _visited_types.copy()
schema, field_rename = _flatten_type(schema)
if schema in _visited_types:
return FuncUnwrapping(_schema_to_deflator_func(schema)), field_rename, set(), {schema}
_visited_types.append(schema)
ongoing_types = set()
recurcive_types = set()
origin = get_origin(schema)
ret_unw = None
if origin is not None:
if origin is list:
args = get_args(schema)
item_unw, arg_rename, ongoings, item_rec = self.schema_to_unwrapper(args[0],
_visited_types=_visited_types)
ret_unw = ListUnwrapping(item_unw)
recurcive_types |= item_rec
ongoing_types |= ongoings
elif origin is dict:
args = get_args(schema)
if len(args) != 2:
ret_unw = OtherUnwrapping()
else:
k, v = args
k_unw, _, k_ongoings, k_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
v_unw, _, v_ongoings, v_rec = self.schema_to_unwrapper(k, _visited_types=_visited_types)
ongoing_types |= k_ongoings | v_ongoings
recurcive_types |= k_rec | v_rec
ret_unw = DictUnwrapping(k_unw, v_unw)
elif origin is UnionType or origin is Union:
args = get_args(schema)
union_unwraps = []
for targ in args:
arg_unw, arg_rename, ongoings, arg_rec = self.schema_to_unwrapper(targ,
_visited_types=_visited_types)
union_unwraps.append(UnionKindUnwrapping(typename(targ), arg_unw))
ongoing_types |= ongoings
recurcive_types |= arg_rec
ret_unw = UnionUnwrapping(union_unwraps)
else:
raise NotImplementedError
else:
if schema is int:
ret_unw = OtherUnwrapping('int')
elif schema is float:
ret_unw = OtherUnwrapping('float')
elif schema is bool:
ret_unw = OtherUnwrapping('bool')
elif schema is str:
ret_unw = OtherUnwrapping('str')
elif schema is None or schema is NoneType:
ret_unw = OtherUnwrapping()
elif schema is dict:
ret_unw = OtherUnwrapping()
elif schema is list:
ret_unw = OtherUnwrapping()
elif is_class_definition(schema):
hints = typing.get_type_hints(schema)
fields = []
for k, f in hints.items():
f_unw, f_rename, ongoings, f_rec = self.schema_to_unwrapper(f, _visited_types=_visited_types)
fields.append(ObjectFieldUnwrapping(f_rename or k, k, f_unw))
ongoing_types |= ongoings
recurcive_types |= f_rec
ret_unw = ObjectUnwrapping(fields)
else:
raise NotImplementedError(f'type not implemented yet: {schema}')
return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types
def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None,
*,
_funcname='deflate',
_namespace=None,
) -> tuple[str, dict]:
if strict_mode_override is not None:
strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
if explicit_casts_override is not None:
explicit_casts = explicit_casts_override
else:
explicit_casts = self._explicit_casts
template = self.object_template
types_for_namespace = set()
recursive_types = {schema}
namespace = {
'JsonObject': JsonObject,
'fallback_unwrapper': _fallback_unwrapper,
}
convertor_functext = ''
added_types = set()
while len(recursive_types ^ (recursive_types & added_types)) > 0:
rec_t = list(recursive_types ^ (recursive_types & added_types))[0]
rec_unw, _, rec_t_namespace, rec_rec_t = self.schema_to_unwrapper(rec_t)
recursive_types |= rec_rec_t
types_for_namespace |= rec_t_namespace
rec_functext = template.render(
funcname=_schema_to_deflator_func(rec_t),
from_type=typename(rec_t),
into_type=None,
root_unwrap=rec_unw,
hashname=hashname,
strict_check=strict_mode,
explicit_cast=explicit_casts,
)
convertor_functext += '\n\n\n' + rec_functext
added_types.add(rec_t)
for t in types_for_namespace:
namespace[typename(t)] = t
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
return convertor_functext, namespace

View File

@@ -1,11 +1,14 @@
# Copyright (C) 2025 Shevchenko A # Copyright (C) 2025 Shevchenko A
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import annotations
import collections.abc
import importlib.resources import importlib.resources
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import Optional, get_origin, get_args, Union, Annotated from typing import Optional, get_origin, get_args, Union, Annotated, Literal, Sequence, List, Set, TypeAliasType, \
OrderedDict
import jinja2 import jinja2
@@ -14,35 +17,77 @@ from .utils import *
@dataclass @dataclass
class RenderData: class TypeRenderData:
typeref: list[TypeRenderData] | TypeRenderData | str
allow_none: bool
is_list: bool
is_union: bool
is_strict: bool
@dataclass
class IterableTypeRenderData(TypeRenderData):
iterable_type: str
is_list = True
is_union = False
def _escape_python_name(name: str) -> str:
name = name.replace('-', '__dash__').replace('+', '__plus__').replace('/', '__shash__')
if name[0].isnumeric():
name = '__num__' + name
return name
@dataclass
class FieldRenderData:
argname: str argname: str
constrs: list[tuple[str, bool]] # typecall / use lookup table argname_escaped: str
constrs: TypeRenderData
typename: str typename: str
is_optional: bool is_optional: bool
allow_none: bool allow_none: bool
default_option: Optional[str] default_option: Optional[str]
typeid: int
def __init__(self,
@dataclass argname: str,
class SchemaInflatorTemplateSettings: constrs: TypeRenderData,
general: str = 'inflator.jinja2' typename: str,
union: str = 'union.jinja2' is_optional: bool,
allow_none: bool,
default_option: Optional[str]):
self.argname = argname
self.constrs = constrs
self.typename = typename
self.is_optional = is_optional
self.allow_none = allow_none
self.default_option = default_option
self.argname_escaped = _escape_python_name(argname)
class SchemaInflatorGenerator: class SchemaInflatorGenerator:
templateLoader: jinja2.BaseLoader templateLoader: jinja2.BaseLoader
templateEnv: jinja2.Environment templateEnv: jinja2.Environment
template: jinja2.Template
union_template: jinja2.Template object_template: jinja2.Template
settings: SchemaInflatorTemplateSettings tuple_template: jinja2.Template
_store_sources: bool
_strict_mode: bool
_out_directory: str | None
def __init__(self, def __init__(self,
loader: Optional[jinja2.BaseLoader] = None, loader: Optional[jinja2.BaseLoader] = None,
template_settings: Optional[SchemaInflatorTemplateSettings] = None): strict_mode: bool = False,
store_sources: bool = False,
*,
object_template_filename: str = 'inflator.jinja2',
tuple_template_filename: str = 'inflator_tuple.jinja2',
out_directory: str | None = None,
):
if template_settings is None: self._strict_mode = strict_mode
template_settings = SchemaInflatorTemplateSettings() self._store_sources = store_sources
self._out_directory = out_directory
if loader is None: if loader is None:
template_path = importlib.resources.files('megasniff.templates') template_path = importlib.resources.files('megasniff.templates')
@@ -50,94 +95,214 @@ class SchemaInflatorGenerator:
self.templateLoader = loader self.templateLoader = loader
self.templateEnv = jinja2.Environment(loader=self.templateLoader) self.templateEnv = jinja2.Environment(loader=self.templateLoader)
self.template = self.templateEnv.get_template(template_settings.general) self.object_template = self.templateEnv.get_template(object_template_filename)
self.union_template = self.templateEnv.get_template(template_settings.union) self.tuple_template = self.templateEnv.get_template(tuple_template_filename)
def _union_inflator(self, def schema_to_inflator(self,
argname: str, schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
argtype: str, strict_mode_override: Optional[bool] = None,
constrs: list[tuple[str, bool]], from_type_override: Optional[type | TypeAliasType] = None,
lookup_table: dict[str, Any]): ignore_directory: bool = False,
txt = self.union_template.render( out_directory_override: Optional[str] = None,
argname=argname, ) -> Callable[[dict[str, Any]], Any]:
typename=argtype, if from_type_override is not None and '__getitem__' not in dir(from_type_override):
constrs=constrs raise RuntimeError('from_type_override must provide __getitem__')
txt, namespace = self._schema_to_inflator(schema,
_funcname='inflate',
strict_mode_override=strict_mode_override,
from_type_override=from_type_override,
) )
namespace = { out_dir = self._out_directory
'_lookup_table': lookup_table if out_directory_override:
} out_dir = out_directory_override
if ignore_directory:
out_dir = None
imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt
if out_dir is not None:
filename = f"{uuid.uuid4()}.py"
filepath = Path(out_dir) / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, 'w', encoding='utf-8') as f:
f.write(txt)
spec = importlib.util.spec_from_file_location("generated_module", filepath)
module = importlib.util.module_from_spec(spec)
module.__dict__.update(namespace)
spec.loader.exec_module(module)
fn_name = _schema_to_deflator_func(schema)
fn = getattr(module, fn_name)
else:
exec(txt, namespace) exec(txt, namespace)
return namespace['inflate'] fn = namespace['inflate']
def schema_to_generator(self, if self._store_sources:
schema: type, setattr(fn, '__megasniff_sources__', txt)
return fn
def _unwrap_typeref(self, t: type, strict_mode: bool) -> TypeRenderData:
type_origin = get_origin(t)
allow_none = False
argtypes = t,
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(t)
if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True
is_union = len(argtypes) > 1
if is_union:
typerefs = list(map(lambda x: self._unwrap_typeref(x, strict_mode), argtypes))
return TypeRenderData(typerefs, allow_none, False, True, False)
elif type_origin in [list, set]:
rd = self._unwrap_typeref(argtypes[0], strict_mode)
return IterableTypeRenderData(rd, allow_none, True, False, False, type_origin.__name__)
else:
t = argtypes[0]
is_list = (type_origin or t) in [list, set]
if is_list:
t = type_origin or t
is_builtin = is_builtin_type(t)
return TypeRenderData(t.__name__ if is_builtin else f'inflate_{t.__name__}',
allow_none,
is_list,
False,
strict_mode if is_builtin else False)
def _schema_to_inflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type],
strict_mode_override: Optional[bool] = None,
from_type_override: Optional[type | TypeAliasType] = None,
*, *,
_base_lookup_table: Optional[dict] = None) -> Callable[[dict[str, Any]], Any]: _funcname='inflate',
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян _namespace=None,
type_hints = get_kwargs_type_hints(schema) ) -> tuple[str, dict]:
render_data = [] if strict_mode_override is not None:
lookup_table = _base_lookup_table or {} strict_mode = strict_mode_override
else:
strict_mode = self._strict_mode
if schema.__name__ not in lookup_table.keys(): template = self.object_template
lookup_table[schema.__name__] = None mode = 'object'
if isinstance(schema, dict):
new_schema = []
for argname, argtype in schema.items():
new_schema.append((argname, argtype))
schema = new_schema
if isinstance(schema, collections.abc.Iterable):
template = self.tuple_template
mode = 'tuple'
new_schema = []
for t in schema:
if isinstance(t, TupleSchemaItem):
new_schema.append(t)
else:
new_schema.append(TupleSchemaItem(t[1], key_name=t[0]))
schema = new_schema
# Я это написал, оно пока работает, и я не собираюсь это упрощать, сорян
if mode == 'object':
type_hints = get_kwargs_type_hints(schema)
else:
type_hints = {}
for i, t in enumerate(schema):
n = t.key_name or f'_arg_{i}'
type_hints[n] = t.schema
render_data = []
txt_segments = []
if _namespace is None:
namespace = {}
else:
namespace = _namespace
if namespace.get(f'{_funcname}_tgt_type') is not None:
return '', namespace
if mode == 'object':
namespace[f'{_funcname}_tgt_type'] = schema
namespace[utils.typename(schema)] = schema
if from_type_override is not None:
namespace['_from_type'] = from_type_override
for argname, argtype in type_hints.items(): for argname, argtype in type_hints.items():
if argname in {'return', 'self'}: if argname in {'return', 'self'}:
continue continue
has_default, default_option = get_field_default(schema, argname) has_default, default_option = get_field_default(schema, argname)
typeref = self._unwrap_typeref(argtype, strict_mode)
argtypes = argtype, argtypes = argtype,
type_origin = get_origin(argtype)
allow_none = False allow_none = False
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated])): while get_origin(argtype) is not None:
type_origin = get_origin(argtype)
if any(map(lambda x: type_origin is x, [Union, UnionType, Optional, Annotated, list, List, set, Set])):
argtypes = get_args(argtype) argtypes = get_args(argtype)
if len(argtypes) == 1:
argtype = argtypes[0]
else:
break
if NoneType in argtypes or None in argtypes: if NoneType in argtypes or None in argtypes:
argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes)) argtypes = tuple(filter(lambda x: x is not None and x is not NoneType, argtypes))
allow_none = True allow_none = True
out_argtypes: list[tuple[str, bool]] = []
for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
if argt.__name__ not in lookup_table.keys():
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
lookup_table[hash(argt)] = self.schema_to_generator(argt, _base_lookup_table=lookup_table)
if argt is schema:
out_argtypes.append(('inflate', True))
else:
out_argtypes.append((argt.__name__, is_builtin))
if len(argtypes) > 1:
lookup_table[hash(argtype)] = self._union_inflator('', '', out_argtypes, lookup_table)
render_data.append( render_data.append(
RenderData( FieldRenderData(
argname, argname,
out_argtypes, typeref,
utils.typename(argtype), utils.typename(argtype),
has_default, has_default,
allow_none, allow_none,
default_option, default_option if not isinstance(default_option, str) else f"'{default_option}'",
hash(argtype)
) )
) )
convertor_functext = self.template.render(conversions=render_data) for argt in argtypes:
is_builtin = is_builtin_type(argt)
if not is_builtin and argt is not schema:
# если случилась циклическая зависимость, мы не хотим бексконечную рекурсию
if argt.__name__ not in namespace.keys():
t, n = self._schema_to_inflator(argt,
_funcname=f'inflate_{argt.__name__}',
_namespace=namespace,
strict_mode_override=strict_mode_override)
namespace |= n
txt_segments.append(t)
elif argt is schema:
pass
else:
namespace[argt.__name__] = argt
convertor_functext = template.render(
funcname=_funcname,
conversions=render_data,
tgt_type=utils.typename(schema),
from_type='_from_type' if from_type_override is not None else None
)
convertor_functext = '\n'.join(txt_segments) + '\n\n' + convertor_functext
convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n')))) convertor_functext = '\n'.join(list(filter(lambda x: len(x.strip()), convertor_functext.split('\n'))))
convertor_functext = convertor_functext.replace(', )', ')')
namespace = {
'_tgt_type': schema,
'_lookup_table': lookup_table
}
exec(convertor_functext, namespace)
# пихаем сгенеренный метод в табличку, return convertor_functext, namespace
# ожидаем что она обновится во всех вложенных методах,
# разрешая циклические зависимости
lookup_table[schema.__name__] = namespace['inflate']
return namespace['inflate']

View File

@@ -0,0 +1,110 @@
{% macro render_unwrap_object(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% for kv in unwrapping.fields %}
{{ render_unwrap(kv.unwrapping, from_container + '.' + kv.object_key, into_container + "['" + kv.key + "']") }}
{% endfor %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_dict(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{ into_container }} = {}
{% if strict_check %}
if not isinstance({{from_container}}, dict):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'dict(' + from_container + ')' %}
{% endif %}
for k_{{hashname(unwrapping)}}, v_{{hashname(unwrapping)}} in {{from_container}}.items():
{{ render_unwrap(unwrapping.key_unwrap, 'k_' + hashname(unwrapping), 'k_' + hashname(unwrapping)) | indent(4) }}
{{ render_unwrap(unwrapping.value_unwrap, 'v_' + hashname(unwrapping), into_container + '[k_' + hashname(unwrapping) + ']') | indent(4) }}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_list(unwrapping, from_container, into_container) -%}
{%- set out -%}
{{into_container}} = []
{% if strict_check %}
if not isinstance({{from_container}}, list):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'list', str(type({{from_container}})))
{% endif %}
{% if explicit_cast %}
{% set from_container = 'list(' + from_container + ')' %}
{% endif %}
for {{hashname(unwrapping)}} in {{from_container}}:
{{ render_unwrap(unwrapping.item_unwrap, hashname(unwrapping), hashname(unwrapping)+'_tmp_container') | indent(4) }}
{{into_container}}.append({{hashname(unwrapping)}}_tmp_container)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_other(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.tp != '' and strict_check %}
if not isinstance({{from_container}}, {{unwrapping.tp}}):
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', '{{unwrapping.tp}}', str(type({{from_container}})))
{% endif %}
{% if unwrapping.tp != '' and explicit_cast %}
{{into_container}} = {{unwrapping.tp}}({{from_container}})
{% else %}
{{into_container}} = {{from_container}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap_union(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% for union_kind in unwrapping.union_kinds %}
{% if loop.index > 1 %}el{% endif %}if isinstance({{from_container}}, {{union_kind.kind}}):
{{render_unwrap(union_kind.unwrapping, from_container, into_container) | indent(4)}}
{% endfor %}
{% if strict_check %}
else:
raise FieldValidationException('{{from_container.replace("'", "\\'")}}', 'dict', str(type({{from_container}})))
{% elif explicit_cast %}
else:
{{render_unwrap(unwrapping.union_kinds[-1], from_container, into_container) | indent(4)}}
{% else %}
else:
{{into_container}} = fallback_unwrap({{from_container}})
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_unwrap(unwrapping, from_container, into_container) -%}
{%- set out -%}
{% if unwrapping.kind == 'dict' %}
{{ render_unwrap_dict(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'list' %}
{{ render_unwrap_list(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'object' %}
{{ render_unwrap_object(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'union' %}
{{ render_unwrap_union(unwrapping, from_container, into_container) }}
{% elif unwrapping.kind == 'fn' %}
{{into_container}} = {{ unwrapping.fn }}({{from_container}})
{% else %}
{{ render_unwrap_other(unwrapping, from_container, into_container) }}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}
def {{funcname}}(from_data{% if from_type is not none%}: {{from_type}}{%endif%}) -> {% if into_type is none %}JsonObject{%else%}{{into_type}}{%endif%}:
"""
{{from_type}} -> {{into_type}}
"""
{{ render_unwrap(root_unwrap, 'from_data', 'ret') | indent(4) }}
return ret

View File

@@ -1,14 +1,17 @@
{% set ns = namespace(retry_indent=0) %} {% set ns = namespace(retry_indent=0) %}
from typing import Any {% import "unwrap_type_data.jinja2" as unwrap_type_data %}
from megasniff.exceptions import MissingFieldException, FieldValidationException
def inflate(from_data: dict[str, Any]):
def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{from_type}}{% endif %}) {% if tgt_type is not none %} -> {{tgt_type}} {% endif %}:
"""
{{tgt_type}}
"""
from_data_keys = from_data.keys() from_data_keys = from_data.keys()
{% for conv in conversions %} {% for conv in conversions %}
if '{{conv.argname}}' not in from_data_keys: if '{{conv.argname}}' not in from_data_keys:
{% if conv.is_optional %} {% if conv.is_optional %}
{{conv.argname}} = {{conv.default_option}} {{conv.argname_escaped}} = {{conv.default_option}}
{% else %} {% else %}
raise MissingFieldException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}") raise MissingFieldException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}")
{% endif %} {% endif %}
@@ -18,20 +21,12 @@ def inflate(from_data: dict[str, Any]):
{% if not conv.allow_none %} {% if not conv.allow_none %}
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data) raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data)
{% else %} {% else %}
{{conv.argname}} = None {{conv.argname_escaped}} = None
{% endif %} {% endif %}
else: else:
try:
{% if conv.constrs | length > 1 or conv.constrs[0][1] is false %}
{{conv.argname}} = _lookup_table[{{conv.typeid}}](conv_data)
{% else %}
{{conv.argname}} = {{conv.constrs[0][0]}}(conv_data)
{% endif %}
except FieldValidationException as e:
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, e.exceptions)
{{ unwrap_type_data.render_segment(conv.argname_escaped, conv.constrs, "conv_data", false) | indent(4*3) }}
{% endfor %} {% endfor %}
return _tgt_type({% for conv in conversions %}{{conv.argname}}={{conv.argname}}, {% endfor %}) return {{funcname}}_tgt_type({% for conv in conversions %}{{conv.argname_escaped}}={{conv.argname_escaped}}, {% endfor %})

View File

@@ -0,0 +1,32 @@
{% set ns = namespace(retry_indent=0) %}
{% import "unwrap_type_data.jinja2" as unwrap_type_data %}
def {{funcname}}(from_data: {% if from_type is none %}dict[str, Any]{% else %}{{from_type}}{% endif %}) {% if tgt_type is not none %} -> tuple {% endif %}:
"""
{% for conv in conversions %}{{conv.argname_escaped}}:{{conv.typename}}, {% endfor %}
"""
from_data_keys = from_data.keys()
{% for conv in conversions %}
if '{{conv.argname}}' not in from_data_keys:
{% if conv.is_optional %}
{{conv.argname_escaped}} = {{conv.default_option}}
{% else %}
raise MissingFieldException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}")
{% endif %}
else:
conv_data = from_data['{{conv.argname}}']
if conv_data is None:
{% if not conv.allow_none %}
raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data)
{% else %}
{{conv.argname_escaped}} = None
{% endif %}
else:
{{ unwrap_type_data.render_segment(conv.argname_escaped, conv.constrs, "conv_data", false) | indent(4*3) }}
{% endfor %}
return ({% for conv in conversions %}{{conv.argname_escaped}}, {% endfor %})

View File

@@ -1,19 +0,0 @@
{% set ns = namespace(retry_indent=0) %}
from typing import Any
from megasniff.exceptions import FieldValidationException
def inflate(conv_data: Any):
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type, is_builtin in constrs %}
{{ ' ' * ns.retry_indent }}try:
{% if is_builtin %}
{{ ' ' * ns.retry_indent }} return {{union_type}}(conv_data)
{% else %}
{{ ' ' * ns.retry_indent }} return _lookup_table['{{union_type}}'](conv_data)
{% endif %}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
raise FieldValidationException('{{argname}}', "{{typename | replace('"', "'")}}", conv_data, all_conv_exceptions)

View File

@@ -0,0 +1,55 @@
{% macro render_iterable(argname, typedef, conv_data) -%}
{%- set out -%}
{{argname}} = []
if not isinstance({{conv_data}}, list):
raise FieldValidationException('{{argname}}', "list", conv_data, [])
for item in {{conv_data}}:
{{ render_segment("_" + argname, typedef, "item", false ) | indent(4) }}
{{argname}}.append(_{{argname}})
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_union(argname, conv, conv_data) -%}
{%- set out -%}
# unwrapping union {{conv}}
{% set ns = namespace(retry_indent=0) %}
{% set ns.retry_indent = 0 %}
all_conv_exceptions = []
{% for union_type in conv.typeref %}
{{ ' ' * ns.retry_indent }}try:
{{ render_segment(argname, union_type, conv_data, false) | indent((ns.retry_indent + 1) * 4) }}
{{ ' ' * ns.retry_indent }}except Exception as e:
{{ ' ' * ns.retry_indent }} all_conv_exceptions.append(e)
{% set ns.retry_indent = ns.retry_indent + 1 %}
{% endfor %}
{{ ' ' * ns.retry_indent }}raise FieldValidationException('{{conv.argname}}', "{{conv.typename | replace('"', "'")}}", conv_data, all_conv_exceptions)
{%- endset %}
{{out}}
{%- endmacro %}
{% macro render_segment(argname, typeref, conv_data, strict) -%}
{%- set out -%}
{% if typeref is string %}
{% if strict %}
if not isinstance({{conv_data}}, {{typeref}}):
raise FieldValidationException('{{argname}}', "{{typeref | replace('"', "'")}}", {{conv_data}}, [])
{% endif %}
{{argname}} = {{typeref}}({{conv_data}})
{% elif typeref.is_union %}
{{render_union(argname, typeref, conv_data)}}
{% elif typeref.is_list %}
{{render_iterable(argname, typeref.typeref, conv_data)}}
{{argname}} = {{typeref.iterable_type}}({{argname}})
{% else %}
{{render_segment(argname, typeref.typeref, conv_data, typeref.is_strict)}}
{% endif %}
{%- endset %}
{{out}}
{%- endmacro %}

View File

@@ -1,6 +1,15 @@
import collections.abc
import dataclasses import dataclasses
import inspect import inspect
from typing import get_type_hints, Any, get_origin from typing import get_type_hints, Any, get_origin, Iterable, Optional
@dataclasses.dataclass
class TupleSchemaItem:
schema: type
key_name: str
has_default: bool = False
default: Any = None
def is_typed_dict_type(tp: type) -> bool: def is_typed_dict_type(tp: type) -> bool:
@@ -19,7 +28,7 @@ def get_kwargs_type_hints(obj: type) -> dict[str, Any]:
return get_type_hints(obj.__init__) return get_type_hints(obj.__init__)
def get_field_default(cls: type[Any], field: str) -> tuple[bool, Any]: def get_field_default(cls: type[Any] | Iterable[TupleSchemaItem], field: str) -> tuple[bool, Any]:
if dataclasses.is_dataclass(cls): if dataclasses.is_dataclass(cls):
for f in dataclasses.fields(cls): for f in dataclasses.fields(cls):
if f.name == field: if f.name == field:
@@ -32,6 +41,12 @@ def get_field_default(cls: type[Any], field: str) -> tuple[bool, Any]:
# поле не объявлено в dataclass # поле не объявлено в dataclass
return False, None return False, None
if isinstance(cls, collections.abc.Iterable):
for i, t in enumerate(cls):
if (t.key_name or f'_arg_{i}') == field:
return t.has_default, t.default
return False, None
sig = inspect.signature(cls.__init__) sig = inspect.signature(cls.__init__)
params = list(sig.parameters.values())[1:] params = list(sig.parameters.values())[1:]
@@ -53,6 +68,26 @@ def is_builtin_type(tp: type) -> bool:
def typename(tp: type) -> str: def typename(tp: type) -> str:
if get_origin(tp) is None: ret = ''
return tp.__name__ if get_origin(tp) is None and hasattr(tp, '__name__'):
return str(tp) ret = tp.__name__
else:
ret = str(tp)
ret = (ret
.replace('.', '_')
.replace('[', '_of_')
.replace(']', '_of_')
.replace(',', '_and_')
.replace(' ', '_')
.replace('\'', '')
.replace('<', '')
.replace('>', ''))
return ret
def is_class_definition(obj):
return (isinstance(obj, type) or inspect.isclass(obj))
def hashname(obj) -> str:
return '_' + str(hash(obj)).replace('-', '_')

View File

@@ -11,7 +11,7 @@ def test_basic_constructor():
self.a = a self.a = a
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -23,7 +23,7 @@ def test_unions():
a: int | str a: int | str
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
a = fn({'a': 42}) a = fn({'a': 42})
assert a.a == 42 assert a.a == 42
@@ -45,10 +45,10 @@ class CircB:
def test_circular(): def test_circular():
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(CircA) fn = infl.schema_to_inflator(CircA)
a = fn({'b': {'a': None}}) a = fn({'b': {'a': None}})
return isinstance(a.b, CircB) assert isinstance(a.b, CircB)
def test_optional(): def test_optional():
@@ -57,6 +57,6 @@ def test_optional():
a: Optional[int] = None a: Optional[int] = None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(C) fn = infl.schema_to_inflator(C)
c = fn({}) c = fn({})
assert c.a is None assert c.a is None

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from megasniff.deflator import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator
def test_basic_deflator():
class A:
a: int
def __init__(self, a: int):
self.a = a
class B:
def __init__(self, b: int):
self.b = b
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
fnb = defl.schema_to_deflator(B)
b = fnb(B(11))
assert len(b) == 0
def test_unions():
@dataclass
class A:
a: int | str
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == '42'
a = fn(A('42a'))
assert a['a'] == '42a'
def test_dict_body():
@dataclass
class A:
a: dict[str, float]
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(A)
a = fn(A({'1': 1.1, '2': 2.2}))
print(a)
assert a['a']['1'] == 1.1
assert a['a']['2'] == 2.2
@dataclass
class CircA:
b: CircB
@dataclass
class CircB:
a: CircA | None
def test_circular():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(CircA)
a = fn(CircA(CircB(CircA(CircB(None)))))
assert isinstance(a['b'], dict)
assert isinstance(a['b']['a'], dict)
assert a['b']['a']['b']['a'] is None
def test_optional():
@dataclass
class C:
a: Optional[int] = None
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(C)
c = fn(C())
assert c['a'] is None
c = fn(C(123))
assert c['a'] == 123

View File

@@ -12,7 +12,7 @@ def test_missing_field():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(MissingFieldException): with pytest.raises(MissingFieldException):
fn({}) fn({})
@@ -23,7 +23,7 @@ def test_null():
a: int a: int
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': None}) fn({'a': None})
@@ -34,6 +34,6 @@ def test_invalid_field():
a: float | int | None a: float | int | None
infl = SchemaInflatorGenerator() infl = SchemaInflatorGenerator()
fn = infl.schema_to_generator(A) fn = infl.schema_to_inflator(A)
with pytest.raises(FieldValidationException): with pytest.raises(FieldValidationException):
fn({'a': {}}) fn({'a': {}})

View File

@@ -0,0 +1,94 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_explicit_casts_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(explicit_casts=False)
fn = defl.schema_to_deflator(A, explicit_casts_override=True)
a = fn(A(42))
assert a['a'] == 42
a = fn(A(42.0))
assert a['a'] == 42
a = fn(A('42'))
assert a['a'] == 42
with pytest.raises(TypeError):
fn(A(['42']))
def test_global_explicit_casts_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
a = fn(A([42.0, 42]))
assert len(a['a']) == 2
assert a['a'][0] == 42
assert a['a'][1] == 42
def test_global_explicit_casts_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(explicit_casts=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
b = fn(B([42.0]))
assert b['b'][0] == 42
b = fn(B([A([1.1])]))
assert b['b'][0]['a'][0] == 1

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import pytest
from megasniff.deflator import SchemaDeflatorGenerator
from src.megasniff import SchemaInflatorGenerator
def test_str_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(str, explicit_casts_override=True)
a = fn('asdf')
assert a == 'asdf'
a = fn(1234)
assert a == '1234'
fn1 = defl.schema_to_deflator(str, strict_mode_override=True)
with pytest.raises(Exception):
fn1(1234)
def test_int_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(int, explicit_casts_override=True)
a = fn(1234)
assert a == 1234
a = fn('1234')
assert a == 1234
fn1 = defl.schema_to_deflator(int, strict_mode_override=True)
with pytest.raises(Exception):
fn1('1234')
def test_float_deflator():
defl = SchemaDeflatorGenerator()
fn = defl.schema_to_deflator(float, explicit_casts_override=True)
a = fn(1234.1)
assert a == 1234.1
a = fn('1234')
assert a == 1234.0
fn1 = defl.schema_to_deflator(float, strict_mode_override=True)
with pytest.raises(Exception):
fn1(1234)

87
tests/test_iterables.py Normal file
View File

@@ -0,0 +1,87 @@
from dataclasses import dataclass
from megasniff import SchemaInflatorGenerator
def test_list_basic():
@dataclass
class A:
l: list[int]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, list)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: list[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == ['1', '2.1', '0']
def test_list_union():
@dataclass
class A:
l: list[int | str]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, list)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, list)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
def test_set_basic():
@dataclass
class A:
l: set[int]
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator(A)
a = fn({'l': []})
assert isinstance(a.l, set)
assert len(a.l) == 0
a = fn({'l': [1, 2.1, '0']})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, int), a.l))
@dataclass
class B:
l: set[str]
fn = infl.schema_to_inflator(B)
a = fn({'l': [1, 2.1, '0', 0]})
print(a.l)
assert isinstance(a.l, set)
assert len(a.l) == 3
assert all(map(lambda x: isinstance(x, str), a.l))
assert a.l == {'1', '2.1', '0'}

43
tests/test_signature.py Normal file
View File

@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import get_type_hints, Any, Annotated
from megasniff import SchemaInflatorGenerator
def test_return_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(A)
hints = get_type_hints(fn)
assert hints['return'] == A
assert len(hints) == 2
def test_argument_signature():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
type custom_from_type = dict[str, Any]
fn1 = infl.schema_to_inflator(A, from_type_override=custom_from_type)
fn2 = infl.schema_to_inflator(A)
hints = get_type_hints(fn1)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] == custom_from_type
assert hints['from_data'] != dict[str, Any]
hints = get_type_hints(fn2)
assert hints['return'] == A
assert len(hints) == 2
assert hints['from_data'] != custom_from_type
assert hints['from_data'] == dict[str, Any]

75
tests/test_strict_mode.py Normal file
View File

@@ -0,0 +1,75 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaInflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_strict_mode_basic():
class A:
def __init__(self, a: int):
self.a = a
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(A)
a = fn({'a': 42})
assert a.a == 42
with pytest.raises(FieldValidationException):
fn({'a': 42.0})
def test_global_strict_mode_basic_override():
class A:
def __init__(self, a: int):
self.a = a
infl = SchemaInflatorGenerator(strict_mode=False)
fn = infl.schema_to_inflator(A, strict_mode_override=True)
a = fn({'a': 42})
assert a.a == 42
with pytest.raises(FieldValidationException):
fn({'a': 42.0})
def test_global_strict_mode_list():
@dataclass
class A:
a: list[int]
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(A)
a = fn({'a': [42]})
assert a.a == [42]
with pytest.raises(FieldValidationException):
fn({'a': [42.0, 42]})
def test_global_strict_mode_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
infl = SchemaInflatorGenerator(strict_mode=True)
fn = infl.schema_to_inflator(B)
b = fn({'b': [{'a': []}, 42]})
assert len(b.b) == 2
assert isinstance(b.b[0], A)
assert isinstance(b.b[1], int)
with pytest.raises(FieldValidationException):
fn({'b': [42.0]})
with pytest.raises(FieldValidationException):
fn({'b': [{'a': [1.1]}]})

View File

@@ -0,0 +1,88 @@
from dataclasses import dataclass
import pytest
from megasniff import SchemaDeflatorGenerator
from megasniff.exceptions import FieldValidationException
def test_global_strict_mode_basic():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_basic_override():
class A:
a: int
def __init__(self, a):
self.a = a
defl = SchemaDeflatorGenerator(strict_mode=False)
fn = defl.schema_to_deflator(A, strict_mode_override=True)
a = fn(A(42))
assert a['a'] == 42
with pytest.raises(FieldValidationException):
fn(A(42.0))
with pytest.raises(FieldValidationException):
fn(A('42'))
with pytest.raises(FieldValidationException):
fn(A(['42']))
def test_global_strict_mode_list():
@dataclass
class A:
a: list[int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(A)
a = fn(A([42]))
assert a['a'] == [42]
with pytest.raises(FieldValidationException):
fn(A([42.0, 42]))
def test_global_strict_mode_circular():
@dataclass
class A:
a: list[int]
@dataclass
class B:
b: list[A | int]
defl = SchemaDeflatorGenerator(strict_mode=True)
fn = defl.schema_to_deflator(B)
b = fn(B([A([]), 42]))
assert len(b['b']) == 2
assert isinstance(b['b'][0], dict)
assert len(b['b'][0]['a']) == 0
assert isinstance(b['b'][1], int)
with pytest.raises(FieldValidationException):
fn(B([42.0]))
with pytest.raises(FieldValidationException):
fn(B([A([1.1])]))

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from src.megasniff.utils import TupleSchemaItem
from src.megasniff import SchemaInflatorGenerator
def test_basic_tuple():
infl = SchemaInflatorGenerator()
fn = infl.schema_to_inflator({'a': int, 'b': float, 'c': str, 'd': list[int]})
a = fn({'a': 42, 'b': 1.1, 'c': 123, 'd': []})
assert a[0] == 42
fn = infl.schema_to_inflator((('a', int), ('b', list[int])))
a = fn({'a': 42, 'b': ['1']})
assert a[1][0] == 1
fn = infl.schema_to_inflator(
(('a', int), TupleSchemaItem(Optional[list[int]], key_name='b', has_default=True, default=None)))
a = fn({'a': 42})
assert a[1] is None
assert a[0] == 42