Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ engine = create_engine(URL(
))
```

Note that this flag has been deprecated, as our caching now uses the built-in SQLAlchemy reflection cache, the flag has been removed, but caching has been improved and if possible extra data will be fetched and cached.

### VARIANT, ARRAY and OBJECT Support

Snowflake SQLAlchemy supports fetching `VARIANT`, `ARRAY` and `OBJECT` data types. All types are converted into `str` in Python so that you can convert them to native data types using `json.loads`.
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/sqlalchemy/name_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class _NameUtils:

def __init__(self, identifier_preparer: IdentifierPreparer) -> None:
self.identifier_preparer = identifier_preparer

Expand All @@ -19,7 +18,9 @@ def normalize_name(self, name):
name.lower()
):
return name.lower()
elif name.lower() == name:
elif name.lower() == name and self.identifier_preparer._requires_quotes(
name.lower()
):
return quoted_name(name, quote=True)
else:
return name
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _get_schema_columns(self, connection, schema, **kw):
elif issubclass(col_type, sqltypes.Numeric):
col_type_kw["precision"] = numeric_precision
col_type_kw["scale"] = numeric_scale
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
elif issubclass(col_type, sqltypes.String):
col_type_kw["length"] = character_maximum_length
elif issubclass(col_type, StructuredType):
column_info = structured_type_info_manager.get_column_info(
Expand Down Expand Up @@ -582,7 +582,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
if not schema:
_, schema = self._current_database_schema(connection, **kw)

schema_columns = self._get_schema_columns(connection, schema, **kw)
if self._cache_column_metadata:
schema_columns = self._get_schema_columns(connection, schema, **kw)
else:
schema_columns = None

if schema_columns is None:
column_info_manager = _StructuredTypeInfoManager(
connection, self.name_utils, self.default_schema_name
Expand Down
28 changes: 24 additions & 4 deletions src/snowflake/sqlalchemy/structured_type_info_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from sqlalchemy import util as sa_util
from sqlalchemy.sql import text
import sqlalchemy.sql.sqltypes as sqltypes
from sqlalchemy.sql.elements import quoted_name


from snowflake.sqlalchemy.name_utils import _NameUtils
from snowflake.sqlalchemy.parser.custom_type_parser import NullType, parse_type
Expand Down Expand Up @@ -45,7 +48,6 @@ def get_column_info(
def _load_structured_type_info(self, schema_name: str, table_name: str):
"""Get column information for a structured type"""
if (schema_name, table_name) not in self.full_columns_descriptions:

column_definitions = self.get_table_columns(table_name, schema_name)
if not column_definitions:
self.full_columns_descriptions[(schema_name, table_name)] = {}
Expand All @@ -68,8 +70,8 @@ def get_table_columns(self, table_name: str, schema: str = None):

schema = schema if schema else self.default_schema

table_schema = self.name_utils.denormalize_name(schema)
table_name = self.name_utils.denormalize_name(table_name)
table_schema = self.name_utils.normalize_name(schema)
table_name = self.name_utils.normalize_name(table_name)
result = self._execute_desc(table_schema, table_name)
if not result:
return []
Expand Down Expand Up @@ -100,10 +102,18 @@ def get_table_columns(self, table_name: str, schema: str = None):
identity = {
"start": int(match.group("start")),
"increment": int(match.group("increment")),
"order_type": match.group("order_type"),
"order": match.group("order_type"),
}
is_identity = identity is not None

# Normalize BINARY type length for consistency with _get_schema_columns().
# DESC TABLE returns the type with the length attribute, but information_schema.columns does not (character_maximum_length is None).
# Setting length to None ensures both code paths return identical column metadata,
# which is important when cache_column_metadata toggles between the two approaches.
# See: tests/test_core.py::test_column_metadata
if isinstance(type_instance, sqltypes.BINARY):
type_instance.length = None

ans.append(
{
"name": column_name,
Expand All @@ -129,6 +139,16 @@ def _execute_desc(self, table_schema: str, table_name: str):
Exception can be caused by another session dropping the table while
once this process has started"""
try:
table_schema = (
self.name_utils.identifier_preparer.quote(table_schema)
if isinstance(table_schema, quoted_name)
else table_schema
)
table_name = (
self.name_utils.identifier_preparer.quote(table_name)
if isinstance(table_name, quoted_name)
else table_name
)
return self.connection.execute(
text(
"DESC /* sqlalchemy:_get_schema_columns */"
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def help():
"protocol": "https",
"host": "<host>",
"port": "443",
"cache_column_metadata": False,
}


Expand Down
136 changes: 136 additions & 0 deletions tests/test_cache_column_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from unittest.mock import patch

import pytest
from sqlalchemy import Column, Integer, Sequence, String, inspect
from sqlalchemy.orm import declarative_base

from snowflake.sqlalchemy.custom_types import OBJECT


@pytest.mark.parametrize(
"cache_column_metadata,expected_schema_count,expected_desc_count",
[
(False, 0, 1),
(True, 1, 3),
],
)
def test_cache_column_metadata(
cache_column_metadata,
expected_schema_count,
expected_desc_count,
engine_testaccount,
):
"""
Test cache_column_metadata behavior for column reflection.

This test verifies that the _cache_column_metadata flag controls whether
the dialect prefetches all columns from a schema or queries individual tables.

When cache_column_metadata=False (default):
- _get_schema_columns is NOT called
- Only the requested table is queried via DESC TABLE
- Results in 1 DESC call for the User table

When cache_column_metadata=True:
- _get_schema_columns IS called (fetches all columns via information_schema)
- Additional DESC TABLE calls are made for tables with structured types
(MAP, ARRAY, OBJECT) to get detailed type information
- Results in 1 schema query + 3 DESC calls (User, OtherTableA, OtherTableB)

Note: OtherTableC does not trigger a DESC call because it has no structured types.
"""
Base = declarative_base()

class User(Base):
__tablename__ = "user"

id = Column(Integer, Sequence("user_id_seq"), primary_key=True)
name = Column(String)
object = Column(OBJECT)

class OtherTableA(Base):
__tablename__ = "other_a"

id = Column(Integer, primary_key=True)
name = Column(String)
payload = Column(OBJECT)

class OtherTableB(Base):
__tablename__ = "other_b"

id = Column(Integer, primary_key=True)
name = Column(String)
payload = Column(OBJECT)

class OtherTableC(Base):
__tablename__ = "other_c"

id = Column(Integer, primary_key=True)
name = Column(String)

models = [User, OtherTableA, OtherTableB, OtherTableC]

Base.metadata.create_all(engine_testaccount)

inspector = inspect(engine_testaccount)
schema = inspector.default_schema_name

# Verify cache_column_metadata is False by default
assert not engine_testaccount.dialect._cache_column_metadata

# Track calls to _get_schema_columns
schema_columns_count = []
original_schema_columns = engine_testaccount.dialect._get_schema_columns

def tracked_schema_columns(*args, **kwargs):
"""Wrapper to count calls to _get_schema_columns."""
schema_columns_count.append(1)
return original_schema_columns(*args, **kwargs)

# Track DESC TABLE commands executed by the dialect
desc_call_count = []

def tracked_execute(statement, *args, **kwargs):
"""
Wrapper to count DESC TABLE commands for our test tables.

Only counts DESC commands with the sqlalchemy:_get_schema_columns comment
that target one of our test tables (filters out unrelated DESC calls).
"""
stmt_str = str(statement)
if (
"DESC" in stmt_str
and "sqlalchemy:_get_schema_columns" in stmt_str
and any(model.__tablename__.lower() in stmt_str for model in models)
):
desc_call_count.append(stmt_str)
return original_execute(statement, *args, **kwargs)

with patch.object(
engine_testaccount.dialect,
"_cache_column_metadata",
cache_column_metadata,
), patch.object(
engine_testaccount.dialect,
"_get_schema_columns",
side_effect=tracked_schema_columns,
):
with engine_testaccount.connect() as conn:
original_execute = conn.execute

with patch.object(conn, "execute", side_effect=tracked_execute):
tracked_inspector = inspect(conn)

# Reflect columns for User table
_ = tracked_inspector.get_columns(User.__tablename__, schema)

# Verify expected behavior based on cache_column_metadata setting
assert len(schema_columns_count) == expected_schema_count, (
f"Expected {expected_schema_count} _get_schema_columns call(s), got {len(schema_columns_count)}"
)
assert len(desc_call_count) == expected_desc_count, (
f"Expected {expected_desc_count} DESC call(s), got {len(desc_call_count)}"
)
12 changes: 6 additions & 6 deletions tests/test_structured_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,12 @@ def test_structured_type_not_supported_in_table_columns_error(


@patch.object(_StructuredTypeInfoManager, "_execute_desc")
@patch.object(_NameUtils, "denormalize_name")
@patch.object(_NameUtils, "normalize_name")
def test_structured_type_on_dropped_table(
mocked_execute_desc_method, mocked_denormalize_name_method
mocked_normalize_name_method, mocked_execute_desc_method
):
mocked_execute_desc_method.return_value = None
mocked_denormalize_name_method.side_effect = lambda self, v: v
mocked_normalize_name_method.side_effect = lambda v: v
structured_type_info = _StructuredTypeInfoManager(
None, _NameUtils(None), "mySchema"
)
Expand All @@ -619,9 +619,9 @@ def test_structured_type_on_dropped_table(


@patch.object(_StructuredTypeInfoManager, "_execute_desc")
@patch.object(_NameUtils, "denormalize_name")
@patch.object(_NameUtils, "normalize_name")
def test_structured_type_on_table_with_map(
mocked_execute_desc_method, mocked_denormalize_name_method
mocked_normalize_name_method, mocked_execute_desc_method
):
mocked_execute_desc_method.return_value = [
[
Expand All @@ -637,7 +637,7 @@ def test_structured_type_on_table_with_map(
"MapColumn",
]
]
mocked_denormalize_name_method.side_effect = lambda self, v: v
mocked_normalize_name_method.side_effect = lambda v: v
structured_type_info = _StructuredTypeInfoManager(
None, _NameUtils(None), "mySchema"
)
Expand Down