Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import threading

import google.auth.credentials
from google.api_core.retry import if_exception_type
from google.protobuf.struct_pb2 import Struct
from google.cloud.exceptions import NotFound
from google.api_core.exceptions import Aborted
import six

# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -394,29 +396,36 @@ def execute_partitioned_dml(

metadata = _metadata_with_prefix(self.name)

with SessionCheckout(self._pool) as session:
def execute_pdml():
with SessionCheckout(self._pool) as session:

txn = api.begin_transaction(
session.name, txn_options, metadata=metadata
)

txn = api.begin_transaction(session.name, txn_options, metadata=metadata)
txn_selector = TransactionSelector(id=txn.id)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)

txn_selector = TransactionSelector(id=txn.id)
iterator = _restart_on_unavailable(restart)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)
result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

iterator = _restart_on_unavailable(restart)
return result_set.stats.row_count_lower_bound

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
retry_config = api._method_configs["ExecuteStreamingSql"].retry

return result_set.stats.row_count_lower_bound
return _retry_on_aborted(execute_pdml, retry_config)()

def session(self, labels=None):
"""Factory to create a session for this database.
Expand Down Expand Up @@ -976,3 +985,19 @@ def __init__(self, source_type, backup_info):
@classmethod
def from_pb(cls, pb):
return cls(pb.source_type, pb.backup_info)


def _retry_on_aborted(func, retry_config):
"""Helper for :meth:`Database.execute_partitioned_dml`.

Wrap function in a Retry that will retry on Aborted exceptions
with the retry config specified.

:type func: callable
:param func: the function to be retried on Aborted exceptions

:type retry_config: Retry
:param retry_config: retry object with the settings to be used
"""
retry = retry_config.with_predicate(if_exception_type(Aborted))
return retry(func)
46 changes: 41 additions & 5 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase):
SESSION_ID = "session_id"
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
TRANSACTION_ID = b"transaction_id"
RETRY_TRANSACTION_ID = b"transaction_id_retry"
BACKUP_ID = "backup_id"
BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID

Expand Down Expand Up @@ -735,8 +736,10 @@ def test_drop_success(self):
)

def _execute_partitioned_dml_helper(
self, dml, params=None, param_types=None, query_options=None
self, dml, params=None, param_types=None, query_options=None, retried=False
):
from google.api_core.exceptions import Aborted
from google.api_core.retry import Retry
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1.proto.result_set_pb2 import (
PartialResultSet,
Expand All @@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper(
_merge_query_options,
)

import collections

MethodConfig = collections.namedtuple("MethodConfig", ["retry"])

transaction_pb = TransactionPB(id=self.TRANSACTION_ID)

stats_pb = ResultSetStats(row_count_lower_bound=2)
Expand All @@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper(
pool.put(session)
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
api = database._spanner_api = self._make_spanner_api()
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator
api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())}
if retried:
retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID)
api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb]
api.execute_streaming_sql.side_effect = [Aborted("test"), iterator]
else:
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator

row_count = database.execute_partitioned_dml(
dml, params, param_types, query_options
Expand All @@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper(
partitioned_dml=TransactionOptions.PartitionedDml()
)

api.begin_transaction.assert_called_once_with(
api.begin_transaction.assert_called_with(
session.name,
txn_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
self.assertEqual(api.begin_transaction.call_count, 2)
else:
self.assertEqual(api.begin_transaction.call_count, 1)

if params:
expected_params = Struct(
Expand All @@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper(
expected_query_options, query_options
)

api.execute_streaming_sql.assert_called_once_with(
api.execute_streaming_sql.assert_any_call(
self.SESSION_NAME,
dml,
transaction=expected_transaction,
Expand All @@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper(
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
expected_retry_transaction = TransactionSelector(
id=self.RETRY_TRANSACTION_ID
)
api.execute_streaming_sql.assert_called_with(
self.SESSION_NAME,
dml,
transaction=expected_retry_transaction,
params=expected_params,
param_types=param_types,
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
self.assertEqual(api.execute_streaming_sql.call_count, 2)
else:
self.assertEqual(api.execute_streaming_sql.call_count, 1)

def test_execute_partitioned_dml_wo_params(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM)
Expand All @@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self):
query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"),
)

def test_execute_partitioned_dml_wo_params_retry_aborted(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)

def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

Expand Down