diff --git a/pyproject.toml b/pyproject.toml index 03359c6..a804cec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,5 +24,6 @@ packages = ["src/breakshaft"] dev = [ "mypy>=1.16.1", "pytest>=8.4.1", + "pytest-asyncio>=1.1.0", "pytest-cov>=6.2.1", ] diff --git a/tests/test_ctxmanager.py b/tests/test_ctxmanager.py new file mode 100644 index 0000000..6ece635 --- /dev/null +++ b/tests/test_ctxmanager.py @@ -0,0 +1,86 @@ +from contextlib import contextmanager, asynccontextmanager +from dataclasses import dataclass +from typing import Any, Generator, AsyncGenerator + +import pytest + +from src.breakshaft.convertor import ConvRepo + +pytest_plugins = ('pytest_asyncio',) + + +@dataclass +class A: + a: int + + +@dataclass +class B: + b: float + + +def test_sync_ctxmanager(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + int_to_a_finalized = [False] + + @repo.mark_injector() + @contextmanager + def int_to_a(i: int) -> Generator[A, Any, None]: + yield A(i) + int_to_a_finalized[0] = True + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn1(B(42.1)) + assert dep == 42 + assert not int_to_a_finalized[0] + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=False) + dep = fn2(123) + assert dep == 123 + assert int_to_a_finalized[0] + + +@pytest.mark.asyncio +async def test_async_ctxmanager(): + repo = ConvRepo() + + @repo.mark_injector() + def b_to_a(b: B) -> A: + return A(int(b.b)) + + @repo.mark_injector() + def a_to_b(a: A) -> B: + return B(float(a.a)) + + int_to_a_finalized = [False] + + @repo.mark_injector() + @asynccontextmanager + async def int_to_a(i: int) -> AsyncGenerator[A, Any]: + yield A(i) + int_to_a_finalized[0] = True + + def consumer(dep: A) -> int: + return dep.a + + fn1 = repo.get_conversion((B,), consumer, force_commutative=True, force_async=False, allow_async=True) + dep = fn1(B(42.1)) + assert dep == 42 + assert not int_to_a_finalized[0] + + fn2 = repo.get_conversion((int,), consumer, force_commutative=True, force_async=False, allow_async=True) + dep = await fn2(123) + assert dep == 123 + assert int_to_a_finalized[0] diff --git a/uv.lock b/uv.lock index 5907a37..7f61796 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.13" [[package]] name = "breakshaft" -version = "0.1.0.post1" +version = "0.1.0.post2" source = { editable = "." } dependencies = [ { name = "hatchling" }, @@ -15,6 +15,7 @@ dependencies = [ dev = [ { name = "mypy" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, ] @@ -28,6 +29,7 @@ requires-dist = [ dev = [ { name = "mypy", specifier = ">=1.16.1" }, { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=1.1.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, ] @@ -216,6 +218,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, +] + [[package]] name = "pytest-cov" version = "6.2.1"