Skip to content

Commit 94043b1

Browse files
authored
Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#8650)
Closes #7826.
1 parent 562deea commit 94043b1

File tree

5 files changed

+207
-28
lines changed

5 files changed

+207
-28
lines changed

api_core/google/api_core/bidi.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ def pending_requests(self):
349349
return self._request_queue.qsize()
350350

351351

352+
def _never_terminate(future_or_error):
353+
"""By default, no errors cause BiDi termination."""
354+
return False
355+
356+
352357
class ResumableBidiRpc(BidiRpc):
353358
"""A :class:`BidiRpc` that can automatically resume the stream on errors.
354359
@@ -391,6 +396,9 @@ def should_recover(exc):
391396
should_recover (Callable[[Exception], bool]): A function that returns
392397
True if the stream should be recovered. This will be called
393398
whenever an error is encountered on the stream.
399+
should_terminate (Callable[[Exception], bool]): A function that returns
400+
True if the stream should be terminated. This will be called
401+
whenever an error is encountered on the stream.
394402
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
395403
the request.
396404
throttle_reopen (bool): If ``True``, throttling will be applied to
@@ -401,12 +409,14 @@ def __init__(
401409
self,
402410
start_rpc,
403411
should_recover,
412+
should_terminate=_never_terminate,
404413
initial_request=None,
405414
metadata=None,
406415
throttle_reopen=False,
407416
):
408417
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
409418
self._should_recover = should_recover
419+
self._should_terminate = should_terminate
410420
self._operational_lock = threading.RLock()
411421
self._finalized = False
412422
self._finalize_lock = threading.Lock()
@@ -433,7 +443,9 @@ def _on_call_done(self, future):
433443
# error, not for errors that we can recover from. Note that grpc's
434444
# "future" here is also a grpc.RpcError.
435445
with self._operational_lock:
436-
if not self._should_recover(future):
446+
if self._should_terminate(future):
447+
self._finalize(future)
448+
elif not self._should_recover(future):
437449
self._finalize(future)
438450
else:
439451
_LOGGER.debug("Re-opening stream from gRPC callback.")
@@ -496,6 +508,12 @@ def _recoverable(self, method, *args, **kwargs):
496508
with self._operational_lock:
497509
_LOGGER.debug("Call to retryable %r caused %s.", method, exc)
498510

511+
if self._should_terminate(exc):
512+
self.close()
513+
_LOGGER.debug("Terminating %r due to %s.", method, exc)
514+
self._finalize(exc)
515+
break
516+
499517
if not self._should_recover(exc):
500518
self.close()
501519
_LOGGER.debug("Not retrying %r due to %s.", method, exc)

api_core/tests/unit/test_bidi.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,33 +370,111 @@ def cancel(self):
370370

371371

372372
class TestResumableBidiRpc(object):
373-
def test_initial_state(self):
374-
callback = mock.Mock()
375-
callback.return_value = True
376-
bidi_rpc = bidi.ResumableBidiRpc(None, callback)
373+
def test_ctor_defaults(self):
374+
start_rpc = mock.Mock()
375+
should_recover = mock.Mock()
376+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
377+
378+
assert bidi_rpc.is_active is False
379+
assert bidi_rpc._finalized is False
380+
assert bidi_rpc._start_rpc is start_rpc
381+
assert bidi_rpc._should_recover is should_recover
382+
assert bidi_rpc._should_terminate is bidi._never_terminate
383+
assert bidi_rpc._initial_request is None
384+
assert bidi_rpc._rpc_metadata is None
385+
assert bidi_rpc._reopen_throttle is None
386+
387+
def test_ctor_explicit(self):
388+
start_rpc = mock.Mock()
389+
should_recover = mock.Mock()
390+
should_terminate = mock.Mock()
391+
initial_request = mock.Mock()
392+
metadata = {"x-foo": "bar"}
393+
bidi_rpc = bidi.ResumableBidiRpc(
394+
start_rpc,
395+
should_recover,
396+
should_terminate=should_terminate,
397+
initial_request=initial_request,
398+
metadata=metadata,
399+
throttle_reopen=True,
400+
)
377401

378402
assert bidi_rpc.is_active is False
403+
assert bidi_rpc._finalized is False
404+
assert bidi_rpc._should_recover is should_recover
405+
assert bidi_rpc._should_terminate is should_terminate
406+
assert bidi_rpc._initial_request is initial_request
407+
assert bidi_rpc._rpc_metadata == metadata
408+
assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)
409+
410+
def test_done_callbacks_terminate(self):
411+
cancellation = mock.Mock()
412+
start_rpc = mock.Mock()
413+
should_recover = mock.Mock(spec=["__call__"], return_value=True)
414+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
415+
bidi_rpc = bidi.ResumableBidiRpc(
416+
start_rpc, should_recover, should_terminate=should_terminate
417+
)
418+
callback = mock.Mock(spec=["__call__"])
419+
420+
bidi_rpc.add_done_callback(callback)
421+
bidi_rpc._on_call_done(cancellation)
422+
423+
should_terminate.assert_called_once_with(cancellation)
424+
should_recover.assert_not_called()
425+
callback.assert_called_once_with(cancellation)
426+
assert not bidi_rpc.is_active
379427

380428
def test_done_callbacks_recoverable(self):
381429
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
382-
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
430+
should_recover = mock.Mock(spec=["__call__"], return_value=True)
431+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
383432
callback = mock.Mock(spec=["__call__"])
384433

385434
bidi_rpc.add_done_callback(callback)
386435
bidi_rpc._on_call_done(mock.sentinel.future)
387436

388437
callback.assert_not_called()
389438
start_rpc.assert_called_once()
439+
should_recover.assert_called_once_with(mock.sentinel.future)
390440
assert bidi_rpc.is_active
391441

392442
def test_done_callbacks_non_recoverable(self):
393-
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
443+
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
444+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
445+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
394446
callback = mock.Mock(spec=["__call__"])
395447

396448
bidi_rpc.add_done_callback(callback)
397449
bidi_rpc._on_call_done(mock.sentinel.future)
398450

399451
callback.assert_called_once_with(mock.sentinel.future)
452+
should_recover.assert_called_once_with(mock.sentinel.future)
453+
assert not bidi_rpc.is_active
454+
455+
def test_send_terminate(self):
456+
cancellation = ValueError()
457+
call_1 = CallStub([cancellation], active=False)
458+
call_2 = CallStub([])
459+
start_rpc = mock.create_autospec(
460+
grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
461+
)
462+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
463+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
464+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
465+
466+
bidi_rpc.open()
467+
468+
bidi_rpc.send(mock.sentinel.request)
469+
470+
assert bidi_rpc.pending_requests == 1
471+
assert bidi_rpc._request_queue.get() is None
472+
473+
should_recover.assert_not_called()
474+
should_terminate.assert_called_once_with(cancellation)
475+
assert bidi_rpc.call == call_1
476+
assert bidi_rpc.is_active is False
477+
assert call_1.cancelled is True
400478

401479
def test_send_recover(self):
402480
error = ValueError()
@@ -441,6 +519,26 @@ def test_send_failure(self):
441519
assert bidi_rpc.pending_requests == 1
442520
assert bidi_rpc._request_queue.get() is None
443521

522+
def test_recv_terminate(self):
523+
cancellation = ValueError()
524+
call = CallStub([cancellation])
525+
start_rpc = mock.create_autospec(
526+
grpc.StreamStreamMultiCallable, instance=True, return_value=call
527+
)
528+
should_recover = mock.Mock(spec=["__call__"], return_value=False)
529+
should_terminate = mock.Mock(spec=["__call__"], return_value=True)
530+
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
531+
532+
bidi_rpc.open()
533+
534+
bidi_rpc.recv()
535+
536+
should_recover.assert_not_called()
537+
should_terminate.assert_called_once_with(cancellation)
538+
assert bidi_rpc.call == call
539+
assert bidi_rpc.is_active is False
540+
assert call.cancelled is True
541+
444542
def test_recv_recover(self):
445543
error = ValueError()
446544
call_1 = CallStub([1, error])

firestore/google/cloud/firestore_v1/watch.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,8 @@
5757
"DO_NOT_USE": -1,
5858
}
5959
_RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated"
60-
_RETRYABLE_STREAM_ERRORS = (
61-
exceptions.DeadlineExceeded,
62-
exceptions.ServiceUnavailable,
63-
exceptions.InternalServerError,
64-
exceptions.Unknown,
65-
exceptions.GatewayTimeout,
66-
)
60+
_RECOVERABLE_STREAM_EXCEPTIONS = (exceptions.ServiceUnavailable,)
61+
_TERMINATING_STREAM_EXCEPTIONS = (exceptions.Cancelled,)
6762

6863
DocTreeEntry = collections.namedtuple("DocTreeEntry", ["value", "index"])
6964

@@ -153,6 +148,16 @@ def document_watch_comparator(doc1, doc2):
153148
return 0
154149

155150

151+
def _should_recover(exception):
152+
wrapped = _maybe_wrap_exception(exception)
153+
return isinstance(wrapped, _RECOVERABLE_STREAM_EXCEPTIONS)
154+
155+
156+
def _should_terminate(exception):
157+
wrapped = _maybe_wrap_exception(exception)
158+
return isinstance(wrapped, _TERMINATING_STREAM_EXCEPTIONS)
159+
160+
156161
class Watch(object):
157162

158163
BackgroundConsumer = BackgroundConsumer # FBO unit tests
@@ -199,12 +204,6 @@ def __init__(
199204
self._closing = threading.Lock()
200205
self._closed = False
201206

202-
def should_recover(exc): # pragma: NO COVER
203-
return (
204-
isinstance(exc, grpc.RpcError)
205-
and exc.code() == grpc.StatusCode.UNAVAILABLE
206-
)
207-
208207
initial_request = firestore_pb2.ListenRequest(
209208
database=self._firestore._database_string, add_target=self._targets
210209
)
@@ -214,8 +213,9 @@ def should_recover(exc): # pragma: NO COVER
214213

215214
self._rpc = ResumableBidiRpc(
216215
self._api.transport.listen,
216+
should_recover=_should_recover,
217+
should_terminate=_should_terminate,
217218
initial_request=initial_request,
218-
should_recover=should_recover,
219219
metadata=self._firestore._rpc_metadata,
220220
)
221221

firestore/tests/unit/v1/test_cross_language.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,18 @@ def convert_precondition(precond):
343343

344344

345345
class DummyRpc(object): # pragma: NO COVER
346-
def __init__(self, listen, initial_request, should_recover, metadata=None):
346+
def __init__(
347+
self,
348+
listen,
349+
should_recover,
350+
should_terminate=None,
351+
initial_request=None,
352+
metadata=None,
353+
):
347354
self.listen = listen
348355
self.initial_request = initial_request
349356
self.should_recover = should_recover
357+
self.should_terminate = should_terminate
350358
self.closed = False
351359
self.callbacks = []
352360
self._metadata = metadata

firestore/tests/unit/v1/test_watch.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,44 @@ def test_diff_doc(self):
110110
self.assertRaises(AssertionError, self._callFUT, 1, 2)
111111

112112

113+
class Test_should_recover(unittest.TestCase):
114+
def _callFUT(self, exception):
115+
from google.cloud.firestore_v1.watch import _should_recover
116+
117+
return _should_recover(exception)
118+
119+
def test_w_unavailable(self):
120+
from google.api_core.exceptions import ServiceUnavailable
121+
122+
exception = ServiceUnavailable("testing")
123+
124+
self.assertTrue(self._callFUT(exception))
125+
126+
def test_w_non_recoverable(self):
127+
exception = ValueError("testing")
128+
129+
self.assertFalse(self._callFUT(exception))
130+
131+
132+
class Test_should_terminate(unittest.TestCase):
133+
def _callFUT(self, exception):
134+
from google.cloud.firestore_v1.watch import _should_terminate
135+
136+
return _should_terminate(exception)
137+
138+
def test_w_unavailable(self):
139+
from google.api_core.exceptions import Cancelled
140+
141+
exception = Cancelled("testing")
142+
143+
self.assertTrue(self._callFUT(exception))
144+
145+
def test_w_non_recoverable(self):
146+
exception = ValueError("testing")
147+
148+
self.assertFalse(self._callFUT(exception))
149+
150+
113151
class TestWatch(unittest.TestCase):
114152
def _makeOne(
115153
self,
@@ -161,17 +199,26 @@ def _snapshot_callback(self, docs, changes, read_time):
161199
self.snapshotted = (docs, changes, read_time)
162200

163201
def test_ctor(self):
202+
from google.cloud.firestore_v1.proto import firestore_pb2
203+
from google.cloud.firestore_v1.watch import _should_recover
204+
from google.cloud.firestore_v1.watch import _should_terminate
205+
164206
inst = self._makeOne()
165207
self.assertTrue(inst._consumer.started)
166208
self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done])
209+
self.assertIs(inst._rpc.start_rpc, inst._api.transport.listen)
210+
self.assertIs(inst._rpc.should_recover, _should_recover)
211+
self.assertIs(inst._rpc.should_terminate, _should_terminate)
212+
self.assertIsInstance(inst._rpc.initial_request, firestore_pb2.ListenRequest)
213+
self.assertEqual(inst._rpc.metadata, DummyFirestore._rpc_metadata)
167214

168215
def test__on_rpc_done(self):
216+
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME
217+
169218
inst = self._makeOne()
170219
threading = DummyThreading()
171220
with mock.patch("google.cloud.firestore_v1.watch.threading", threading):
172221
inst._on_rpc_done(True)
173-
from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME
174-
175222
self.assertTrue(threading.threads[_RPC_ERROR_THREAD_NAME].started)
176223

177224
def test_close(self):
@@ -835,13 +882,21 @@ def Thread(self, name, target, kwargs):
835882

836883

837884
class DummyRpc(object):
838-
def __init__(self, listen, initial_request, should_recover, metadata=None):
839-
self.listen = listen
840-
self.initial_request = initial_request
885+
def __init__(
886+
self,
887+
start_rpc,
888+
should_recover,
889+
should_terminate=None,
890+
initial_request=None,
891+
metadata=None,
892+
):
893+
self.start_rpc = start_rpc
841894
self.should_recover = should_recover
895+
self.should_terminate = should_terminate
896+
self.initial_request = initial_request
897+
self.metadata = metadata
842898
self.closed = False
843899
self.callbacks = []
844-
self._metadata = metadata
845900

846901
def add_done_callback(self, callback):
847902
self.callbacks.append(callback)

0 commit comments

Comments
 (0)