Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
"BYTES": types.LargeBinary,
"DATE": types.DATE,
"DATETIME": types.DATETIME,
"FLOAT32": types.REAL,
"FLOAT64": types.Float,
"INT64": types.BIGINT,
"NUMERIC": types.NUMERIC(precision=38, scale=9),
Expand All @@ -101,6 +102,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
types.LargeBinary: "BYTES(MAX)",
types.DATE: "DATE",
types.DATETIME: "DATETIME",
types.REAL: "FLOAT32",
types.Float: "FLOAT64",
types.BIGINT: "INT64",
types.DECIMAL: "NUMERIC",
Expand Down Expand Up @@ -540,9 +542,18 @@ class SpannerTypeCompiler(GenericTypeCompiler):
def visit_INTEGER(self, type_, **kw):
return "INT64"

def visit_DOUBLE(self, type_, **kw):
return "FLOAT64"

def visit_FLOAT(self, type_, **kw):
# Note: This was added before Spanner supported FLOAT32.
# Changing this now to generate a FLOAT32 would be a breaking change.
# Users therefore have to use REAL to generate a FLOAT32 column.
return "FLOAT64"

def visit_REAL(self, type_, **kw):
return "FLOAT32"

def visit_TEXT(self, type_, **kw):
return "STRING({})".format(type_.length or "MAX")

Expand Down
30 changes: 30 additions & 0 deletions test/mockserver_tests/float32_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy import String
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.types import REAL


class Base(DeclarativeBase):
pass


class Number(Base):
__tablename__ = "numbers"
number: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
ln: Mapped[float] = mapped_column(REAL)
73 changes: 73 additions & 0 deletions test/mockserver_tests/test_float32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy.orm import Session
from sqlalchemy.testing import (
eq_,
is_instance_of,
is_false,
)
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
ExecuteSqlRequest,
ResultSet,
ResultSetStats,
BeginTransactionRequest,
CommitRequest,
TypeCode,
)
from test.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_result,
)


class TestFloat32(MockServerTestBase):
def test_insert_data(self):
from test.mockserver_tests.float32_model import Number

update_count = ResultSet(
dict(
stats=ResultSetStats(
dict(
row_count_exact=1,
)
)
)
)
add_result(
"INSERT INTO numbers (number, name, ln) VALUES (@a0, @a1, @a2)",
update_count,
)

engine = self.create_engine()
with Session(engine) as session:
n1 = Number(number=1, name="One", ln=0.0)
session.add_all([n1])
session.commit()

requests = self.spanner_service.requests
eq_(4, len(requests))
is_instance_of(requests[0], BatchCreateSessionsRequest)
is_instance_of(requests[1], BeginTransactionRequest)
is_instance_of(requests[2], ExecuteSqlRequest)
is_instance_of(requests[3], CommitRequest)
request: ExecuteSqlRequest = requests[2]
eq_(3, len(request.params))
eq_("1", request.params["a0"])
eq_("One", request.params["a1"])
eq_(0.0, request.params["a2"])
eq_(TypeCode.INT64, request.param_types["a0"].code)
eq_(TypeCode.STRING, request.param_types["a1"].code)
is_false("a2" in request.param_types)
1 change: 0 additions & 1 deletion test/mockserver_tests/test_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestQuickStart(MockServerTestBase):
def test_create_tables(self):
from test.mockserver_tests.quickstart_model import Base

# TODO: Fix the double quotes inside these SQL fragments.
add_result(
"""SELECT true
FROM INFORMATION_SCHEMA.TABLES
Expand Down
39 changes: 37 additions & 2 deletions test/system/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
Index,
MetaData,
Boolean,
BIGINT,
)
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import REAL
from sqlalchemy.testing import eq_
from sqlalchemy.testing.plugin.plugin_base import fixtures

Expand All @@ -37,6 +40,7 @@ def define_tables(cls, metadata):
Column("name", String(20)),
Column("alternative_name", String(20)),
Column("prime", Boolean),
Column("ln", REAL),
PrimaryKeyConstraint("number"),
)
Index(
Expand All @@ -53,8 +57,8 @@ def test_hello_world(self, connection):
def test_insert_number(self, connection):
connection.execute(
text(
"""insert or update into numbers (number, name, prime)
values (1, 'One', false)"""
"""insert or update into numbers (number, name, prime, ln)
values (1, 'One', false, cast(ln(1) as float32))"""
)
)
name = connection.execute(text("select name from numbers where number=1"))
Expand All @@ -66,6 +70,17 @@ def test_reflect(self, connection):
meta.reflect(bind=engine)
eq_(1, len(meta.tables))
table = meta.tables["numbers"]
eq_(5, len(table.columns))
eq_("number", table.columns[0].name)
eq_(BIGINT, type(table.columns[0].type))
eq_("name", table.columns[1].name)
eq_(String, type(table.columns[1].type))
eq_("alternative_name", table.columns[2].name)
eq_(String, type(table.columns[2].type))
eq_("prime", table.columns[3].name)
eq_(Boolean, type(table.columns[3].type))
eq_("ln", table.columns[4].name)
eq_(REAL, type(table.columns[4].type))
eq_(1, len(table.indexes))
index = next(iter(table.indexes))
eq_(2, len(index.columns))
Expand All @@ -74,3 +89,23 @@ def test_reflect(self, connection):
dialect_options = index.dialect_options["spanner"]
eq_(1, len(dialect_options["storing"]))
eq_("alternative_name", dialect_options["storing"][0])

def test_orm(self, connection):
class Base(DeclarativeBase):
pass

class Number(Base):
__tablename__ = "numbers"
number: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(20))
alternative_name: Mapped[str] = mapped_column(String(20))
prime: Mapped[bool] = mapped_column(Boolean)
ln: Mapped[float] = mapped_column(REAL)

engine = connection.engine
with Session(engine) as session:
number = Number(
number=1, name="One", alternative_name="Uno", prime=False, ln=0.0
)
session.add(number)
session.commit()
Loading