Escape deflaters names, extend signatures to allow root UnionTypes

This commit is contained in:
2025-08-29 02:23:13 +03:00
parent 8b29b941af
commit 3aae5cf2d2
4 changed files with 37 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "megasniff" name = "megasniff"
version = "0.2.4" version = "0.2.4.post1"
description = "Library for in-time codegened type validation" description = "Library for in-time codegened type validation"
authors = [ authors = [
{ name = "nikto_b", email = "niktob560@yandex.ru" } { name = "nikto_b", email = "niktob560@yandex.ru" }

View File

@@ -81,10 +81,11 @@ class ZSchema:
def main_deflator(): def main_deflator():
deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True) deflator = SchemaDeflatorGenerator(store_sources=True, explicit_casts=True, strict_mode=True)
fn = deflator.schema_to_deflator(DSchema) fn = deflator.schema_to_deflator(DSchema | int)
print(getattr(fn, '__megasniff_sources__', '## No data'))
# ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42))) # ret = fn(ZSchema(ZSchema(ZSchema(None, 42), 42), ZSchema(None, 42)))
ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], ['b']))) ret = fn(DSchema({'a': 34}, {}, ASchema(1, 'a', None), ESchema([[['a'], ['b']]], 'b')))
ret = fn(42)
# assert ret['a'] == 1 # assert ret['a'] == 1
# assert ret['b'] == 1.1 # assert ret['b'] == 1.1
# assert ret['c'] == 'a' # assert ret['c'] == 'a'

View File

@@ -99,7 +99,14 @@ def _flatten_type(t: type | TypeAliasType) -> tuple[type, Optional[str]]:
def _schema_to_deflator_func(t: type | TypeAliasType) -> str: def _schema_to_deflator_func(t: type | TypeAliasType) -> str:
t, _ = _flatten_type(t) t, _ = _flatten_type(t)
return 'deflate_' + typename(t).replace('.', '_') ret = 'deflate_' + typename(t).replace('.', '_')
ret = ret.replace(' | ', '__OR__')
ret = ret.replace('<', '')
ret = ret.replace('>', '')
ret = ret.replace(' ', '__')
ret = ret.replace("'", '')
ret = ret.replace('.', '_')
return ret
def _fallback_unwrapper(obj: Any) -> JsonObject: def _fallback_unwrapper(obj: Any) -> JsonObject:
@@ -152,7 +159,7 @@ class SchemaDeflatorGenerator:
self.object_template = self.templateEnv.get_template(object_template_filename) self.object_template = self.templateEnv.get_template(object_template_filename)
def schema_to_deflator(self, def schema_to_deflator(self,
schema: type, schema: type | UnionType,
strict_mode_override: Optional[bool] = None, strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None, explicit_casts_override: Optional[bool] = None,
) -> Callable[[Any], dict[str, Any]]: ) -> Callable[[Any], dict[str, Any]]:
@@ -163,13 +170,15 @@ class SchemaDeflatorGenerator:
imports = ('from typing import Any\n' imports = ('from typing import Any\n'
'from megasniff.exceptions import MissingFieldException, FieldValidationException\n') 'from megasniff.exceptions import MissingFieldException, FieldValidationException\n')
txt = imports + '\n' + txt txt = imports + '\n' + txt
print(txt)
exec(txt, namespace) exec(txt, namespace)
fn = namespace[_schema_to_deflator_func(schema)] fn = namespace[_schema_to_deflator_func(schema)]
if self._store_sources: if self._store_sources:
setattr(fn, '__megasniff_sources__', txt) setattr(fn, '__megasniff_sources__', txt)
return fn return fn
def schema_to_unwrapper(self, schema: type | TypeAliasType, *, _visited_types: Optional[list[type]] = None): def schema_to_unwrapper(self, schema: type | UnionType | TypeAliasType, *,
_visited_types: Optional[list[type]] = None):
if _visited_types is None: if _visited_types is None:
_visited_types = [] _visited_types = []
else: else:
@@ -249,12 +258,11 @@ class SchemaDeflatorGenerator:
return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types return ret_unw, field_rename, set(_visited_types) | ongoing_types, recurcive_types
def _schema_to_deflator(self, def _schema_to_deflator(self,
schema: type | Sequence[TupleSchemaItem | tuple[str, type]] | OrderedDict[str, type], schema: type | UnionType,
strict_mode_override: Optional[bool] = None, strict_mode_override: Optional[bool] = None,
explicit_casts_override: Optional[bool] = None, explicit_casts_override: Optional[bool] = None,
into_type_override: Optional[type | TypeAliasType] = None, into_type_override: Optional[type | TypeAliasType] = None,
*, *,
_funcname='deflate',
_namespace=None, _namespace=None,
) -> tuple[str, dict]: ) -> tuple[str, dict]:
if strict_mode_override is not None: if strict_mode_override is not None:

View File

@@ -68,9 +68,18 @@ def is_builtin_type(tp: type) -> bool:
def typename(tp: type) -> str: def typename(tp: type) -> str:
ret = ''
if get_origin(tp) is None and hasattr(tp, '__name__'): if get_origin(tp) is None and hasattr(tp, '__name__'):
return tp.__name__ ret = tp.__name__
return str(tp) ret = str(tp)
ret = ret.replace(' | ', '__OR__')
ret = ret.replace('<', '')
ret = ret.replace('>', '')
ret = ret.replace(' ', '__')
ret = ret.replace("'", '')
ret = ret.replace('.', '_')
return ret
def is_class_definition(obj): def is_class_definition(obj):
@@ -78,4 +87,11 @@ def is_class_definition(obj):
def hashname(obj) -> str: def hashname(obj) -> str:
return '_' + str(hash(obj)).replace('-', '_') ret = '_' + str(hash(obj)).replace('-', '_')
ret = ret.replace(' | ', '__OR__')
ret = ret.replace('<', '')
ret = ret.replace('>', '')
ret = ret.replace(' ', '__')
ret = ret.replace("'", '')
ret = ret.replace('.', '_')
return ret