diff --git a/check.sh b/check.sh index fd2cd44..39f10e0 100755 --- a/check.sh +++ b/check.sh @@ -14,7 +14,7 @@ flake8 tricycle/ \ || EXIT_STATUS=$? # Run mypy -mypy --strict -p tricycle || EXIT_STATUS=$? +mypy --strict --implicit-reexport -p tricycle || EXIT_STATUS=$? # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then diff --git a/ci.sh b/ci.sh index 4f12db3..4b2f828 100755 --- a/ci.sh +++ b/ci.sh @@ -120,7 +120,7 @@ else INSTALLDIR=$(python -c "import os, tricycle; print(os.path.dirname(tricycle.__file__))") cp ../setup.cfg $INSTALLDIR - pytest -W error -ra --junitxml=../test-results.xml --faulthandler-timeout=60 ${INSTALLDIR} --cov="$INSTALLDIR" --cov-config=../.coveragerc --verbose + pytest -W error -ra --junitxml=../test-results.xml -o faulthandler_timeout=60 ${INSTALLDIR} --cov="$INSTALLDIR" --cov-config=../.coveragerc --verbose # Disable coverage on 3.8 until we run 3.8 on Windows CI too # https://github.com/python-trio/trio/pull/784#issuecomment-446438407 diff --git a/tricycle/_meta.py b/tricycle/_meta.py index 9b15ba2..7171712 100644 --- a/tricycle/_meta.py +++ b/tricycle/_meta.py @@ -13,6 +13,7 @@ Optional, Type, TypeVar, + TYPE_CHECKING, ) from ._service_nursery import open_service_nursery @@ -125,14 +126,15 @@ async def __wrap__(self): async def __wrap__(self) -> AsyncIterator[None]: yield - # These are necessary to placate mypy, which doesn't understand - # the asynccontextmanager metaclass __call__. They should never - # actually get called. - async def __aenter__(self: T) -> T: - raise AssertionError + if TYPE_CHECKING: + # These are necessary to placate mypy, which doesn't understand + # the asynccontextmanager metaclass __call__. They should never + # actually get called. + async def __aenter__(self: T) -> T: + raise AssertionError - async def __aexit__(self, *exc: object) -> None: - raise AssertionError + async def __aexit__(self, *exc: object) -> None: + raise AssertionError class BackgroundObject(ScopedObject): diff --git a/tricycle/_multi_cancel.py b/tricycle/_multi_cancel.py index 15d5dad..58ed6cc 100644 --- a/tricycle/_multi_cancel.py +++ b/tricycle/_multi_cancel.py @@ -1,10 +1,10 @@ import attr import trio import weakref -from typing import MutableSet, Optional +from typing import Iterator, MutableSet, Optional -@attr.s(eq=False) +@attr.s(eq=False, repr=False) class MultiCancelScope: r"""Manages a dynamic set of :class:`trio.CancelScope`\s that can be shielded and cancelled as a unit. @@ -30,6 +30,14 @@ class MultiCancelScope: _shield: bool = attr.ib(default=False, kw_only=True) _cancel_called: bool = attr.ib(default=False, kw_only=True) + def __repr__(self) -> str: + descr = ["MultiCancelScope"] + if self._shield: + descr.append(" shielded") + if self._cancel_called: + descr.append(" cancelled") + return f"<{''.join(descr)}: {list(self._child_scopes)}>" + @property def cancel_called(self) -> bool: """Returns true if :meth:`cancel` has been called.""" diff --git a/tricycle/_tests/test_meta.py b/tricycle/_tests/test_meta.py new file mode 100644 index 0000000..11f7fe1 --- /dev/null +++ b/tricycle/_tests/test_meta.py @@ -0,0 +1,144 @@ +import attr +import pytest # type: ignore +import types +import trio +import trio.testing +from async_generator import asynccontextmanager +from typing import Any, AsyncIterator, Coroutine, List +from trio_typing import TaskStatus + +from .. import ScopedObject, BackgroundObject + + +def test_too_much_magic() -> None: + with pytest.raises(TypeError) as info: + class TooMuchMagic(ScopedObject): # pragma: no cover + async def __open__(self) -> None: + pass + + @asynccontextmanager + async def __wrap__(self) -> AsyncIterator[None]: + yield + + assert str(info.value) == ( + "ScopedObjects can define __open__/__close__, or __wrap__, but not both" + ) + + + +@types.coroutine +def async_yield(value: str) -> None: + yield value + + +def test_mro() -> None: + class A(ScopedObject): + async def __open__(self) -> None: + await async_yield("open A") + + class B(A): + async def __open__(self) -> None: + await async_yield("open B") + + async def __close__(self) -> None: + await async_yield("close B") + + class C(A): + async def __open__(self) -> None: + await async_yield("open C") + + async def __close__(self) -> None: + await async_yield("close C") + + class D(B, C): + def __init__(self, value: int): + self.value = value + + async def __close__(self) -> None: + await async_yield("close D") + + assert D.__mro__ == (D, B, C, A, ScopedObject, object) + d_mgr = D(42) + assert not isinstance(d_mgr, D) + assert not hasattr(d_mgr, "value") + assert hasattr(d_mgr, "__aenter__") + + async def use_it() -> None: + async with d_mgr as d: + assert isinstance(d, D) + assert d.value == 42 + await async_yield("body") + + coro: Coroutine[str, None, None] = use_it() + record = [] + while True: + try: + record.append(coro.send(None)) + except StopIteration: + break + assert record == [ + "open A", "open C", "open B", "body", "close D", "close B", "close C" + ] + + +@attr.s(auto_attribs=True) +class Example(BackgroundObject): + ticks: int = 0 + record: List[str] = attr.Factory(list) + exiting: bool = False + + def __attrs_post_init__(self) -> None: + assert not hasattr(self, "nursery") + self.record.append("attrs_post_init") + + async def __open__(self) -> None: + self.record.append("open") + await self.nursery.start(self._background_task) + self.record.append("started") + + async def __close__(self) -> None: + assert len(self.nursery.child_tasks) != 0 + # Make sure this doesn't raise AttributeError in aexit: + del self.nursery + self.record.append("close") + self.exiting = True + + async def _background_task(self, *, task_status: TaskStatus[None]) -> None: + self.record.append("background") + await trio.sleep(1) + self.record.append("starting") + task_status.started() + self.record.append("running") + while not self.exiting: + await trio.sleep(1) + self.ticks += 1 + self.record.append("stopping") + + +class DaemonExample(Example, daemon=True): + pass + + +async def test_background(autojump_clock: trio.testing.MockClock) -> None: + async with Example(ticks=100) as obj: + assert obj.record == [ + "attrs_post_init", "open", "background", "starting", "running", "started" + ] + del obj.record[:] + await trio.sleep(5.5) + assert obj.record == ["close", "stopping"] + # 1 sec start + 6 ticks + assert trio.current_time() == 7.0 + assert obj.ticks == 106 + assert not hasattr(obj, "nursery") + + # With daemon=True, the background tasks are cancelled when the parent exits + async with DaemonExample() as obj2: + assert obj2.record == [ + "attrs_post_init", "open", "background", "starting", "running", "started" + ] + del obj2.record[:] + await trio.sleep(5.5) + assert obj2.record == ["close"] + assert trio.current_time() == 13.5 + assert obj2.ticks == 5 diff --git a/tricycle/_tests/test_multi_cancel.py b/tricycle/_tests/test_multi_cancel.py new file mode 100644 index 0000000..f6a089e --- /dev/null +++ b/tricycle/_tests/test_multi_cancel.py @@ -0,0 +1,194 @@ +import pytest # type: ignore + +import trio +import trio.testing +from .. import MultiCancelScope + + +async def test_basic(autojump_clock: trio.testing.MockClock) -> None: + parent = MultiCancelScope() + finish_order = [] + + async def cancel_child_before_entering() -> None: + child = parent.open_child() + assert not child.cancel_called + child.cancel() + assert child.cancel_called + assert not child.cancelled_caught + await trio.sleep(0.2) + with child: + assert not child.cancelled_caught + await trio.sleep(1) + assert child.cancelled_caught + finish_order.append("cancel_child_before_entering") + + async def cancel_child_after_entering() -> None: + with parent.open_child() as child: + await trio.sleep(0.3) + child.cancel() + await trio.sleep(1) + assert child.cancel_called + assert child.cancelled_caught + finish_order.append("cancel_child_after_entering") + + async def cancel_child_via_local_deadline() -> None: + child = parent.open_child() + child.deadline = trio.current_time() + 0.4 + deadline_before_entering = child.deadline + with child: + assert child.deadline == deadline_before_entering + await trio.sleep(1) + assert child.cancel_called + assert child.cancelled_caught + finish_order.append("cancel_child_via_local_deadline") + + async def cancel_child_via_local_deadline_2() -> None: + child = parent.open_child() + child.deadline = trio.current_time() + 1.0 + with child: + child.deadline -= 0.9 + await trio.sleep(1) + assert child.cancel_called + assert child.cancelled_caught + finish_order.append("cancel_child_via_local_deadline_2") + + async def cancel_parent_before_entering() -> None: + child = parent.open_child() + await trio.sleep(0.6) + assert child.cancel_called + assert not child.cancelled_caught + with child: + await trio.sleep(1) + assert child.cancelled_caught + finish_order.append("cancel_parent_before_entering") + + async def cancel_parent_after_entering() -> None: + with parent.open_child() as child: + await trio.sleep(1) + assert child.cancel_called + assert child.cancelled_caught + finish_order.append("cancel_parent_after_entering") + + async with trio.open_nursery() as nursery: + nursery.start_soon(cancel_child_before_entering) + nursery.start_soon(cancel_child_after_entering) + nursery.start_soon(cancel_child_via_local_deadline) + nursery.start_soon(cancel_child_via_local_deadline_2) + nursery.start_soon(cancel_parent_before_entering) + nursery.start_soon(cancel_parent_after_entering) + await trio.sleep(0.5) + assert "MultiCancelScope cancelled" not in repr(parent) + assert not parent.cancel_called + parent.cancel() + assert parent.cancel_called + assert "MultiCancelScope cancelled" in repr(parent) + parent.cancel() + await trio.sleep(0.2) + + nursery.cancel_scope.deadline = trio.current_time() + 0.1 + with parent.open_child() as child: + child.deadline = nursery.cancel_scope.deadline + assert child.cancel_called + assert not child.cancelled_caught + await trio.sleep_forever() + assert child.cancelled_caught + finish_order.append("cancel_parent_before_creating") + + assert not nursery.cancel_scope.cancelled_caught + assert finish_order == [ + "cancel_child_via_local_deadline_2", # t+0.1 + "cancel_child_before_entering", # t+0.2 + "cancel_child_after_entering", # t+0.3 + "cancel_child_via_local_deadline", # t+0.4 + "cancel_parent_after_entering", # t+0.5 + "cancel_parent_before_entering", # t+0.6 + "cancel_parent_before_creating", # t+0.7 + ] + + +async def test_shielding(autojump_clock: trio.testing.MockClock) -> None: + parent = MultiCancelScope() + finish_order = [] + + async def shield_child_on_creation() -> None: + try: + with parent.open_child(shield=True) as child: + await trio.sleep(1) + assert False # pragma: no cover + finally: + finish_order.append("shield_child_on_creation") + + async def shield_child_before_entering() -> None: + child = parent.open_child() + child.shield = True + try: + with child: + await trio.sleep(1) + assert False # pragma: no cover + finally: + with trio.CancelScope(shield=True): + await trio.sleep(0.1) + finish_order.append("shield_child_before_entering") + + async def shield_child_after_entering() -> None: + try: + with parent.open_child() as child: + child.shield = True + await trio.sleep(1) + assert False # pragma: no cover + finally: + with trio.CancelScope(shield=True): + await trio.sleep(0.2) + finish_order.append("shield_child_after_entering") + + async def shield_child_when_parent_shielded() -> None: + try: + with trio.CancelScope(shield=True): + await trio.sleep(0.3) + with parent.open_child() as child: + await trio.sleep(1) + finally: + with trio.CancelScope(shield=True): + await trio.sleep(0.3) + finish_order.append("shield_child_when_parent_shielded") + + async def shield_child_after_parent_unshielded() -> None: + with parent.open_child(shield=True) as child: + this_task = trio.hazmat.current_task() + + def abort_fn(_): # type: ignore + trio.hazmat.reschedule(this_task) + return trio.hazmat.Abort.FAILED + + await trio.hazmat.wait_task_rescheduled(abort_fn) + child.shield = True + await trio.sleep(0.5) + assert not child.cancelled_caught + finish_order.append("shield_child_after_parent_unshielded") + + async with trio.open_nursery() as nursery: + nursery.start_soon(shield_child_on_creation) + nursery.start_soon(shield_child_before_entering) + nursery.start_soon(shield_child_after_entering) + nursery.start_soon(shield_child_when_parent_shielded) + nursery.start_soon(shield_child_after_parent_unshielded) + + nursery.cancel_scope.cancel() + assert parent.shield == False + with trio.CancelScope(shield=True): + await trio.sleep(0.2) + assert "MultiCancelScope shielded" not in repr(parent) + parent.shield = True + assert "MultiCancelScope shielded" in repr(parent) + assert parent.shield == True + with trio.CancelScope(shield=True): + await trio.sleep(0.2) + parent.shield = False + + assert finish_order == [ + "shield_child_on_creation", # t+0.4 + "shield_child_before_entering", # t+0.5 + "shield_child_after_entering", # t+0.6 + "shield_child_when_parent_shielded", # t+0.7 + "shield_child_after_parent_unshielded", # t+0.8 + ] diff --git a/tricycle/_tests/test_service_nursery.py b/tricycle/_tests/test_service_nursery.py new file mode 100644 index 0000000..b78ca96 --- /dev/null +++ b/tricycle/_tests/test_service_nursery.py @@ -0,0 +1,108 @@ +import pytest # type: ignore +from typing import Any +from trio_typing import TaskStatus + +import trio +import trio.testing +from .. import open_service_nursery + + +async def test_basic(autojump_clock: trio.testing.MockClock) -> None: + record = [] + async with open_service_nursery() as nursery: + @nursery.start_soon + async def background_task() -> None: + try: + await trio.sleep_forever() + finally: + record.append("background_task exiting") + + task, = nursery.child_tasks + assert "background_task" in task.name + + nursery.cancel_scope.cancel() + with trio.CancelScope(shield=True): + await trio.sleep(1) + record.append("body exiting") + await trio.sleep(0) + pytest.fail("should've been cancelled") # pragma: no cover + + assert nursery.cancel_scope.cancelled_caught + assert record == ["body exiting", "background_task exiting"] + + +async def test_start(autojump_clock: trio.testing.MockClock) -> None: + record = [] + + async def sleep_then_start(val: int, *, task_status: TaskStatus[int]) -> None: + await trio.sleep(1) + task_status.started(val) + try: + await trio.sleep(10) + record.append("background task finished") # pragma: no cover + finally: + record.append("background task exiting") + + async def shielded_sleep_then_start(*, task_status: TaskStatus[None]) -> None: + with trio.CancelScope(shield=True): + await trio.sleep(1) + task_status.started() + await trio.sleep(10) + + async with open_service_nursery() as nursery: + # Child can be cancelled normally while it's starting + with trio.move_on_after(0.5) as scope: + await nursery.start(sleep_then_start, 1) + assert scope.cancelled_caught + assert not nursery.child_tasks + + # If started() is the first thing to notice a cancellation, the task + # stays in the old nursery and remains unshielded + with trio.move_on_after(0.5) as scope: + await nursery.start(shielded_sleep_then_start) + assert scope.cancelled_caught + assert not nursery.child_tasks + + assert trio.current_time() == 1.5 + + # Otherwise, once started() is called the child is shielded until + # the 'async with' block exits. + assert 42 == await nursery.start(sleep_then_start, 42) + assert trio.current_time() == 2.5 + + nursery.cancel_scope.cancel() + with trio.CancelScope(shield=True): + await trio.sleep(1) + record.append("parent task finished") + + assert trio.current_time() == 3.5 + assert record == ["parent task finished", "background task exiting"] + + +async def test_problems() -> None: + async with open_service_nursery() as nursery: + with pytest.raises(TypeError) as info: + nursery.start_soon(trio.sleep) + assert "missing 1 required positional argument" in str(info.value) + + with pytest.raises(TypeError) as info: + nursery.start_soon(trio.sleep(1)) + assert "Trio was expecting an async function" in str(info.value) + + with pytest.raises(TypeError) as info: + nursery.start_soon(int, 42) + assert "appears to be synchronous" in str(info.value) + + first_call = True + + def evil() -> Any: + nonlocal first_call + if first_call: + first_call = False + return 42 + else: + return trio.sleep(0) + + with pytest.raises(trio.TrioInternalError) as info: + nursery.start_soon(evil) + assert "all bets are off at this point" in str(info.value)