Skip to content

Commit 81bd325

Browse files
authored
feat: update storage backend and add tests (#71)
Refactor the storage backend capabilities and introduce unit and integration tests for better coverage. Update dependencies and improve logging for configuration management.
1 parent bec4d37 commit 81bd325

24 files changed

+4413
-1426
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.10"
20+
rev: "v0.12.11"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
1313
name = "sqlspec"
1414
readme = "README.md"
1515
requires-python = ">=3.9, <4.0"
16-
version = "0.21.1"
16+
version = "0.22.0"
1717

1818
[project.urls]
1919
Discord = "https://discord.gg/litestar"
@@ -83,6 +83,7 @@ doc = [
8383
]
8484
extras = [
8585
"adbc_driver_manager",
86+
"fsspec[s3]",
8687
"pgvector",
8788
"pyarrow",
8889
"polars",
@@ -341,6 +342,7 @@ module = [
341342
"sqlglot.*",
342343
"pgvector",
343344
"pgvector.*",
345+
"minio",
344346
]
345347

346348
[[tool.mypy.overrides]]

sqlspec/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _cleanup_sync_pools(self) -> None:
6464
config.close_pool()
6565
cleaned_count += 1
6666
except Exception as e:
67-
logger.warning("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
67+
logger.debug("Failed to clean up sync pool for config %s: %s", config_type.__name__, e)
6868

6969
if cleaned_count > 0:
7070
logger.debug("Sync pool cleanup completed. Cleaned %d pools.", cleaned_count)
@@ -87,14 +87,14 @@ async def close_all_pools(self) -> None:
8787
else:
8888
sync_configs.append((config_type, config))
8989
except Exception as e:
90-
logger.warning("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
90+
logger.debug("Failed to prepare cleanup for config %s: %s", config_type.__name__, e)
9191

9292
if cleanup_tasks:
9393
try:
9494
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
9595
logger.debug("Async pool cleanup completed. Cleaned %d pools.", len(cleanup_tasks))
9696
except Exception as e:
97-
logger.warning("Failed to complete async pool cleanup: %s", e)
97+
logger.debug("Failed to complete async pool cleanup: %s", e)
9898

9999
for _config_type, config in sync_configs:
100100
config.close_pool()
@@ -129,7 +129,7 @@ def add_config(self, config: "Union[SyncConfigT, AsyncConfigT]") -> "type[Union[
129129
"""
130130
config_type = type(config)
131131
if config_type in self._configs:
132-
logger.warning("Configuration for %s already exists. Overwriting.", config_type.__name__)
132+
logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__)
133133
self._configs[config_type] = config
134134
return config_type
135135

sqlspec/loader.py

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,15 @@
1010
from datetime import datetime, timezone
1111
from pathlib import Path
1212
from typing import TYPE_CHECKING, Any, Final, Optional, Union
13+
from urllib.parse import unquote, urlparse
1314

1415
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
1516
from sqlspec.core.statement import SQL
16-
from sqlspec.exceptions import (
17-
MissingDependencyError,
18-
SQLFileNotFoundError,
19-
SQLFileParseError,
20-
StorageOperationFailedError,
21-
)
17+
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
2218
from sqlspec.storage.registry import storage_registry as default_storage_registry
2319
from sqlspec.utils.correlation import CorrelationContext
2420
from sqlspec.utils.logging import get_logger
21+
from sqlspec.utils.text import slugify
2522

2623
if TYPE_CHECKING:
2724
from sqlspec.storage.registry import StorageRegistry
@@ -54,13 +51,25 @@
5451
def _normalize_query_name(name: str) -> str:
5552
"""Normalize query name to be a valid Python identifier.
5653
54+
Convert hyphens to underscores, preserve dots for namespacing,
55+
and remove invalid characters.
56+
5757
Args:
5858
name: Raw query name from SQL file.
5959
6060
Returns:
6161
Normalized query name suitable as Python identifier.
6262
"""
63-
return TRIM_SPECIAL_CHARS.sub("", name).replace("-", "_")
63+
# Handle namespace parts separately to preserve dots
64+
parts = name.split(".")
65+
normalized_parts = []
66+
67+
for part in parts:
68+
# Use slugify with underscore separator and remove any remaining invalid chars
69+
normalized_part = slugify(part, separator="_")
70+
normalized_parts.append(normalized_part)
71+
72+
return ".".join(normalized_parts)
6473

6574

6675
def _normalize_dialect(dialect: str) -> str:
@@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
7685
return DIALECT_ALIASES.get(normalized, normalized)
7786

7887

79-
def _normalize_dialect_for_sqlglot(dialect: str) -> str:
80-
"""Normalize dialect name for SQLGlot compatibility.
81-
82-
Args:
83-
dialect: Dialect name from SQL file or parameter.
84-
85-
Returns:
86-
SQLGlot-compatible dialect name.
87-
"""
88-
normalized = dialect.lower().strip()
89-
return DIALECT_ALIASES.get(normalized, normalized)
90-
91-
9288
class NamedStatement:
9389
"""Represents a parsed SQL statement with metadata.
9490
@@ -218,8 +214,7 @@ def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
218214
SQLFileParseError: If file cannot be read.
219215
"""
220216
try:
221-
content = self._read_file_content(path)
222-
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
217+
return hashlib.md5(self._read_file_content(path).encode(), usedforsecurity=False).hexdigest()
223218
except Exception as e:
224219
raise SQLFileParseError(str(path), str(path), e) from e
225220

@@ -253,19 +248,22 @@ def _read_file_content(self, path: Union[str, Path]) -> str:
253248
SQLFileNotFoundError: If file does not exist.
254249
SQLFileParseError: If file cannot be read or parsed.
255250
"""
256-
257251
path_str = str(path)
258252

259253
try:
260254
backend = self.storage_registry.get(path)
255+
# For file:// URIs, extract just the filename for the backend call
256+
if path_str.startswith("file://"):
257+
parsed = urlparse(path_str)
258+
file_path = unquote(parsed.path)
259+
# Handle Windows paths (file:///C:/path)
260+
if file_path and len(file_path) > 2 and file_path[2] == ":": # noqa: PLR2004
261+
file_path = file_path[1:] # Remove leading slash for Windows
262+
filename = Path(file_path).name
263+
return backend.read_text(filename, encoding=self.encoding)
261264
return backend.read_text(path_str, encoding=self.encoding)
262265
except KeyError as e:
263266
raise SQLFileNotFoundError(path_str) from e
264-
except MissingDependencyError:
265-
try:
266-
return path.read_text(encoding=self.encoding) # type: ignore[union-attr]
267-
except FileNotFoundError as e:
268-
raise SQLFileNotFoundError(path_str) from e
269267
except StorageOperationFailedError as e:
270268
if "not found" in str(e).lower() or "no such file" in str(e).lower():
271269
raise SQLFileNotFoundError(path_str) from e
@@ -419,8 +417,7 @@ def _load_directory(self, dir_path: Path) -> int:
419417
for file_path in sql_files:
420418
relative_path = file_path.relative_to(dir_path)
421419
namespace_parts = relative_path.parent.parts
422-
namespace = ".".join(namespace_parts) if namespace_parts else None
423-
self._load_single_file(file_path, namespace)
420+
self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
424421
return len(sql_files)
425422

426423
def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
@@ -533,44 +530,6 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
533530
self._queries[normalized_name] = statement
534531
self._query_to_file[normalized_name] = "<directly added>"
535532

536-
def get_sql(self, name: str) -> "SQL":
537-
"""Get a SQL object by statement name.
538-
539-
Args:
540-
name: Name of the statement (from -- name: in SQL file).
541-
Hyphens in names are converted to underscores.
542-
543-
Returns:
544-
SQL object ready for execution.
545-
546-
Raises:
547-
SQLFileNotFoundError: If statement name not found.
548-
"""
549-
correlation_id = CorrelationContext.get()
550-
551-
safe_name = _normalize_query_name(name)
552-
553-
if safe_name not in self._queries:
554-
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
555-
logger.error(
556-
"Statement not found: %s",
557-
name,
558-
extra={
559-
"statement_name": name,
560-
"safe_name": safe_name,
561-
"available_statements": len(self._queries),
562-
"correlation_id": correlation_id,
563-
},
564-
)
565-
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
566-
567-
parsed_statement = self._queries[safe_name]
568-
sqlglot_dialect = None
569-
if parsed_statement.dialect:
570-
sqlglot_dialect = _normalize_dialect_for_sqlglot(parsed_statement.dialect)
571-
572-
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)
573-
574533
def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
575534
"""Get a loaded SQLFile object by path.
576535
@@ -659,3 +618,41 @@ def get_query_text(self, name: str) -> str:
659618
if safe_name not in self._queries:
660619
raise SQLFileNotFoundError(name)
661620
return self._queries[safe_name].sql
621+
622+
def get_sql(self, name: str) -> "SQL":
623+
"""Get a SQL object by statement name.
624+
625+
Args:
626+
name: Name of the statement (from -- name: in SQL file).
627+
Hyphens in names are converted to underscores.
628+
629+
Returns:
630+
SQL object ready for execution.
631+
632+
Raises:
633+
SQLFileNotFoundError: If statement name not found.
634+
"""
635+
correlation_id = CorrelationContext.get()
636+
637+
safe_name = _normalize_query_name(name)
638+
639+
if safe_name not in self._queries:
640+
available = ", ".join(sorted(self._queries.keys())) if self._queries else "none"
641+
logger.error(
642+
"Statement not found: %s",
643+
name,
644+
extra={
645+
"statement_name": name,
646+
"safe_name": safe_name,
647+
"available_statements": len(self._queries),
648+
"correlation_id": correlation_id,
649+
},
650+
)
651+
raise SQLFileNotFoundError(name, path=f"Statement '{name}' not found. Available statements: {available}")
652+
653+
parsed_statement = self._queries[safe_name]
654+
sqlglot_dialect = None
655+
if parsed_statement.dialect:
656+
sqlglot_dialect = _normalize_dialect(parsed_statement.dialect)
657+
658+
return SQL(parsed_statement.sql, dialect=sqlglot_dialect)

sqlspec/protocols.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
and runtime isinstance() checks.
55
"""
66

7-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, Union, runtime_checkable
7+
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
88

99
from typing_extensions import Self
1010

@@ -14,7 +14,6 @@
1414

1515
from sqlglot import exp
1616

17-
from sqlspec.storage.capabilities import StorageCapabilities
1817
from sqlspec.typing import ArrowRecordBatch, ArrowTable
1918

2019
__all__ = (
@@ -194,9 +193,8 @@ class ObjectStoreItemProtocol(Protocol):
194193
class ObjectStoreProtocol(Protocol):
195194
"""Protocol for object storage operations."""
196195

197-
capabilities: ClassVar["StorageCapabilities"]
198-
199196
protocol: str
197+
backend_type: str
200198

201199
def __init__(self, uri: str, **kwargs: Any) -> None:
202200
return
@@ -330,7 +328,7 @@ async def write_arrow_async(self, path: "Union[str, Path]", table: "ArrowTable",
330328
msg = "Async arrow writing not implemented"
331329
raise NotImplementedError(msg)
332330

333-
async def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
331+
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
334332
"""Async stream Arrow record batches from matching objects."""
335333
msg = "Async arrow streaming not implemented"
336334
raise NotImplementedError(msg)

sqlspec/storage/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88
- Capability-based backend selection
99
"""
1010

11-
from sqlspec.protocols import ObjectStoreProtocol
12-
from sqlspec.storage.capabilities import HasStorageCapabilities, StorageCapabilities
13-
from sqlspec.storage.registry import StorageRegistry
11+
from sqlspec.storage.registry import StorageRegistry, storage_registry
1412

15-
storage_registry = StorageRegistry()
16-
17-
__all__ = (
18-
"HasStorageCapabilities",
19-
"ObjectStoreProtocol",
20-
"StorageCapabilities",
21-
"StorageRegistry",
22-
"storage_registry",
23-
)
13+
__all__ = ("StorageRegistry", "storage_registry")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Storage backends."""

0 commit comments

Comments
 (0)