diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index 177df9e9bd..ee4883d74f 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -13,7 +13,6 @@ # limitations under the License. from google.cloud.spanner_dbapi.parse_utils import get_param_types -from google.cloud.spanner_dbapi.parse_utils import parse_insert from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_v1 import param_types @@ -51,44 +50,13 @@ def _execute_insert_heterogenous(transaction, sql_params_list): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) - param_types = get_param_types(params) - transaction.execute_update(sql, params=params, param_types=param_types) - - -def _execute_insert_homogenous(transaction, parts): - # Perform an insert in one shot. - return transaction.insert( - parts.get("table"), parts.get("columns"), parts.get("values") - ) + transaction.execute_update(sql, params, get_param_types(params)) def handle_insert(connection, sql, params): - parts = parse_insert(sql, params) - - # The split between the two styles exists because: - # in the common case of multiple values being passed - # with simple pyformat arguments, - # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) - # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] - # we can take advantage of a single RPC with: - # transaction.insert(table, columns, values) - # instead of invoking: - # with transaction: - # for sql, params in sql_params_list: - # transaction.execute_sql(sql, params, param_types) - # which invokes more RPCs and is more costly. - - if parts.get("homogenous"): - # The common case of multiple values being passed in - # non-complex pyformat args and need to be uploaded in one RPC. - return connection.database.run_in_transaction(_execute_insert_homogenous, parts) - else: - # All the other cases that are esoteric and need - # transaction.execute_sql - sql_params_list = parts.get("sql_params_list") - return connection.database.run_in_transaction( - _execute_insert_heterogenous, sql_params_list - ) + return connection.database.run_in_transaction( + _execute_insert_heterogenous, ((sql, params),) + ) class ColumnInfo: diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 76f04338c4..91b63a2da1 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -24,8 +24,6 @@ from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous -from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous -from google.cloud.spanner_dbapi._helpers import parse_insert from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor @@ -436,23 +434,13 @@ def run_statement(self, statement, retried=False): self._statements.append(statement) if statement.is_insert: - parts = parse_insert(statement.sql, statement.params) - - if parts.get("homogenous"): - _execute_insert_homogenous(transaction, parts) - return ( - iter(()), - ResultsChecksum() if retried else statement.checksum, - ) - else: - _execute_insert_heterogenous( - transaction, - parts.get("sql_params_list"), - ) - return ( - iter(()), - ResultsChecksum() if retried else statement.checksum, - ) + _execute_insert_heterogenous( + transaction, ((statement.sql, statement.params),) + ) + return ( + iter(()), + ResultsChecksum() if retried else statement.checksum, + ) return ( transaction.execute_sql( diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 61bded4e80..e051f96a00 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -17,14 +17,12 @@ import datetime import decimal import re -from functools import reduce import sqlparse from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1 import JsonObject -from .exceptions import Error, ProgrammingError -from .parser import expect, VALUES +from .exceptions import Error from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload @@ -185,6 +183,12 @@ def classify_stmt(query): :rtype: str :returns: The query type name. """ + # sqlparse will strip Cloud Spanner comments, + # still, special commenting styles, like + # PostgreSQL dollar quoted comments are not + # supported and will not be stripped. + query = sqlparse.format(query, strip_comments=True).strip() + if RE_DDL.match(query): return STMT_DDL @@ -199,255 +203,6 @@ def classify_stmt(query): return STMT_UPDATING -def parse_insert(insert_sql, params): - """ - Parse an INSERT statement and generate a list of tuples of the form: - [ - (SQL, params_per_row1), - (SQL, params_per_row2), - (SQL, params_per_row3), - ... - ] - - There are 4 variants of an INSERT statement: - a) INSERT INTO (columns...) VALUES (): no params - b) INSERT INTO
(columns...) SELECT_STMT: no params - c) INSERT INTO
(columns...) VALUES (%s,...): with params - d) INSERT INTO
(columns...) VALUES (%s,.....) with params and expressions - - Thus given each of the forms, it will produce a dictionary describing - how to upload the contents to Cloud Spanner: - Case a) - SQL: INSERT INTO T (f1, f2) VALUES (1, 2) - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (f1, f2) VALUES (1, 2)', None), - ], - } - - Case b) - SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus WHERE col IN (%s, %s)', - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', ('a', 'b')), - ] - } - - Case c) - SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s) - Params: ['a', 'b', 'c', 'd'] - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('a', 'b')), - ('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('c', 'd')) - ], - } - - Case d) - SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s) - Params: ['a', 'b', 'c', 'd'] - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)), - ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) - ], - } - - :type insert_sql: str - :param insert_sql: A SQL insert request. - - :type params: list - :param params: A list of parameters. - - :rtype: dict - :returns: A dictionary that maps `sql_params_list` to the list of - parameters in cases a), b), d) or the dictionary with information - about the resulting table in case c). - """ # noqa - match = RE_INSERT.search(insert_sql) - - if not match: - raise ProgrammingError( - "Could not parse an INSERT statement from %s" % insert_sql - ) - - after_values_sql = RE_VALUES_TILL_END.findall(insert_sql) - if not after_values_sql: - # Case b) - insert_sql = sanitize_literals_for_upload(insert_sql) - return {"sql_params_list": [(insert_sql, params)]} - - if not params: - # Case a) perhaps? - # Check if any %s exists. - - # pyformat_str_count = after_values_sql.count("%s") - # if pyformat_str_count > 0: - # raise ProgrammingError( - # 'no params yet there are %d "%%s" tokens' % pyformat_str_count - # ) - for item in after_values_sql: - if item.count("%s") > 0: - raise ProgrammingError( - 'no params yet there are %d "%%s" tokens' % item.count("%s") - ) - - insert_sql = sanitize_literals_for_upload(insert_sql) - # Confirmed case of: - # SQL: INSERT INTO T (a1, a2) VALUES (1, 2) - # Params: None - return {"sql_params_list": [(insert_sql, None)]} - - _, values = expect(after_values_sql[0], VALUES) - - if values.homogenous(): - # Case c) - - columns = [mi.strip(" `") for mi in match.group("columns").split(",")] - sql_params_list = [] - insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % ( - match.group("table_name"), - match.group("columns"), - values.argv[0], - ) - values_pyformat = [str(arg) for arg in values.argv] - rows_list = rows_for_insert_or_update(columns, params, values_pyformat) - insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble) - for row in rows_list: - sql_params_list.append((insert_sql_preamble, row)) - - return {"sql_params_list": sql_params_list} - - # Case d) - # insert_sql is of the form: - # INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s)) - - # Sanity check: - # length(all_args) == len(params) - args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv]) - if args_len != len(params): - raise ProgrammingError( - "Invalid length: VALUES(...) len: %d != len(params): %d" - % (args_len, len(params)) - ) - - trim_index = insert_sql.find(after_values_sql[0]) - before_values_sql = insert_sql[:trim_index] - - sql_param_tuples = [] - for token_arg in values.argv: - row_sql = before_values_sql + " VALUES%s" % token_arg - row_sql = sanitize_literals_for_upload(row_sql) - row_params, params = ( - tuple(params[0 : len(token_arg)]), - params[len(token_arg) :], - ) - sql_param_tuples.append((row_sql, row_params)) - - return {"sql_params_list": sql_param_tuples} - - -def rows_for_insert_or_update(columns, params, pyformat_args=None): - """ - Create a tupled list of params to be used as a single value per - value that inserted from a statement such as - SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)' - Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] - - We'll have to convert both params types into: - Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - - :type columns: list - :param columns: A list of the columns of the table. - - :type params: list - :param params: A list of parameters. - - :rtype: list - :returns: A properly restructured list of the parameters. - """ # noqa - if not pyformat_args: - # This is the case where we have for example: - # SQL: 'INSERT INTO t (f1, f2, f3)' - # Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - # Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] - # - # We'll have to convert both params types into: - # [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - contains_all_list_or_tuples = True - for param in params: - if not (isinstance(param, list) or isinstance(param, tuple)): - contains_all_list_or_tuples = False - break - - if contains_all_list_or_tuples: - # The case with Params A: [(1, 2, 3), (4, 5, 6)] - # Ensure that each param's length == len(columns) - columns_len = len(columns) - for param in params: - if columns_len != len(param): - raise Error( - "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d" - % (param, len(param), columns, columns_len) - ) - return params - else: - # The case with Params B: [1, 2, 3] - # Insert statements' params are only passed as tuples or lists, - # yet for do_execute_update, we've got to pass in list of list. - # https://googleapis.dev/python/spanner/latest/transaction-api.html\ - # #google.cloud.spanner_v1.transaction.Transaction.insert - n_stride = len(columns) - else: - # This is the case where we have for example: - # SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), - # (%s, %s, %s), (%s, %s, %s)' - # Params: [1, 2, 3, 4, 5, 6, 7, 8, 9] - # which should become - # Columns: (f1, f2, f3) - # new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - - # Sanity check 1: all the pyformat_values should have the exact same - # length. - first, rest = pyformat_args[0], pyformat_args[1:] - n_stride = first.count("%s") - for pyfmt_value in rest: - n = pyfmt_value.count("%s") - if n_stride != n: - raise Error( - "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d" - % (first, n_stride, pyfmt_value, n) - ) - - # Sanity check 2: len(params) MUST be a multiple of n_stride aka - # len(count of %s). - # so that we can properly group for example: - # Given pyformat args: - # (%s, %s, %s) - # Params: - # [1, 2, 3, 4, 5, 6, 7, 8, 9] - # into - # [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - if (len(params) % n_stride) != 0: - raise ProgrammingError( - "Invalid length: len(params)=%d MUST be a multiple of " - "len(pyformat_args)=%d" % (len(params), n_stride) - ) - - # Now chop up the strides. - strides = [] - for step in range(0, len(params), n_stride): - stride = tuple(params[step : step + n_stride :]) - strides.append(stride) - - return strides - - def sql_pyformat_args_to_spanner(sql, params): """ Transform pyformat set SQL to named arguments for Cloud Spanner. diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index 84d6b3e323..1782978d62 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -32,23 +32,37 @@ def test__execute_insert_heterogenous(self): "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None ) as mock_param_types: transaction = mock.MagicMock() - transaction.execute_update = mock_execute = mock.MagicMock() - _helpers._execute_insert_heterogenous(transaction, [params]) + transaction.execute_update = mock_update = mock.MagicMock() + _helpers._execute_insert_heterogenous(transaction, (params,)) mock_pyformat.assert_called_once_with(params[0], params[1]) mock_param_types.assert_called_once_with(None) - mock_execute.assert_called_once_with(sql, params=None, param_types=None) + mock_update.assert_called_once_with(sql, None, None) - def test__execute_insert_homogenous(self): + def test__execute_insert_heterogenous_error(self): from google.cloud.spanner_dbapi import _helpers + from google.api_core.exceptions import Unknown - transaction = mock.MagicMock() - transaction.insert = mock.MagicMock() - parts = mock.MagicMock() - parts.get = mock.MagicMock(return_value=0) + sql = "sql" + params = (sql, None) + with mock.patch( + "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", + return_value=params, + ) as mock_pyformat: + with mock.patch( + "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None + ) as mock_param_types: + transaction = mock.MagicMock() + transaction.execute_update = mock_update = mock.MagicMock( + side_effect=Unknown("Unknown") + ) - _helpers._execute_insert_homogenous(transaction, parts) - transaction.insert.assert_called_once_with(0, 0, 0) + with self.assertRaises(Unknown): + _helpers._execute_insert_heterogenous(transaction, (params,)) + + mock_pyformat.assert_called_once_with(params[0], params[1]) + mock_param_types.assert_called_once_with(None) + mock_update.assert_called_once_with(sql, None, None) def test_handle_insert(self): from google.cloud.spanner_dbapi import _helpers @@ -56,19 +70,13 @@ def test_handle_insert(self): connection = mock.MagicMock() connection.database.run_in_transaction = mock_run_in = mock.MagicMock() sql = "sql" - parts = mock.MagicMock() - with mock.patch( - "google.cloud.spanner_dbapi._helpers.parse_insert", return_value=parts - ): - parts.get = mock.MagicMock(return_value=True) - mock_run_in.return_value = 0 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 0) - - parts.get = mock.MagicMock(return_value=False) - mock_run_in.return_value = 1 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 1) + mock_run_in.return_value = 0 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 0) + + mock_run_in.return_value = 1 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 1) class TestColumnInfo(unittest.TestCase): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 7902de6405..e15f6af33b 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -392,13 +392,17 @@ def test_run_statement_w_heterogenous_insert_statements(self): """Check that Connection executed heterogenous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement + from google.rpc.status_pb2 import Status + from google.rpc.code_pb2 import OK sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" params = None param_types = None connection = self._make_connection() - connection.transaction_checkout = mock.Mock() + transaction = mock.MagicMock() + connection.transaction_checkout = mock.Mock(return_value=transaction) + transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) statement = Statement(sql, params, param_types, ResultsChecksum(), True) connection.run_statement(statement, retried=True) @@ -409,13 +413,17 @@ def test_run_statement_w_homogeneous_insert_statements(self): """Check that Connection executed homogeneous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement + from google.rpc.status_pb2 import Status + from google.rpc.code_pb2 import OK sql = "INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)" params = ["a", "b", "c", "d"] param_types = {"f1": str, "f2": str} connection = self._make_connection() - connection.transaction_checkout = mock.Mock() + transaction = mock.MagicMock() + connection.transaction_checkout = mock.Mock(return_value=transaction) + transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) statement = Statement(sql, params, param_types, ResultsChecksum(), True) connection.run_statement(statement, retried=True) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 71e4a96d6e..3f379f96ac 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -564,7 +564,7 @@ def test_executemany_insert_batch_aborted(self): transaction1 = mock.Mock(committed=False, rolled_back=False) transaction1.batch_update = mock.Mock( - side_effect=[(mock.Mock(code=ABORTED, details=err_details), [])] + side_effect=[(mock.Mock(code=ABORTED, message=err_details), [])] ) transaction2 = self._transaction_mock() @@ -732,15 +732,6 @@ def test_setoutputsize(self): with self.assertRaises(exceptions.InterfaceError): cursor.setoutputsize(size=None) - # def test_handle_insert(self): - # pass - # - # def test_do_execute_insert_heterogenous(self): - # pass - # - # def test_do_execute_insert_homogenous(self): - # pass - def test_handle_dql(self): from google.cloud.spanner_dbapi import utils from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index b0f363299b..511ad838cf 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -61,199 +61,6 @@ def test_classify_stmt(self): for query, want_class in cases: self.assertEqual(classify_stmt(query), want_class) - @unittest.skipIf(skip_condition, skip_message) - def test_parse_insert(self): - from google.cloud.spanner_dbapi.parse_utils import parse_insert - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - - with self.assertRaises(ProgrammingError): - parse_insert("bad-sql", None) - - cases = [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - [1, 2, 3, 4, 5, 6], - { - "sql_params_list": [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ] - }, - ), - ( - "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)", - [1, 2, 3, 4, 5, 6], - { - "sql_params_list": [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ] - }, - ), - ( - "INSERT INTO sales.addresses (street, city, state, zip_code) " - "SELECT street, city, state, zip_code FROM sales.customers" - "ORDER BY first_name, last_name", - None, - { - "sql_params_list": [ - ( - "INSERT INTO sales.addresses (street, city, state, zip_code) " - "SELECT street, city, state, zip_code FROM sales.customers" - "ORDER BY first_name, last_name", - None, - ) - ] - }, - ), - ( - "INSERT INTO ap (n, ct, cn) " - "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)", - (1, 2, 3, 4, 5, 6, 7, 8, 9), - { - "sql_params_list": [ - ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (1, 2, 3)), - ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (4, 5, 6)), - ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (7, 8, 9)), - ] - }, - ), - ( - "INSERT INTO `no` (`yes`) VALUES (%s)", - (1, 4, 5), - { - "sql_params_list": [ - ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)), - ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)), - ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)), - ] - }, - ), - ( - "INSERT INTO T (f1, f2) VALUES (1, 2)", - None, - {"sql_params_list": [("INSERT INTO T (f1, f2) VALUES (1, 2)", None)]}, - ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", - (1, "FOO", 5, 10, 11, 29), - { - "sql_params_list": [ - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", - (1, "FOO"), - ), - ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (5, 10)), - ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (11, 29)), - ] - }, - ), - ] - - sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" - with self.assertRaises(ProgrammingError): - parse_insert(sql, None) - - for sql, params, want in cases: - with self.subTest(sql=sql): - got = parse_insert(sql, params) - self.assertEqual(got, want, "Mismatch with parse_insert of `%s`" % sql) - - @unittest.skipIf(skip_condition, skip_message) - def test_parse_insert_invalid(self): - from google.cloud.spanner_dbapi import exceptions - from google.cloud.spanner_dbapi.parse_utils import parse_insert - - cases = [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", - [1, 2, 3, 4, 5, 6, 7], - "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3", - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))", - [1, 2, 3, 4, 5, 6, 7], - "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7", - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))", - [1, 2, 3, 4, 5, 6], - "VALUES: expected `,` got \\) in \\)", - ), - ] - - for sql, params, wantException in cases: - with self.subTest(sql=sql): - self.assertRaisesRegex( - exceptions.ProgrammingError, - wantException, - lambda: parse_insert(sql, params), - ) - - @unittest.skipIf(skip_condition, skip_message) - def test_rows_for_insert_or_update(self): - from google.cloud.spanner_dbapi.parse_utils import rows_for_insert_or_update - from google.cloud.spanner_dbapi.exceptions import Error - - with self.assertRaises(Error): - rows_for_insert_or_update([0], [[]]) - - with self.assertRaises(Error): - rows_for_insert_or_update([0], None, ["0", "%s"]) - - cases = [ - ( - ["id", "app", "name"], - [(5, "ap", "n"), (6, "bp", "m")], - None, - [(5, "ap", "n"), (6, "bp", "m")], - ), - ( - ["app", "name"], - [("ap", "n"), ("bp", "m")], - None, - [("ap", "n"), ("bp", "m")], - ), - ( - ["app", "name", "fn"], - ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"], - ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"], - [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")], - ), - ( - ["app", "name", "fn", "ln"], - [ - ("ap", "n", (45, "nested"), "ll"), - ("bp", "m", "f2", "mt"), - ("fp", "cp", "o", "f3"), - ], - None, - [ - ("ap", "n", (45, "nested"), "ll"), - ("bp", "m", "f2", "mt"), - ("fp", "cp", "o", "f3"), - ], - ), - (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]), - ] - - for i, (columns, params, pyformat_args, want) in enumerate(cases): - with self.subTest(i=i): - got = rows_for_insert_or_update(columns, params, pyformat_args) - self.assertEqual(got, want) - @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner @@ -411,20 +218,3 @@ def test_escape_name(self): with self.subTest(name=name): got = escape_name(name) self.assertEqual(got, want) - - def test_insert_from_select(self): - """Check that INSERT from SELECT clause can be executed with arguments.""" - from google.cloud.spanner_dbapi.parse_utils import parse_insert - - SQL = """ -INSERT INTO tab_name (id, data) -SELECT tab_name.id + %s AS anon_1, tab_name.data -FROM tab_name -WHERE tab_name.data IN (%s, %s) -""" - ARGS = [5, "data2", "data3"] - - self.assertEqual( - parse_insert(SQL, ARGS), - {"sql_params_list": [(SQL, ARGS)]}, - )