Skip to content

Commit 6c3cb42

Browse files
authored
feat: support float32 (#531)
* feat: support float32 Adds support for FLOAT32 columns. Applications should use the SQLAlchemy type REAL to create a FLOAT32 column, as FLOAT is already reserved for FLOAT64. Fixes #409 * chore: run code formatter * fix: remove DOUBLE reference which is SQLAlchemy 2.0-only
1 parent dbb19c4 commit 6c3cb42

File tree

5 files changed

+151
-3
lines changed

5 files changed

+151
-3
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
8484
"BYTES": types.LargeBinary,
8585
"DATE": types.DATE,
8686
"DATETIME": types.DATETIME,
87+
"FLOAT32": types.REAL,
8788
"FLOAT64": types.Float,
8889
"INT64": types.BIGINT,
8990
"NUMERIC": types.NUMERIC(precision=38, scale=9),
@@ -101,6 +102,7 @@ def reset_connection(dbapi_conn, connection_record, reset_state=None):
101102
types.LargeBinary: "BYTES(MAX)",
102103
types.DATE: "DATE",
103104
types.DATETIME: "DATETIME",
105+
types.REAL: "FLOAT32",
104106
types.Float: "FLOAT64",
105107
types.BIGINT: "INT64",
106108
types.DECIMAL: "NUMERIC",
@@ -540,9 +542,18 @@ class SpannerTypeCompiler(GenericTypeCompiler):
540542
def visit_INTEGER(self, type_, **kw):
541543
return "INT64"
542544

545+
def visit_DOUBLE(self, type_, **kw):
546+
return "FLOAT64"
547+
543548
def visit_FLOAT(self, type_, **kw):
549+
# Note: This was added before Spanner supported FLOAT32.
550+
# Changing this now to generate a FLOAT32 would be a breaking change.
551+
# Users therefore have to use REAL to generate a FLOAT32 column.
544552
return "FLOAT64"
545553

554+
def visit_REAL(self, type_, **kw):
555+
return "FLOAT32"
556+
546557
def visit_TEXT(self, type_, **kw):
547558
return "STRING({})".format(type_.length or "MAX")
548559

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy import String
16+
from sqlalchemy.orm import DeclarativeBase
17+
from sqlalchemy.orm import Mapped
18+
from sqlalchemy.orm import mapped_column
19+
from sqlalchemy.types import REAL
20+
21+
22+
class Base(DeclarativeBase):
23+
pass
24+
25+
26+
class Number(Base):
27+
__tablename__ = "numbers"
28+
number: Mapped[int] = mapped_column(primary_key=True)
29+
name: Mapped[str] = mapped_column(String(30))
30+
ln: Mapped[float] = mapped_column(REAL)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from sqlalchemy.orm import Session
16+
from sqlalchemy.testing import (
17+
eq_,
18+
is_instance_of,
19+
is_false,
20+
)
21+
from google.cloud.spanner_v1 import (
22+
BatchCreateSessionsRequest,
23+
ExecuteSqlRequest,
24+
ResultSet,
25+
ResultSetStats,
26+
BeginTransactionRequest,
27+
CommitRequest,
28+
TypeCode,
29+
)
30+
from test.mockserver_tests.mock_server_test_base import (
31+
MockServerTestBase,
32+
add_result,
33+
)
34+
35+
36+
class TestFloat32(MockServerTestBase):
37+
def test_insert_data(self):
38+
from test.mockserver_tests.float32_model import Number
39+
40+
update_count = ResultSet(
41+
dict(
42+
stats=ResultSetStats(
43+
dict(
44+
row_count_exact=1,
45+
)
46+
)
47+
)
48+
)
49+
add_result(
50+
"INSERT INTO numbers (number, name, ln) VALUES (@a0, @a1, @a2)",
51+
update_count,
52+
)
53+
54+
engine = self.create_engine()
55+
with Session(engine) as session:
56+
n1 = Number(number=1, name="One", ln=0.0)
57+
session.add_all([n1])
58+
session.commit()
59+
60+
requests = self.spanner_service.requests
61+
eq_(4, len(requests))
62+
is_instance_of(requests[0], BatchCreateSessionsRequest)
63+
is_instance_of(requests[1], BeginTransactionRequest)
64+
is_instance_of(requests[2], ExecuteSqlRequest)
65+
is_instance_of(requests[3], CommitRequest)
66+
request: ExecuteSqlRequest = requests[2]
67+
eq_(3, len(request.params))
68+
eq_("1", request.params["a0"])
69+
eq_("One", request.params["a1"])
70+
eq_(0.0, request.params["a2"])
71+
eq_(TypeCode.INT64, request.param_types["a0"].code)
72+
eq_(TypeCode.STRING, request.param_types["a1"].code)
73+
is_false("a2" in request.param_types)

test/mockserver_tests/test_quickstart.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TestQuickStart(MockServerTestBase):
3030
def test_create_tables(self):
3131
from test.mockserver_tests.quickstart_model import Base
3232

33-
# TODO: Fix the double quotes inside these SQL fragments.
3433
add_result(
3534
"""SELECT true
3635
FROM INFORMATION_SCHEMA.TABLES

test/system/test_basics.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
Index,
2323
MetaData,
2424
Boolean,
25+
BIGINT,
2526
)
27+
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
28+
from sqlalchemy.types import REAL
2629
from sqlalchemy.testing import eq_
2730
from sqlalchemy.testing.plugin.plugin_base import fixtures
2831

@@ -37,6 +40,7 @@ def define_tables(cls, metadata):
3740
Column("name", String(20)),
3841
Column("alternative_name", String(20)),
3942
Column("prime", Boolean),
43+
Column("ln", REAL),
4044
PrimaryKeyConstraint("number"),
4145
)
4246
Index(
@@ -53,8 +57,8 @@ def test_hello_world(self, connection):
5357
def test_insert_number(self, connection):
5458
connection.execute(
5559
text(
56-
"""insert or update into numbers (number, name, prime)
57-
values (1, 'One', false)"""
60+
"""insert or update into numbers (number, name, prime, ln)
61+
values (1, 'One', false, cast(ln(1) as float32))"""
5862
)
5963
)
6064
name = connection.execute(text("select name from numbers where number=1"))
@@ -66,6 +70,17 @@ def test_reflect(self, connection):
6670
meta.reflect(bind=engine)
6771
eq_(1, len(meta.tables))
6872
table = meta.tables["numbers"]
73+
eq_(5, len(table.columns))
74+
eq_("number", table.columns[0].name)
75+
eq_(BIGINT, type(table.columns[0].type))
76+
eq_("name", table.columns[1].name)
77+
eq_(String, type(table.columns[1].type))
78+
eq_("alternative_name", table.columns[2].name)
79+
eq_(String, type(table.columns[2].type))
80+
eq_("prime", table.columns[3].name)
81+
eq_(Boolean, type(table.columns[3].type))
82+
eq_("ln", table.columns[4].name)
83+
eq_(REAL, type(table.columns[4].type))
6984
eq_(1, len(table.indexes))
7085
index = next(iter(table.indexes))
7186
eq_(2, len(index.columns))
@@ -74,3 +89,23 @@ def test_reflect(self, connection):
7489
dialect_options = index.dialect_options["spanner"]
7590
eq_(1, len(dialect_options["storing"]))
7691
eq_("alternative_name", dialect_options["storing"][0])
92+
93+
def test_orm(self, connection):
94+
class Base(DeclarativeBase):
95+
pass
96+
97+
class Number(Base):
98+
__tablename__ = "numbers"
99+
number: Mapped[int] = mapped_column(primary_key=True)
100+
name: Mapped[str] = mapped_column(String(20))
101+
alternative_name: Mapped[str] = mapped_column(String(20))
102+
prime: Mapped[bool] = mapped_column(Boolean)
103+
ln: Mapped[float] = mapped_column(REAL)
104+
105+
engine = connection.engine
106+
with Session(engine) as session:
107+
number = Number(
108+
number=1, name="One", alternative_name="Uno", prime=False, ln=0.0
109+
)
110+
session.add(number)
111+
session.commit()

0 commit comments

Comments
 (0)