diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py
index 177df9e9bd..ee4883d74f 100644
--- a/google/cloud/spanner_dbapi/_helpers.py
+++ b/google/cloud/spanner_dbapi/_helpers.py
@@ -13,7 +13,6 @@
# limitations under the License.
from google.cloud.spanner_dbapi.parse_utils import get_param_types
-from google.cloud.spanner_dbapi.parse_utils import parse_insert
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_v1 import param_types
@@ -51,44 +50,13 @@
def _execute_insert_heterogenous(transaction, sql_params_list):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
- param_types = get_param_types(params)
- transaction.execute_update(sql, params=params, param_types=param_types)
-
-
-def _execute_insert_homogenous(transaction, parts):
- # Perform an insert in one shot.
- return transaction.insert(
- parts.get("table"), parts.get("columns"), parts.get("values")
- )
+ transaction.execute_update(sql, params, get_param_types(params))
def handle_insert(connection, sql, params):
- parts = parse_insert(sql, params)
-
- # The split between the two styles exists because:
- # in the common case of multiple values being passed
- # with simple pyformat arguments,
- # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s)
- # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)]
- # we can take advantage of a single RPC with:
- # transaction.insert(table, columns, values)
- # instead of invoking:
- # with transaction:
- # for sql, params in sql_params_list:
- # transaction.execute_sql(sql, params, param_types)
- # which invokes more RPCs and is more costly.
-
- if parts.get("homogenous"):
- # The common case of multiple values being passed in
- # non-complex pyformat args and need to be uploaded in one RPC.
- return connection.database.run_in_transaction(_execute_insert_homogenous, parts)
- else:
- # All the other cases that are esoteric and need
- # transaction.execute_sql
- sql_params_list = parts.get("sql_params_list")
- return connection.database.run_in_transaction(
- _execute_insert_heterogenous, sql_params_list
- )
+ return connection.database.run_in_transaction(
+ _execute_insert_heterogenous, ((sql, params),)
+ )
class ColumnInfo:
diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py
index 76f04338c4..91b63a2da1 100644
--- a/google/cloud/spanner_dbapi/connection.py
+++ b/google/cloud/spanner_dbapi/connection.py
@@ -24,8 +24,6 @@
from google.cloud.spanner_v1.snapshot import Snapshot
from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
-from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
-from google.cloud.spanner_dbapi._helpers import parse_insert
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
@@ -436,23 +434,13 @@ def run_statement(self, statement, retried=False):
self._statements.append(statement)
if statement.is_insert:
- parts = parse_insert(statement.sql, statement.params)
-
- if parts.get("homogenous"):
- _execute_insert_homogenous(transaction, parts)
- return (
- iter(()),
- ResultsChecksum() if retried else statement.checksum,
- )
- else:
- _execute_insert_heterogenous(
- transaction,
- parts.get("sql_params_list"),
- )
- return (
- iter(()),
- ResultsChecksum() if retried else statement.checksum,
- )
+ _execute_insert_heterogenous(
+ transaction, ((statement.sql, statement.params),)
+ )
+ return (
+ iter(()),
+ ResultsChecksum() if retried else statement.checksum,
+ )
return (
transaction.execute_sql(
diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py
index 61bded4e80..e051f96a00 100644
--- a/google/cloud/spanner_dbapi/parse_utils.py
+++ b/google/cloud/spanner_dbapi/parse_utils.py
@@ -17,14 +17,12 @@
import datetime
import decimal
import re
-from functools import reduce
import sqlparse
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1 import JsonObject
-from .exceptions import Error, ProgrammingError
-from .parser import expect, VALUES
+from .exceptions import Error
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload
@@ -185,6 +183,12 @@ def classify_stmt(query):
:rtype: str
:returns: The query type name.
"""
+ # sqlparse will strip Cloud Spanner comments,
+ # still, special commenting styles, like
+ # PostgreSQL dollar quoted comments are not
+ # supported and will not be stripped.
+ query = sqlparse.format(query, strip_comments=True).strip()
+
if RE_DDL.match(query):
return STMT_DDL
@@ -199,255 +203,6 @@ def classify_stmt(query):
return STMT_UPDATING
-def parse_insert(insert_sql, params):
- """
- Parse an INSERT statement and generate a list of tuples of the form:
- [
- (SQL, params_per_row1),
- (SQL, params_per_row2),
- (SQL, params_per_row3),
- ...
- ]
-
- There are 4 variants of an INSERT statement:
- a) INSERT INTO
(columns...) VALUES (): no params
- b) INSERT INTO (columns...) SELECT_STMT: no params
- c) INSERT INTO (columns...) VALUES (%s,...): with params
- d) INSERT INTO (columns...) VALUES (%s,.....) with params and expressions
-
- Thus given each of the forms, it will produce a dictionary describing
- how to upload the contents to Cloud Spanner:
- Case a)
- SQL: INSERT INTO T (f1, f2) VALUES (1, 2)
- it produces:
- {
- 'sql_params_list': [
- ('INSERT INTO T (f1, f2) VALUES (1, 2)', None),
- ],
- }
-
- Case b)
- SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus WHERE col IN (%s, %s)',
- it produces:
- {
- 'sql_params_list': [
- ('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', ('a', 'b')),
- ]
- }
-
- Case c)
- SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)
- Params: ['a', 'b', 'c', 'd']
- it produces:
- {
- 'sql_params_list': [
- ('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('a', 'b')),
- ('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('c', 'd'))
- ],
- }
-
- Case d)
- SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s)
- Params: ['a', 'b', 'c', 'd']
- it produces:
- {
- 'sql_params_list': [
- ('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)),
- ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',))
- ],
- }
-
- :type insert_sql: str
- :param insert_sql: A SQL insert request.
-
- :type params: list
- :param params: A list of parameters.
-
- :rtype: dict
- :returns: A dictionary that maps `sql_params_list` to the list of
- parameters in cases a), b), d) or the dictionary with information
- about the resulting table in case c).
- """ # noqa
- match = RE_INSERT.search(insert_sql)
-
- if not match:
- raise ProgrammingError(
- "Could not parse an INSERT statement from %s" % insert_sql
- )
-
- after_values_sql = RE_VALUES_TILL_END.findall(insert_sql)
- if not after_values_sql:
- # Case b)
- insert_sql = sanitize_literals_for_upload(insert_sql)
- return {"sql_params_list": [(insert_sql, params)]}
-
- if not params:
- # Case a) perhaps?
- # Check if any %s exists.
-
- # pyformat_str_count = after_values_sql.count("%s")
- # if pyformat_str_count > 0:
- # raise ProgrammingError(
- # 'no params yet there are %d "%%s" tokens' % pyformat_str_count
- # )
- for item in after_values_sql:
- if item.count("%s") > 0:
- raise ProgrammingError(
- 'no params yet there are %d "%%s" tokens' % item.count("%s")
- )
-
- insert_sql = sanitize_literals_for_upload(insert_sql)
- # Confirmed case of:
- # SQL: INSERT INTO T (a1, a2) VALUES (1, 2)
- # Params: None
- return {"sql_params_list": [(insert_sql, None)]}
-
- _, values = expect(after_values_sql[0], VALUES)
-
- if values.homogenous():
- # Case c)
-
- columns = [mi.strip(" `") for mi in match.group("columns").split(",")]
- sql_params_list = []
- insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % (
- match.group("table_name"),
- match.group("columns"),
- values.argv[0],
- )
- values_pyformat = [str(arg) for arg in values.argv]
- rows_list = rows_for_insert_or_update(columns, params, values_pyformat)
- insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble)
- for row in rows_list:
- sql_params_list.append((insert_sql_preamble, row))
-
- return {"sql_params_list": sql_params_list}
-
- # Case d)
- # insert_sql is of the form:
- # INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s))
-
- # Sanity check:
- # length(all_args) == len(params)
- args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv])
- if args_len != len(params):
- raise ProgrammingError(
- "Invalid length: VALUES(...) len: %d != len(params): %d"
- % (args_len, len(params))
- )
-
- trim_index = insert_sql.find(after_values_sql[0])
- before_values_sql = insert_sql[:trim_index]
-
- sql_param_tuples = []
- for token_arg in values.argv:
- row_sql = before_values_sql + " VALUES%s" % token_arg
- row_sql = sanitize_literals_for_upload(row_sql)
- row_params, params = (
- tuple(params[0 : len(token_arg)]),
- params[len(token_arg) :],
- )
- sql_param_tuples.append((row_sql, row_params))
-
- return {"sql_params_list": sql_param_tuples}
-
-
-def rows_for_insert_or_update(columns, params, pyformat_args=None):
- """
- Create a tupled list of params to be used as a single value per
- value that inserted from a statement such as
- SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)'
- Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
- Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9]
-
- We'll have to convert both params types into:
- Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]
-
- :type columns: list
- :param columns: A list of the columns of the table.
-
- :type params: list
- :param params: A list of parameters.
-
- :rtype: list
- :returns: A properly restructured list of the parameters.
- """ # noqa
- if not pyformat_args:
- # This is the case where we have for example:
- # SQL: 'INSERT INTO t (f1, f2, f3)'
- # Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
- # Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9]
- #
- # We'll have to convert both params types into:
- # [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]
- contains_all_list_or_tuples = True
- for param in params:
- if not (isinstance(param, list) or isinstance(param, tuple)):
- contains_all_list_or_tuples = False
- break
-
- if contains_all_list_or_tuples:
- # The case with Params A: [(1, 2, 3), (4, 5, 6)]
- # Ensure that each param's length == len(columns)
- columns_len = len(columns)
- for param in params:
- if columns_len != len(param):
- raise Error(
- "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d"
- % (param, len(param), columns, columns_len)
- )
- return params
- else:
- # The case with Params B: [1, 2, 3]
- # Insert statements' params are only passed as tuples or lists,
- # yet for do_execute_update, we've got to pass in list of list.
- # https://googleapis.dev/python/spanner/latest/transaction-api.html\
- # #google.cloud.spanner_v1.transaction.Transaction.insert
- n_stride = len(columns)
- else:
- # This is the case where we have for example:
- # SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s),
- # (%s, %s, %s), (%s, %s, %s)'
- # Params: [1, 2, 3, 4, 5, 6, 7, 8, 9]
- # which should become
- # Columns: (f1, f2, f3)
- # new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]
-
- # Sanity check 1: all the pyformat_values should have the exact same
- # length.
- first, rest = pyformat_args[0], pyformat_args[1:]
- n_stride = first.count("%s")
- for pyfmt_value in rest:
- n = pyfmt_value.count("%s")
- if n_stride != n:
- raise Error(
- "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d"
- % (first, n_stride, pyfmt_value, n)
- )
-
- # Sanity check 2: len(params) MUST be a multiple of n_stride aka
- # len(count of %s).
- # so that we can properly group for example:
- # Given pyformat args:
- # (%s, %s, %s)
- # Params:
- # [1, 2, 3, 4, 5, 6, 7, 8, 9]
- # into
- # [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
- if (len(params) % n_stride) != 0:
- raise ProgrammingError(
- "Invalid length: len(params)=%d MUST be a multiple of "
- "len(pyformat_args)=%d" % (len(params), n_stride)
- )
-
- # Now chop up the strides.
- strides = []
- for step in range(0, len(params), n_stride):
- stride = tuple(params[step : step + n_stride :])
- strides.append(stride)
-
- return strides
-
-
def sql_pyformat_args_to_spanner(sql, params):
"""
Transform pyformat set SQL to named arguments for Cloud Spanner.
diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py
index 84d6b3e323..1782978d62 100644
--- a/tests/unit/spanner_dbapi/test__helpers.py
+++ b/tests/unit/spanner_dbapi/test__helpers.py
@@ -32,23 +32,37 @@ def test__execute_insert_heterogenous(self):
"google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None
) as mock_param_types:
transaction = mock.MagicMock()
- transaction.execute_update = mock_execute = mock.MagicMock()
- _helpers._execute_insert_heterogenous(transaction, [params])
+ transaction.execute_update = mock_update = mock.MagicMock()
+ _helpers._execute_insert_heterogenous(transaction, (params,))
mock_pyformat.assert_called_once_with(params[0], params[1])
mock_param_types.assert_called_once_with(None)
- mock_execute.assert_called_once_with(sql, params=None, param_types=None)
+ mock_update.assert_called_once_with(sql, None, None)
- def test__execute_insert_homogenous(self):
+ def test__execute_insert_heterogenous_error(self):
from google.cloud.spanner_dbapi import _helpers
+ from google.api_core.exceptions import Unknown
- transaction = mock.MagicMock()
- transaction.insert = mock.MagicMock()
- parts = mock.MagicMock()
- parts.get = mock.MagicMock(return_value=0)
+ sql = "sql"
+ params = (sql, None)
+ with mock.patch(
+ "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner",
+ return_value=params,
+ ) as mock_pyformat:
+ with mock.patch(
+ "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None
+ ) as mock_param_types:
+ transaction = mock.MagicMock()
+ transaction.execute_update = mock_update = mock.MagicMock(
+ side_effect=Unknown("Unknown")
+ )
- _helpers._execute_insert_homogenous(transaction, parts)
- transaction.insert.assert_called_once_with(0, 0, 0)
+ with self.assertRaises(Unknown):
+ _helpers._execute_insert_heterogenous(transaction, (params,))
+
+ mock_pyformat.assert_called_once_with(params[0], params[1])
+ mock_param_types.assert_called_once_with(None)
+ mock_update.assert_called_once_with(sql, None, None)
def test_handle_insert(self):
from google.cloud.spanner_dbapi import _helpers
@@ -56,19 +70,13 @@ def test_handle_insert(self):
connection = mock.MagicMock()
connection.database.run_in_transaction = mock_run_in = mock.MagicMock()
sql = "sql"
- parts = mock.MagicMock()
- with mock.patch(
- "google.cloud.spanner_dbapi._helpers.parse_insert", return_value=parts
- ):
- parts.get = mock.MagicMock(return_value=True)
- mock_run_in.return_value = 0
- result = _helpers.handle_insert(connection, sql, None)
- self.assertEqual(result, 0)
-
- parts.get = mock.MagicMock(return_value=False)
- mock_run_in.return_value = 1
- result = _helpers.handle_insert(connection, sql, None)
- self.assertEqual(result, 1)
+ mock_run_in.return_value = 0
+ result = _helpers.handle_insert(connection, sql, None)
+ self.assertEqual(result, 0)
+
+ mock_run_in.return_value = 1
+ result = _helpers.handle_insert(connection, sql, None)
+ self.assertEqual(result, 1)
class TestColumnInfo(unittest.TestCase):
diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py
index 7902de6405..e15f6af33b 100644
--- a/tests/unit/spanner_dbapi/test_connection.py
+++ b/tests/unit/spanner_dbapi/test_connection.py
@@ -392,13 +392,17 @@ def test_run_statement_w_heterogenous_insert_statements(self):
"""Check that Connection executed heterogenous insert statements."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement
+ from google.rpc.status_pb2 import Status
+ from google.rpc.code_pb2 import OK
sql = "INSERT INTO T (f1, f2) VALUES (1, 2)"
params = None
param_types = None
connection = self._make_connection()
- connection.transaction_checkout = mock.Mock()
+ transaction = mock.MagicMock()
+ connection.transaction_checkout = mock.Mock(return_value=transaction)
+ transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1))
statement = Statement(sql, params, param_types, ResultsChecksum(), True)
connection.run_statement(statement, retried=True)
@@ -409,13 +413,17 @@ def test_run_statement_w_homogeneous_insert_statements(self):
"""Check that Connection executed homogeneous insert statements."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement
+ from google.rpc.status_pb2 import Status
+ from google.rpc.code_pb2 import OK
sql = "INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)"
params = ["a", "b", "c", "d"]
param_types = {"f1": str, "f2": str}
connection = self._make_connection()
- connection.transaction_checkout = mock.Mock()
+ transaction = mock.MagicMock()
+ connection.transaction_checkout = mock.Mock(return_value=transaction)
+ transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1))
statement = Statement(sql, params, param_types, ResultsChecksum(), True)
connection.run_statement(statement, retried=True)
diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py
index 71e4a96d6e..3f379f96ac 100644
--- a/tests/unit/spanner_dbapi/test_cursor.py
+++ b/tests/unit/spanner_dbapi/test_cursor.py
@@ -564,7 +564,7 @@ def test_executemany_insert_batch_aborted(self):
transaction1 = mock.Mock(committed=False, rolled_back=False)
transaction1.batch_update = mock.Mock(
- side_effect=[(mock.Mock(code=ABORTED, details=err_details), [])]
+ side_effect=[(mock.Mock(code=ABORTED, message=err_details), [])]
)
transaction2 = self._transaction_mock()
@@ -732,15 +732,6 @@ def test_setoutputsize(self):
with self.assertRaises(exceptions.InterfaceError):
cursor.setoutputsize(size=None)
- # def test_handle_insert(self):
- # pass
- #
- # def test_do_execute_insert_heterogenous(self):
- # pass
- #
- # def test_do_execute_insert_homogenous(self):
- # pass
-
def test_handle_dql(self):
from google.cloud.spanner_dbapi import utils
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py
index b0f363299b..511ad838cf 100644
--- a/tests/unit/spanner_dbapi/test_parse_utils.py
+++ b/tests/unit/spanner_dbapi/test_parse_utils.py
@@ -61,199 +61,6 @@ def test_classify_stmt(self):
for query, want_class in cases:
self.assertEqual(classify_stmt(query), want_class)
- @unittest.skipIf(skip_condition, skip_message)
- def test_parse_insert(self):
- from google.cloud.spanner_dbapi.parse_utils import parse_insert
- from google.cloud.spanner_dbapi.exceptions import ProgrammingError
-
- with self.assertRaises(ProgrammingError):
- parse_insert("bad-sql", None)
-
- cases = [
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)",
- [1, 2, 3, 4, 5, 6],
- {
- "sql_params_list": [
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)",
- (1, 2, 3),
- ),
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)",
- (4, 5, 6),
- ),
- ]
- },
- ),
- (
- "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)",
- [1, 2, 3, 4, 5, 6],
- {
- "sql_params_list": [
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)",
- (1, 2, 3),
- ),
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)",
- (4, 5, 6),
- ),
- ]
- },
- ),
- (
- "INSERT INTO sales.addresses (street, city, state, zip_code) "
- "SELECT street, city, state, zip_code FROM sales.customers"
- "ORDER BY first_name, last_name",
- None,
- {
- "sql_params_list": [
- (
- "INSERT INTO sales.addresses (street, city, state, zip_code) "
- "SELECT street, city, state, zip_code FROM sales.customers"
- "ORDER BY first_name, last_name",
- None,
- )
- ]
- },
- ),
- (
- "INSERT INTO ap (n, ct, cn) "
- "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)",
- (1, 2, 3, 4, 5, 6, 7, 8, 9),
- {
- "sql_params_list": [
- ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (1, 2, 3)),
- ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (4, 5, 6)),
- ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (7, 8, 9)),
- ]
- },
- ),
- (
- "INSERT INTO `no` (`yes`) VALUES (%s)",
- (1, 4, 5),
- {
- "sql_params_list": [
- ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)),
- ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)),
- ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)),
- ]
- },
- ),
- (
- "INSERT INTO T (f1, f2) VALUES (1, 2)",
- None,
- {"sql_params_list": [("INSERT INTO T (f1, f2) VALUES (1, 2)", None)]},
- ),
- (
- "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)",
- (1, "FOO", 5, 10, 11, 29),
- {
- "sql_params_list": [
- (
- "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))",
- (1, "FOO"),
- ),
- ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (5, 10)),
- ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (11, 29)),
- ]
- },
- ),
- ]
-
- sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
- with self.assertRaises(ProgrammingError):
- parse_insert(sql, None)
-
- for sql, params, want in cases:
- with self.subTest(sql=sql):
- got = parse_insert(sql, params)
- self.assertEqual(got, want, "Mismatch with parse_insert of `%s`" % sql)
-
- @unittest.skipIf(skip_condition, skip_message)
- def test_parse_insert_invalid(self):
- from google.cloud.spanner_dbapi import exceptions
- from google.cloud.spanner_dbapi.parse_utils import parse_insert
-
- cases = [
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)",
- [1, 2, 3, 4, 5, 6, 7],
- "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3",
- ),
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))",
- [1, 2, 3, 4, 5, 6, 7],
- "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7",
- ),
- (
- "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))",
- [1, 2, 3, 4, 5, 6],
- "VALUES: expected `,` got \\) in \\)",
- ),
- ]
-
- for sql, params, wantException in cases:
- with self.subTest(sql=sql):
- self.assertRaisesRegex(
- exceptions.ProgrammingError,
- wantException,
- lambda: parse_insert(sql, params),
- )
-
- @unittest.skipIf(skip_condition, skip_message)
- def test_rows_for_insert_or_update(self):
- from google.cloud.spanner_dbapi.parse_utils import rows_for_insert_or_update
- from google.cloud.spanner_dbapi.exceptions import Error
-
- with self.assertRaises(Error):
- rows_for_insert_or_update([0], [[]])
-
- with self.assertRaises(Error):
- rows_for_insert_or_update([0], None, ["0", "%s"])
-
- cases = [
- (
- ["id", "app", "name"],
- [(5, "ap", "n"), (6, "bp", "m")],
- None,
- [(5, "ap", "n"), (6, "bp", "m")],
- ),
- (
- ["app", "name"],
- [("ap", "n"), ("bp", "m")],
- None,
- [("ap", "n"), ("bp", "m")],
- ),
- (
- ["app", "name", "fn"],
- ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"],
- ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"],
- [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")],
- ),
- (
- ["app", "name", "fn", "ln"],
- [
- ("ap", "n", (45, "nested"), "ll"),
- ("bp", "m", "f2", "mt"),
- ("fp", "cp", "o", "f3"),
- ],
- None,
- [
- ("ap", "n", (45, "nested"), "ll"),
- ("bp", "m", "f2", "mt"),
- ("fp", "cp", "o", "f3"),
- ],
- ),
- (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]),
- ]
-
- for i, (columns, params, pyformat_args, want) in enumerate(cases):
- with self.subTest(i=i):
- got = rows_for_insert_or_update(columns, params, pyformat_args)
- self.assertEqual(got, want)
-
@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner(self):
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
@@ -411,20 +218,3 @@ def test_escape_name(self):
with self.subTest(name=name):
got = escape_name(name)
self.assertEqual(got, want)
-
- def test_insert_from_select(self):
- """Check that INSERT from SELECT clause can be executed with arguments."""
- from google.cloud.spanner_dbapi.parse_utils import parse_insert
-
- SQL = """
-INSERT INTO tab_name (id, data)
-SELECT tab_name.id + %s AS anon_1, tab_name.data
-FROM tab_name
-WHERE tab_name.data IN (%s, %s)
-"""
- ARGS = [5, "data2", "data3"]
-
- self.assertEqual(
- parse_insert(SQL, ARGS),
- {"sql_params_list": [(SQL, ARGS)]},
- )