10
10
from datetime import datetime , timezone
11
11
from pathlib import Path
12
12
from typing import TYPE_CHECKING , Any , Final , Optional , Union
13
+ from urllib .parse import unquote , urlparse
13
14
14
15
from sqlspec .core .cache import CacheKey , get_cache_config , get_default_cache
15
16
from sqlspec .core .statement import SQL
16
- from sqlspec .exceptions import (
17
- MissingDependencyError ,
18
- SQLFileNotFoundError ,
19
- SQLFileParseError ,
20
- StorageOperationFailedError ,
21
- )
17
+ from sqlspec .exceptions import SQLFileNotFoundError , SQLFileParseError , StorageOperationFailedError
22
18
from sqlspec .storage .registry import storage_registry as default_storage_registry
23
19
from sqlspec .utils .correlation import CorrelationContext
24
20
from sqlspec .utils .logging import get_logger
21
+ from sqlspec .utils .text import slugify
25
22
26
23
if TYPE_CHECKING :
27
24
from sqlspec .storage .registry import StorageRegistry
54
51
def _normalize_query_name (name : str ) -> str :
55
52
"""Normalize query name to be a valid Python identifier.
56
53
54
+ Convert hyphens to underscores, preserve dots for namespacing,
55
+ and remove invalid characters.
56
+
57
57
Args:
58
58
name: Raw query name from SQL file.
59
59
60
60
Returns:
61
61
Normalized query name suitable as Python identifier.
62
62
"""
63
- return TRIM_SPECIAL_CHARS .sub ("" , name ).replace ("-" , "_" )
63
+ # Handle namespace parts separately to preserve dots
64
+ parts = name .split ("." )
65
+ normalized_parts = []
66
+
67
+ for part in parts :
68
+ # Use slugify with underscore separator and remove any remaining invalid chars
69
+ normalized_part = slugify (part , separator = "_" )
70
+ normalized_parts .append (normalized_part )
71
+
72
+ return "." .join (normalized_parts )
64
73
65
74
66
75
def _normalize_dialect (dialect : str ) -> str :
@@ -76,19 +85,6 @@ def _normalize_dialect(dialect: str) -> str:
76
85
return DIALECT_ALIASES .get (normalized , normalized )
77
86
78
87
79
- def _normalize_dialect_for_sqlglot (dialect : str ) -> str :
80
- """Normalize dialect name for SQLGlot compatibility.
81
-
82
- Args:
83
- dialect: Dialect name from SQL file or parameter.
84
-
85
- Returns:
86
- SQLGlot-compatible dialect name.
87
- """
88
- normalized = dialect .lower ().strip ()
89
- return DIALECT_ALIASES .get (normalized , normalized )
90
-
91
-
92
88
class NamedStatement :
93
89
"""Represents a parsed SQL statement with metadata.
94
90
@@ -218,8 +214,7 @@ def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
218
214
SQLFileParseError: If file cannot be read.
219
215
"""
220
216
try :
221
- content = self ._read_file_content (path )
222
- return hashlib .md5 (content .encode (), usedforsecurity = False ).hexdigest ()
217
+ return hashlib .md5 (self ._read_file_content (path ).encode (), usedforsecurity = False ).hexdigest ()
223
218
except Exception as e :
224
219
raise SQLFileParseError (str (path ), str (path ), e ) from e
225
220
@@ -253,19 +248,22 @@ def _read_file_content(self, path: Union[str, Path]) -> str:
253
248
SQLFileNotFoundError: If file does not exist.
254
249
SQLFileParseError: If file cannot be read or parsed.
255
250
"""
256
-
257
251
path_str = str (path )
258
252
259
253
try :
260
254
backend = self .storage_registry .get (path )
255
+ # For file:// URIs, extract just the filename for the backend call
256
+ if path_str .startswith ("file://" ):
257
+ parsed = urlparse (path_str )
258
+ file_path = unquote (parsed .path )
259
+ # Handle Windows paths (file:///C:/path)
260
+ if file_path and len (file_path ) > 2 and file_path [2 ] == ":" : # noqa: PLR2004
261
+ file_path = file_path [1 :] # Remove leading slash for Windows
262
+ filename = Path (file_path ).name
263
+ return backend .read_text (filename , encoding = self .encoding )
261
264
return backend .read_text (path_str , encoding = self .encoding )
262
265
except KeyError as e :
263
266
raise SQLFileNotFoundError (path_str ) from e
264
- except MissingDependencyError :
265
- try :
266
- return path .read_text (encoding = self .encoding ) # type: ignore[union-attr]
267
- except FileNotFoundError as e :
268
- raise SQLFileNotFoundError (path_str ) from e
269
267
except StorageOperationFailedError as e :
270
268
if "not found" in str (e ).lower () or "no such file" in str (e ).lower ():
271
269
raise SQLFileNotFoundError (path_str ) from e
@@ -419,8 +417,7 @@ def _load_directory(self, dir_path: Path) -> int:
419
417
for file_path in sql_files :
420
418
relative_path = file_path .relative_to (dir_path )
421
419
namespace_parts = relative_path .parent .parts
422
- namespace = "." .join (namespace_parts ) if namespace_parts else None
423
- self ._load_single_file (file_path , namespace )
420
+ self ._load_single_file (file_path , "." .join (namespace_parts ) if namespace_parts else None )
424
421
return len (sql_files )
425
422
426
423
def _load_single_file (self , file_path : Union [str , Path ], namespace : Optional [str ]) -> None :
@@ -533,44 +530,6 @@ def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) ->
533
530
self ._queries [normalized_name ] = statement
534
531
self ._query_to_file [normalized_name ] = "<directly added>"
535
532
536
- def get_sql (self , name : str ) -> "SQL" :
537
- """Get a SQL object by statement name.
538
-
539
- Args:
540
- name: Name of the statement (from -- name: in SQL file).
541
- Hyphens in names are converted to underscores.
542
-
543
- Returns:
544
- SQL object ready for execution.
545
-
546
- Raises:
547
- SQLFileNotFoundError: If statement name not found.
548
- """
549
- correlation_id = CorrelationContext .get ()
550
-
551
- safe_name = _normalize_query_name (name )
552
-
553
- if safe_name not in self ._queries :
554
- available = ", " .join (sorted (self ._queries .keys ())) if self ._queries else "none"
555
- logger .error (
556
- "Statement not found: %s" ,
557
- name ,
558
- extra = {
559
- "statement_name" : name ,
560
- "safe_name" : safe_name ,
561
- "available_statements" : len (self ._queries ),
562
- "correlation_id" : correlation_id ,
563
- },
564
- )
565
- raise SQLFileNotFoundError (name , path = f"Statement '{ name } ' not found. Available statements: { available } " )
566
-
567
- parsed_statement = self ._queries [safe_name ]
568
- sqlglot_dialect = None
569
- if parsed_statement .dialect :
570
- sqlglot_dialect = _normalize_dialect_for_sqlglot (parsed_statement .dialect )
571
-
572
- return SQL (parsed_statement .sql , dialect = sqlglot_dialect )
573
-
574
533
def get_file (self , path : Union [str , Path ]) -> "Optional[SQLFile]" :
575
534
"""Get a loaded SQLFile object by path.
576
535
@@ -659,3 +618,41 @@ def get_query_text(self, name: str) -> str:
659
618
if safe_name not in self ._queries :
660
619
raise SQLFileNotFoundError (name )
661
620
return self ._queries [safe_name ].sql
621
+
622
+ def get_sql (self , name : str ) -> "SQL" :
623
+ """Get a SQL object by statement name.
624
+
625
+ Args:
626
+ name: Name of the statement (from -- name: in SQL file).
627
+ Hyphens in names are converted to underscores.
628
+
629
+ Returns:
630
+ SQL object ready for execution.
631
+
632
+ Raises:
633
+ SQLFileNotFoundError: If statement name not found.
634
+ """
635
+ correlation_id = CorrelationContext .get ()
636
+
637
+ safe_name = _normalize_query_name (name )
638
+
639
+ if safe_name not in self ._queries :
640
+ available = ", " .join (sorted (self ._queries .keys ())) if self ._queries else "none"
641
+ logger .error (
642
+ "Statement not found: %s" ,
643
+ name ,
644
+ extra = {
645
+ "statement_name" : name ,
646
+ "safe_name" : safe_name ,
647
+ "available_statements" : len (self ._queries ),
648
+ "correlation_id" : correlation_id ,
649
+ },
650
+ )
651
+ raise SQLFileNotFoundError (name , path = f"Statement '{ name } ' not found. Available statements: { available } " )
652
+
653
+ parsed_statement = self ._queries [safe_name ]
654
+ sqlglot_dialect = None
655
+ if parsed_statement .dialect :
656
+ sqlglot_dialect = _normalize_dialect (parsed_statement .dialect )
657
+
658
+ return SQL (parsed_statement .sql , dialect = sqlglot_dialect )
0 commit comments