Skip to content

Commit

Permalink
Refactored cancellation across all backends
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Apr 6, 2019
1 parent 6b44e3d commit b14a72e
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 235 deletions.
16 changes: 8 additions & 8 deletions anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ def open_cancel_scope(*, shield: bool = False) -> 'typing.AsyncContextManager[Ca
:return: an asynchronous context manager that yields a cancel scope
"""
return _get_asynclib().open_cancel_scope(shield=shield)
return _get_asynclib().CancelScope(shield=shield)


def fail_after(delay: Optional[float], *,
shield: bool = False) -> 'typing.AsyncContextManager[CancelScope]':
"""
Create a context manager which raises an exception if does not finish in time.
Create an async context manager which raises an exception if does not finish in time.
:param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to
disable the timeout
Expand All @@ -160,15 +160,15 @@ def fail_after(delay: Optional[float], *,
"""
if delay is None:
return _get_asynclib().open_cancel_scope(shield=shield)
return _get_asynclib().CancelScope(shield=shield)
else:
return _get_asynclib().fail_after(delay, shield=shield)


def move_on_after(delay: Optional[float], *,
shield: bool = False) -> 'typing.AsyncContextManager[CancelScope]':
"""
Create a context manager which is exited if it does not complete within the given time.
Create an async context manager which is exited if it does not complete within the given time.
:param delay: maximum allowed time (in seconds) before exiting the context block, or ``None``
to disable the timeout
Expand All @@ -177,7 +177,7 @@ def move_on_after(delay: Optional[float], *,
"""
if delay is None:
return _get_asynclib().open_cancel_scope(shield=shield)
return _get_asynclib().CancelScope(shield=shield)
else:
return _get_asynclib().move_on_after(delay, shield=shield)

Expand All @@ -198,14 +198,14 @@ def current_effective_deadline() -> Coroutine[Any, Any, float]:
# Task groups
#

def create_task_group() -> 'typing.AsyncContextManager[TaskGroup]':
def create_task_group() -> TaskGroup:
"""
Create a task group.
:return: an asynchronous context manager that yields a task group
:return: a task group
"""
return _get_asynclib().create_task_group()
return _get_asynclib().TaskGroup()


#
Expand Down
190 changes: 102 additions & 88 deletions anyio/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,51 @@ async def sleep(delay: float) -> None:
# Timeouts and cancellation
#

class CancelScope(abc.CancelScope):
def __init__(self, host_task: asyncio.Task, deadline: float,
parent_scope: Optional['CancelScope'], shield: bool = False) -> None:
self._host_task = host_task
class CancelScope:
__slots__ = ('_deadline', '_shield', '_parent_scope', '_cancel_called', '_host_task',
'_timeout_task', '_timeout_expired')

def __init__(self, deadline: float = float('inf'), shield: bool = False):
self._deadline = deadline
self._parent_scope = parent_scope
self._shield = shield
self._parent_scope = None
self._cancel_called = False
self._host_task = None
self._timeout_task = None

async def __aenter__(self):
async def timeout():
await asyncio.sleep(self._deadline - get_running_loop().time())
self._timeout_expired = True
await self.cancel()

if self._host_task:
raise RuntimeError(
"Each CancelScope may only be used for a single 'async with' block"
)

self._host_task = current_task()
self._parent_scope = get_cancel_scope(self._host_task)
set_cancel_scope(self._host_task, self)
self._timeout_expired = False

if self._deadline != float('inf'):
self._timeout_task = get_running_loop().create_task(timeout())

return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._timeout_task:
self._timeout_task.cancel()

set_cancel_scope(self._host_task, self._parent_scope)

if isinstance(exc_val, asyncio.CancelledError):
if self._timeout_expired:
return True
elif self._cancel_called:
# This scope was directly cancelled
return True

async def cancel(self):
if not self._cancel_called:
Expand Down Expand Up @@ -199,6 +236,9 @@ def shield(self) -> bool:
return self._shield


abc.CancelScope.register(CancelScope)


def get_cancel_scope(task: asyncio.Task) -> Optional[CancelScope]:
try:
return _local.cancel_scopes_by_task.get(task)
Expand All @@ -225,57 +265,27 @@ def check_cancelled():
raise CancelledError


@asynccontextmanager
@async_generator
async def open_cancel_scope(deadline: float = float('inf'), shield: bool = False):
async def timeout():
nonlocal timeout_expired
await asyncio.sleep(deadline - get_running_loop().time())
timeout_expired = True
await scope.cancel()

host_task = cast(asyncio.Task, current_task())
scope = CancelScope(host_task, deadline, get_cancel_scope(host_task), shield)
set_cancel_scope(host_task, scope)
timeout_expired = False

timeout_task = None
if deadline != float('inf'):
timeout_task = get_running_loop().create_task(timeout())

try:
await yield_(scope)
except asyncio.CancelledError as exc:
if timeout_expired:
raise TimeoutError().with_traceback(exc.__traceback__) from None
elif not scope._cancel_called:
raise
finally:
if timeout_task:
timeout_task.cancel()

set_cancel_scope(host_task, scope._parent_scope)
def open_cancel_scope(deadline: float = float('inf'), shield: bool = False) -> CancelScope:
return CancelScope(deadline, shield)


@asynccontextmanager
@async_generator
async def fail_after(delay: float, shield: bool):
deadline = get_running_loop().time() + delay
async with open_cancel_scope(deadline, shield) as cancel_scope:
await yield_(cancel_scope)
async with CancelScope(deadline, shield) as scope:
await yield_(scope)

if scope._timeout_expired:
raise TimeoutError


@asynccontextmanager
@async_generator
async def move_on_after(delay: float, shield: bool):
deadline = get_running_loop().time() + delay
cancel_scope = None
try:
async with open_cancel_scope(deadline, shield) as cancel_scope:
await yield_(cancel_scope)
except TimeoutError:
if not cancel_scope or not cancel_scope.cancel_called:
raise
async with CancelScope(deadline=deadline, shield=shield) as scope:
await yield_(scope)


async def current_effective_deadline():
Expand All @@ -293,14 +303,57 @@ async def current_effective_deadline():
#

class TaskGroup:
__slots__ = 'cancel_scope', '_active', '_tasks', '_host_task'
__slots__ = 'cancel_scope', '_active', '_tasks'

def __init__(self, cancel_scope: 'CancelScope', host_task: asyncio.Task) -> None:
self.cancel_scope = cancel_scope
self._host_task = host_task
self._active = True
def __init__(self) -> None:
self.cancel_scope = CancelScope()
self._active = False
self._tasks = set() # type: Set[asyncio.Task]

async def __aenter__(self):
await self.cancel_scope.__aenter__()
self._active = True
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
exceptions = []
ignore_exception = False
if exc_val is not None:
await self.cancel_scope.cancel()
if not isinstance(exc_val, (asyncio.CancelledError, CancelledError)):
exceptions.append(exc_val)
elif not self.cancel_scope._parent_scope:
ignore_exception = True
elif not self.cancel_scope._parent_scope.cancel_called:
ignore_exception = True

if self.cancel_scope.cancel_called:
for task in self._tasks:
if task._coro.cr_await is not None:
task.cancel()

while self._tasks:
for task in set(self._tasks):
try:
await task
except (asyncio.CancelledError, CancelledError):
set_cancel_scope(task, None)
self._tasks.remove(task)
except BaseException as exc:
set_cancel_scope(task, None)
self._tasks.remove(task)
exceptions.append(exc)

self._active = False
await self.cancel_scope.__aexit__(exc_type, exc_val, exc_tb)

if len(exceptions) > 1:
raise ExceptionGroup(exceptions)
elif exceptions and exceptions[0] is not exc_val:
raise exceptions[0]

return ignore_exception

async def _run_wrapped_task(self, func, *args):
try:
await func(*args)
Expand Down Expand Up @@ -332,45 +385,6 @@ async def spawn(self, func: Callable, *args, name=None) -> None:
abc.TaskGroup.register(TaskGroup)


@asynccontextmanager
@async_generator
async def create_task_group():
async with open_cancel_scope() as cancel_scope:
group = TaskGroup(cancel_scope, current_task())
exceptions = []
try:
try:
await yield_(group)
except (CancelledError, asyncio.CancelledError):
await cancel_scope.cancel()
except BaseException as exc:
exceptions.append(exc)
await cancel_scope.cancel()

if cancel_scope.cancel_called:
for task in group._tasks:
if task._coro.cr_await is not None:
task.cancel()

while group._tasks:
for task in set(group._tasks):
try:
await task
except (CancelledError, asyncio.CancelledError):
group._tasks.remove(task)
set_cancel_scope(task, None)
except BaseException as exc:
group._tasks.remove(task)
set_cancel_scope(task, None)
exceptions.append(exc)
finally:
group._active = False

if len(exceptions) > 1:
raise ExceptionGroup(exceptions)
elif exceptions:
raise exceptions[0]

#
# Threads
#
Expand Down
Loading

0 comments on commit b14a72e

Please sign in to comment.