-
Notifications
You must be signed in to change notification settings - Fork 14
feat: Implement python retrying connection, which generically retries stream errors #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| import pkg_resources | ||
|
|
||
| pkg_resources.declare_namespace(__name__) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| import pkg_resources | ||
|
|
||
| pkg_resources.declare_namespace(__name__) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from typing import Generic, TypeVar, Coroutine, Any, AsyncContextManager | ||
| from abc import ABCMeta, abstractmethod | ||
| from google.api_core.exceptions import GoogleAPICallError | ||
|
|
||
| Request = TypeVar('Request') | ||
| Response = TypeVar('Response') | ||
|
|
||
|
|
||
| class Connection(Generic[Request, Response], AsyncContextManager): | ||
| """ | ||
| A connection to an underlying stream. Only one call to 'read' may be outstanding at a time. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| async def write(self, request: Request) -> None: | ||
| """ | ||
| Write a message to the stream. | ||
|
|
||
| Raises: | ||
| GoogleAPICallError: When the connection terminates in failure. | ||
| """ | ||
| raise NotImplementedError() | ||
|
|
||
| @abstractmethod | ||
| async def read(self) -> Response: | ||
| """ | ||
| Read a message off of the stream. | ||
|
|
||
| Raises: | ||
| GoogleAPICallError: When the connection terminates in failure. | ||
| """ | ||
| raise NotImplementedError() | ||
|
|
||
|
|
||
| class ConnectionFactory(Generic[Request, Response]): | ||
| """A factory for producing Connections.""" | ||
| def new(self) -> Connection[Request, Response]: | ||
| raise NotImplementedError() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from typing import Generic | ||
| from abc import ABCMeta, abstractmethod | ||
| from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response | ||
|
|
||
|
|
||
| class ConnectionReinitializer(Generic[Request, Response], metaclass=ABCMeta): | ||
| """A class capable of reinitializing a connection after a new one has been created.""" | ||
| @abstractmethod | ||
| def reinitialize(self, connection: Connection[Request, Response]): | ||
| """Reinitialize a connection. | ||
|
|
||
| Args: | ||
| connection: The connection to reinitialize | ||
|
|
||
| Raises: | ||
| GoogleAPICallError: If it fails to reinitialize. | ||
| """ | ||
| raise NotImplementedError() | ||
|
|
||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| from typing import AsyncIterator, TypeVar, Optional, Callable, AsyncIterable | ||
| import asyncio | ||
|
|
||
| from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory | ||
| from google.cloud.pubsublite.internal.wire.work_item import WorkItem | ||
| from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable | ||
|
|
||
| T = TypeVar('T') | ||
|
|
||
|
|
||
| class GapicConnection(Connection[Request, Response], AsyncIterator[Request], PermanentFailable): | ||
dpcollins-google marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """A Connection wrapping a gapic AsyncIterator[Request/Response] pair.""" | ||
| _write_queue: 'asyncio.Queue[WorkItem[Request]]' | ||
| _response_it: Optional[AsyncIterator[Response]] | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self._write_queue = asyncio.Queue(maxsize=1) | ||
|
|
||
| def set_response_it(self, response_it: AsyncIterator[Response]): | ||
| self._response_it = response_it | ||
|
|
||
| async def write(self, request: Request) -> None: | ||
| item = WorkItem(request) | ||
| await self.await_or_fail(self._write_queue.put(item)) | ||
| await self.await_or_fail(item.response_future) | ||
|
|
||
| async def read(self) -> Response: | ||
| return await self.await_or_fail(self._response_it.__anext__()) | ||
|
|
||
| def __aenter__(self): | ||
| return self | ||
|
|
||
| def __aexit__(self, exc_type, exc_value, traceback) -> None: | ||
| pass | ||
|
|
||
| async def __anext__(self) -> Request: | ||
| item: WorkItem[Request] = await self.await_or_fail(self._write_queue.get()) | ||
| item.response_future.set_result(None) | ||
| return item.request | ||
|
|
||
| def __aiter__(self) -> AsyncIterator[Response]: | ||
| return self | ||
|
|
||
|
|
||
| class GapicConnectionFactory(ConnectionFactory[Request, Response]): | ||
| """A ConnectionFactory that produces GapicConnections.""" | ||
| _producer = Callable[[AsyncIterator[Request]], AsyncIterable[Response]] | ||
|
|
||
| def New(self) -> Connection[Request, Response]: | ||
| conn = GapicConnection[Request, Response]() | ||
| response_iterable = self._producer(conn) | ||
| conn.set_response_it(response_iterable.__aiter__()) | ||
| return conn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import asyncio | ||
| from typing import Awaitable, TypeVar | ||
|
|
||
| from google.api_core.exceptions import GoogleAPICallError | ||
|
|
||
| T = TypeVar('T') | ||
|
|
||
|
|
||
| class PermanentFailable: | ||
| """A class that can experience permanent failures, with helpers for forwarding these to client actions.""" | ||
| _failure_task: asyncio.Future | ||
|
|
||
| def __init__(self): | ||
| self._failure_task = asyncio.Future() | ||
|
|
||
| async def await_or_fail(self, awaitable: Awaitable[T]) -> T: | ||
| if self._failure_task.done(): | ||
| raise self._failure_task.exception() | ||
| task = asyncio.ensure_future(awaitable) | ||
| done, _ = await asyncio.wait([task, self._failure_task], return_when=asyncio.FIRST_COMPLETED) | ||
| if task in done: | ||
| try: | ||
| return await task | ||
| except GoogleAPICallError as e: | ||
| self.fail(e) | ||
| task.cancel() | ||
| raise self._failure_task.exception() | ||
|
|
||
| def fail(self, err: GoogleAPICallError): | ||
| if not self._failure_task.done(): | ||
| self._failure_task.set_exception(err) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import asyncio | ||
|
|
||
| from typing import Awaitable | ||
| from google.api_core.exceptions import GoogleAPICallError, Cancelled | ||
| from google.cloud.pubsublite.status_codes import is_retryable | ||
| from google.cloud.pubsublite.internal.wire.connection_reinitializer import ConnectionReinitializer | ||
| from google.cloud.pubsublite.internal.wire.connection import Connection, Request, Response, ConnectionFactory | ||
| from google.cloud.pubsublite.internal.wire.work_item import WorkItem | ||
| from google.cloud.pubsublite.internal.wire.permanent_failable import PermanentFailable | ||
|
|
||
| _MIN_BACKOFF_SECS = .01 | ||
| _MAX_BACKOFF_SECS = 10 | ||
dpcollins-google marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class RetryingConnection(Connection[Request, Response], PermanentFailable): | ||
| """A connection which performs retries on an underlying stream when experiencing retryable errors.""" | ||
| _connection_factory: ConnectionFactory[Request, Response] | ||
| _reinitializer: ConnectionReinitializer[Request, Response] | ||
|
|
||
| _loop_task: asyncio.Future | ||
|
|
||
| _write_queue: 'asyncio.Queue[WorkItem[Request]]' | ||
| _read_queue: 'asyncio.Queue[Response]' | ||
|
|
||
| def __init__(self, connection_factory: ConnectionFactory[Request, Response], reinitializer: ConnectionReinitializer[Request, Response]): | ||
| super().__init__() | ||
| self._connection_factory = connection_factory | ||
| self._reinitializer = reinitializer | ||
| self._write_queue = asyncio.Queue(maxsize=1) | ||
| self._read_queue = asyncio.Queue(maxsize=1) | ||
|
|
||
| async def __aenter__(self): | ||
| self._loop_task = asyncio.ensure_future(self._run_loop()) | ||
| return self | ||
|
|
||
| async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
| self.fail(Cancelled("Connection shutting down.")) | ||
|
|
||
| async def write(self, request: Request) -> None: | ||
| item = WorkItem(request) | ||
| await self.await_or_fail(self._write_queue.put(item)) | ||
| return await self.await_or_fail(item.response_future) | ||
|
|
||
| async def read(self) -> Response: | ||
| return await self.await_or_fail(self._read_queue.get()) | ||
|
|
||
| async def _run_loop(self): | ||
| """ | ||
| Processes actions on this connection and handles retries until cancelled. | ||
| """ | ||
| try: | ||
| bad_retries = 0 | ||
| while True: | ||
| try: | ||
| async with self._connection_factory.new() as connection: | ||
| await self._reinitializer.reinitialize(connection) | ||
| bad_retries = 0 | ||
| await self._loop_connection(connection) | ||
| except (Exception, GoogleAPICallError) as e: | ||
| if not is_retryable(e): | ||
| self.fail(e) | ||
| return | ||
| await asyncio.sleep(min(_MAX_BACKOFF_SECS, _MIN_BACKOFF_SECS * (2**bad_retries))) | ||
| bad_retries += 1 | ||
|
|
||
| except asyncio.CancelledError: | ||
| return | ||
|
|
||
| async def _loop_connection(self, connection: Connection[Request, Response]): | ||
| read_task: Awaitable[Response] = asyncio.ensure_future(connection.read()) | ||
| write_task: Awaitable[WorkItem[Request]] = asyncio.ensure_future(self._write_queue.get()) | ||
| while True: | ||
| done, _ = await asyncio.wait([write_task, read_task], return_when=asyncio.FIRST_COMPLETED) | ||
| if write_task in done: | ||
| await self._handle_write(connection, await write_task) | ||
| write_task = asyncio.ensure_future(self._write_queue.get()) | ||
| if read_task in done: | ||
| await self._read_queue.put(await read_task) | ||
| read_task = asyncio.ensure_future(connection.read()) | ||
|
|
||
| @staticmethod | ||
| async def _handle_write(connection: Connection[Request, Response], to_write: WorkItem[Request]): | ||
| try: | ||
| await connection.write(to_write.request) | ||
| to_write.response_future.set_result(None) | ||
| except GoogleAPICallError as e: | ||
| to_write.response_future.set_exception(e) | ||
| raise e | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| import asyncio | ||
| from typing import Generic, TypeVar | ||
|
|
||
| T = TypeVar('T') | ||
|
|
||
|
|
||
| class WorkItem(Generic[T]): | ||
| """An item of work and a future to complete when it is finished.""" | ||
| request: T | ||
| response_future: "asyncio.Future[None]" | ||
|
|
||
| def __init__(self, request: T): | ||
| self.request = request | ||
| self.response_future = asyncio.Future() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from grpc import StatusCode | ||
| from google.api_core.exceptions import GoogleAPICallError | ||
|
|
||
| retryable_codes = { | ||
| StatusCode.DEADLINE_EXCEEDED, StatusCode.ABORTED, StatusCode.INTERNAL, StatusCode.UNAVAILABLE, StatusCode.UNKNOWN | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you considered adding CANCELLED and/or RESOURCE_EXHAUSTED? The retry codes for the CPS library are here:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CANCELLED should not be retried, it is a bug if it leaks to the client on a non-client initiated action. |
||
| } | ||
|
|
||
|
|
||
| def is_retryable(error: GoogleAPICallError) -> bool: | ||
| return error.grpc_status_code in retryable_codes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from typing import List, Union, Any | ||
|
|
||
|
|
||
| async def async_iterable(elts: List[Union[Any, Exception]]): | ||
| for elt in elts: | ||
| if isinstance(elt, Exception): | ||
| raise elt | ||
| yield elt |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import asyncio | ||
|
|
||
| import pytest | ||
| from google.api_core.exceptions import InternalServerError | ||
| from google.cloud.pubsublite.internal.wire.gapic_connection import GapicConnection | ||
| from google.cloud.pubsublite.testing.test_utils import async_iterable | ||
|
|
||
| # All test coroutines will be treated as marked. | ||
| pytestmark = pytest.mark.asyncio | ||
|
|
||
|
|
||
| async def test_read_error_fails(): | ||
| conn = GapicConnection[int, int]() | ||
| conn.set_response_it(async_iterable([InternalServerError("abc")])) | ||
| with pytest.raises(InternalServerError): | ||
| await conn.read() | ||
| with pytest.raises(InternalServerError): | ||
| await conn.read() | ||
| with pytest.raises(InternalServerError): | ||
| await conn.write(3) | ||
|
|
||
|
|
||
| async def test_read_success(): | ||
| conn = GapicConnection[int, int]() | ||
| conn.set_response_it(async_iterable([3, 4, 5])) | ||
| assert [await conn.read() for _ in range(3)] == [3, 4, 5] | ||
|
|
||
|
|
||
| async def test_writes(): | ||
| conn = GapicConnection[int, int]() | ||
| conn.set_response_it(async_iterable([])) | ||
| task1 = asyncio.ensure_future(conn.write(1)) | ||
| task2 = asyncio.ensure_future(conn.write(2)) | ||
| assert not task1.done() | ||
| assert not task2.done() | ||
| assert await conn.__anext__() == 1 | ||
| await task1 | ||
| assert not task2.done() | ||
| assert await conn.__anext__() == 2 | ||
| await task2 |
Uh oh!
There was an error while loading. Please reload this page.