diff --git a/src/breakshaft/models.py b/src/breakshaft/models.py index 24bff02..abd4b2a 100644 --- a/src/breakshaft/models.py +++ b/src/breakshaft/models.py @@ -71,6 +71,7 @@ class ConversionPoint: raise ValueError(f'Function {func.__qualname__} provided as injector, but return-type is not specified') rettype_origin = get_origin(rettype) + fn_rettype_origin = get_origin(fn_rettype) cm_out_origins = [ typing.Generator, typing.Iterator, @@ -83,6 +84,8 @@ class ConversionPoint: ] if any(map(lambda x: rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): rettype = get_args(rettype)[0] + if any(map(lambda x: fn_rettype_origin is x, cm_out_origins)) and is_context_manager_factory(func): + fn_rettype = get_args(fn_rettype)[0] if is_basic_type_annot(rettype): return [] diff --git a/src/breakshaft/renderer.py b/src/breakshaft/renderer.py index 79a7aa3..473686c 100644 --- a/src/breakshaft/renderer.py +++ b/src/breakshaft/renderer.py @@ -146,7 +146,7 @@ class InTimeGenerationConvertorRenderer(ConvertorRenderer): conversion_models = deduplicate_callseq(conversion_models) - ret_hash = hashname(callseq[-1].injects) + ret_hash = hashname(callseq[-1].rettype) conv_args = [] for i, from_type in enumerate(from_types):