Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 4 additions & 36 deletions google/cloud/spanner_dbapi/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import parse_insert
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_v1 import param_types

Expand Down Expand Up @@ -51,44 +50,13 @@
def _execute_insert_heterogenous(transaction, sql_params_list):
for sql, params in sql_params_list:
sql, params = sql_pyformat_args_to_spanner(sql, params)
param_types = get_param_types(params)
transaction.execute_update(sql, params=params, param_types=param_types)


def _execute_insert_homogenous(transaction, parts):
# Perform an insert in one shot.
return transaction.insert(
parts.get("table"), parts.get("columns"), parts.get("values")
)
transaction.execute_update(sql, params, get_param_types(params))


def handle_insert(connection, sql, params):
parts = parse_insert(sql, params)

# The split between the two styles exists because:
# in the common case of multiple values being passed
# with simple pyformat arguments,
# SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s)
# Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)]
# we can take advantage of a single RPC with:
# transaction.insert(table, columns, values)
# instead of invoking:
# with transaction:
# for sql, params in sql_params_list:
# transaction.execute_sql(sql, params, param_types)
# which invokes more RPCs and is more costly.

if parts.get("homogenous"):
# The common case of multiple values being passed in
# non-complex pyformat args and need to be uploaded in one RPC.
return connection.database.run_in_transaction(_execute_insert_homogenous, parts)
else:
# All the other cases that are esoteric and need
# transaction.execute_sql
sql_params_list = parts.get("sql_params_list")
return connection.database.run_in_transaction(
_execute_insert_heterogenous, sql_params_list
)
return connection.database.run_in_transaction(
_execute_insert_heterogenous, ((sql, params),)
)


class ColumnInfo:
Expand Down
26 changes: 7 additions & 19 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from google.cloud.spanner_v1.snapshot import Snapshot

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
from google.cloud.spanner_dbapi._helpers import parse_insert
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -436,23 +434,13 @@ def run_statement(self, statement, retried=False):
self._statements.append(statement)

if statement.is_insert:
parts = parse_insert(statement.sql, statement.params)

if parts.get("homogenous"):
_execute_insert_homogenous(transaction, parts)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)
else:
_execute_insert_heterogenous(
transaction,
parts.get("sql_params_list"),
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)
_execute_insert_heterogenous(
transaction, ((statement.sql, statement.params),)
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
Expand Down
259 changes: 7 additions & 252 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import datetime
import decimal
import re
from functools import reduce

import sqlparse
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1 import JsonObject

from .exceptions import Error, ProgrammingError
from .parser import expect, VALUES
from .exceptions import Error
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload

Expand Down Expand Up @@ -185,6 +183,12 @@ def classify_stmt(query):
:rtype: str
:returns: The query type name.
"""
# sqlparse will strip Cloud Spanner comments,
# still, special commenting styles, like
# PostgreSQL dollar quoted comments are not
# supported and will not be stripped.
query = sqlparse.format(query, strip_comments=True).strip()

if RE_DDL.match(query):
return STMT_DDL

Expand All @@ -199,255 +203,6 @@ def classify_stmt(query):
return STMT_UPDATING


def parse_insert(insert_sql, params):
"""
Parse an INSERT statement and generate a list of tuples of the form:
[
(SQL, params_per_row1),
(SQL, params_per_row2),
(SQL, params_per_row3),
...
]

There are 4 variants of an INSERT statement:
a) INSERT INTO <table> (columns...) VALUES (<inlined values>): no params
b) INSERT INTO <table> (columns...) SELECT_STMT: no params
c) INSERT INTO <table> (columns...) VALUES (%s,...): with params
d) INSERT INTO <table> (columns...) VALUES (%s,..<EXPR>...) with params and expressions

Thus given each of the forms, it will produce a dictionary describing
how to upload the contents to Cloud Spanner:
Case a)
SQL: INSERT INTO T (f1, f2) VALUES (1, 2)
it produces:
{
'sql_params_list': [
('INSERT INTO T (f1, f2) VALUES (1, 2)', None),
],
}

Case b)
SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus WHERE col IN (%s, %s)',
it produces:
{
'sql_params_list': [
('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', ('a', 'b')),
]
}

Case c)
SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)
Params: ['a', 'b', 'c', 'd']
it produces:
{
'sql_params_list': [
('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('a', 'b')),
('INSERT INTO T (f1, f2) VALUES (%s, %s)', ('c', 'd'))
],
}

Case d)
SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s)
Params: ['a', 'b', 'c', 'd']
it produces:
{
'sql_params_list': [
('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)),
('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',))
],
}

:type insert_sql: str
:param insert_sql: A SQL insert request.

:type params: list
:param params: A list of parameters.

:rtype: dict
:returns: A dictionary that maps `sql_params_list` to the list of
parameters in cases a), b), d) or the dictionary with information
about the resulting table in case c).
""" # noqa
match = RE_INSERT.search(insert_sql)

if not match:
raise ProgrammingError(
"Could not parse an INSERT statement from %s" % insert_sql
)

after_values_sql = RE_VALUES_TILL_END.findall(insert_sql)
if not after_values_sql:
# Case b)
insert_sql = sanitize_literals_for_upload(insert_sql)
return {"sql_params_list": [(insert_sql, params)]}

if not params:
# Case a) perhaps?
# Check if any %s exists.

# pyformat_str_count = after_values_sql.count("%s")
# if pyformat_str_count > 0:
# raise ProgrammingError(
# 'no params yet there are %d "%%s" tokens' % pyformat_str_count
# )
for item in after_values_sql:
if item.count("%s") > 0:
raise ProgrammingError(
'no params yet there are %d "%%s" tokens' % item.count("%s")
)

insert_sql = sanitize_literals_for_upload(insert_sql)
# Confirmed case of:
# SQL: INSERT INTO T (a1, a2) VALUES (1, 2)
# Params: None
return {"sql_params_list": [(insert_sql, None)]}

_, values = expect(after_values_sql[0], VALUES)

if values.homogenous():
# Case c)

columns = [mi.strip(" `") for mi in match.group("columns").split(",")]
sql_params_list = []
insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % (
match.group("table_name"),
match.group("columns"),
values.argv[0],
)
values_pyformat = [str(arg) for arg in values.argv]
rows_list = rows_for_insert_or_update(columns, params, values_pyformat)
insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble)
for row in rows_list:
sql_params_list.append((insert_sql_preamble, row))

return {"sql_params_list": sql_params_list}

# Case d)
# insert_sql is of the form:
# INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s))

# Sanity check:
# length(all_args) == len(params)
args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv])
if args_len != len(params):
raise ProgrammingError(
"Invalid length: VALUES(...) len: %d != len(params): %d"
% (args_len, len(params))
)

trim_index = insert_sql.find(after_values_sql[0])
before_values_sql = insert_sql[:trim_index]

sql_param_tuples = []
for token_arg in values.argv:
row_sql = before_values_sql + " VALUES%s" % token_arg
row_sql = sanitize_literals_for_upload(row_sql)
row_params, params = (
tuple(params[0 : len(token_arg)]),
params[len(token_arg) :],
)
sql_param_tuples.append((row_sql, row_params))

return {"sql_params_list": sql_param_tuples}


def rows_for_insert_or_update(columns, params, pyformat_args=None):
"""
Create a tupled list of params to be used as a single value per
value that inserted from a statement such as
SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)'
Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9]

We'll have to convert both params types into:
Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]

:type columns: list
:param columns: A list of the columns of the table.

:type params: list
:param params: A list of parameters.

:rtype: list
:returns: A properly restructured list of the parameters.
""" # noqa
if not pyformat_args:
# This is the case where we have for example:
# SQL: 'INSERT INTO t (f1, f2, f3)'
# Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
# Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9]
#
# We'll have to convert both params types into:
# [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]
contains_all_list_or_tuples = True
for param in params:
if not (isinstance(param, list) or isinstance(param, tuple)):
contains_all_list_or_tuples = False
break

if contains_all_list_or_tuples:
# The case with Params A: [(1, 2, 3), (4, 5, 6)]
# Ensure that each param's length == len(columns)
columns_len = len(columns)
for param in params:
if columns_len != len(param):
raise Error(
"\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d"
% (param, len(param), columns, columns_len)
)
return params
else:
# The case with Params B: [1, 2, 3]
# Insert statements' params are only passed as tuples or lists,
# yet for do_execute_update, we've got to pass in list of list.
# https://googleapis.dev/python/spanner/latest/transaction-api.html\
# #google.cloud.spanner_v1.transaction.Transaction.insert
n_stride = len(columns)
else:
# This is the case where we have for example:
# SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s),
# (%s, %s, %s), (%s, %s, %s)'
# Params: [1, 2, 3, 4, 5, 6, 7, 8, 9]
# which should become
# Columns: (f1, f2, f3)
# new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)]

# Sanity check 1: all the pyformat_values should have the exact same
# length.
first, rest = pyformat_args[0], pyformat_args[1:]
n_stride = first.count("%s")
for pyfmt_value in rest:
n = pyfmt_value.count("%s")
if n_stride != n:
raise Error(
"\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d"
% (first, n_stride, pyfmt_value, n)
)

# Sanity check 2: len(params) MUST be a multiple of n_stride aka
# len(count of %s).
# so that we can properly group for example:
# Given pyformat args:
# (%s, %s, %s)
# Params:
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
# into
# [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
if (len(params) % n_stride) != 0:
raise ProgrammingError(
"Invalid length: len(params)=%d MUST be a multiple of "
"len(pyformat_args)=%d" % (len(params), n_stride)
)

# Now chop up the strides.
strides = []
for step in range(0, len(params), n_stride):
stride = tuple(params[step : step + n_stride :])
strides.append(stride)

return strides


def sql_pyformat_args_to_spanner(sql, params):
"""
Transform pyformat set SQL to named arguments for Cloud Spanner.
Expand Down
Loading