Skip to content

Commit c0125cb

Browse files
authored
feat(duckdb): add automatic extension configuration (#9)
* feat: duckdb extension installer
1 parent 1a62c36 commit c0125cb

File tree

23 files changed

+869
-207
lines changed

23 files changed

+869
-207
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ exclude_lines = [
128128

129129
[tool.pytest.ini_options]
130130
addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config"
131+
asyncio_default_fixture_loop_scope = "function"
132+
asyncio_mode = "auto"
131133
testpaths = ["tests"]
132134
xfail_strict = true
133135

sqlspec/_typing.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,26 @@
55

66
from __future__ import annotations
77

8+
from enum import Enum
89
from typing import (
910
Any,
1011
ClassVar,
12+
Final,
1113
Protocol,
14+
Union,
1215
cast,
1316
runtime_checkable,
1417
)
1518

16-
from typing_extensions import TypeVar, dataclass_transform
19+
from typing_extensions import Literal, TypeVar, dataclass_transform
20+
21+
22+
@runtime_checkable
23+
class DataclassProtocol(Protocol):
24+
"""Protocol for instance checking dataclasses."""
25+
26+
__dataclass_fields__: ClassVar[dict[str, Any]]
27+
1728

1829
T = TypeVar("T")
1930
T_co = TypeVar("T_co", covariant=True)
@@ -99,11 +110,26 @@ class UnsetType(enum.Enum): # type: ignore[no-redef]
99110
UNSET = UnsetType.UNSET # pyright: ignore[reportConstantRedefinition]
100111
MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
101112

113+
114+
class EmptyEnum(Enum):
115+
"""A sentinel enum used as placeholder."""
116+
117+
EMPTY = 0
118+
119+
120+
EmptyType = Union[Literal[EmptyEnum.EMPTY], UnsetType]
121+
Empty: Final = EmptyEnum.EMPTY
122+
123+
102124
__all__ = (
103125
"MSGSPEC_INSTALLED",
104126
"PYDANTIC_INSTALLED",
105127
"UNSET",
106128
"BaseModel",
129+
"DataclassProtocol",
130+
"Empty",
131+
"EmptyEnum",
132+
"EmptyType",
107133
"FailFast",
108134
"Struct",
109135
"TypeAdapter",

sqlspec/adapters/adbc/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
from typing import TYPE_CHECKING, TypeVar
66

77
from sqlspec.config import GenericDatabaseConfig
8-
from sqlspec.utils.empty import Empty
8+
from sqlspec.typing import Empty, EmptyType
99

1010
if TYPE_CHECKING:
1111
from collections.abc import Generator
1212
from typing import Any
1313

1414
from adbc_driver_manager.dbapi import Connection, Cursor
1515

16-
from sqlspec.utils.empty import EmptyType
17-
1816
__all__ = ("AdbcDatabaseConfig",)
1917

2018
ConnectionT = TypeVar("ConnectionT", bound="Connection")

sqlspec/adapters/aiosqlite/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from sqlspec.config import GenericDatabaseConfig
88
from sqlspec.exceptions import ImproperConfigurationError
9-
from sqlspec.utils.dataclass import simple_asdict
10-
from sqlspec.utils.empty import Empty, EmptyType
9+
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
1110

1211
if TYPE_CHECKING:
1312
from collections.abc import AsyncGenerator
@@ -60,7 +59,7 @@ def connection_config_dict(self) -> dict[str, Any]:
6059
Returns:
6160
A string keyed dict of config kwargs for the aiosqlite.connect() function.
6261
"""
63-
return simple_asdict(self, exclude_empty=True, convert_nested=False)
62+
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)
6463

6564
async def create_connection(self) -> Connection:
6665
"""Create and return a new database connection.

sqlspec/adapters/asyncmy/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from typing import TYPE_CHECKING, TypeVar
66

77
from sqlspec.exceptions import ImproperConfigurationError
8-
from sqlspec.utils.dataclass import simple_asdict
9-
from sqlspec.utils.empty import Empty, EmptyType
8+
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
109

1110
if TYPE_CHECKING:
1211
from collections.abc import AsyncGenerator
@@ -101,7 +100,7 @@ def pool_config_dict(self) -> dict[str, Any]:
101100
Returns:
102101
A string keyed dict of config kwargs for the Asyncmy create_pool function.
103102
"""
104-
return simple_asdict(self, exclude_empty=True, convert_nested=False)
103+
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)
105104

106105

107106
@dataclass
@@ -125,7 +124,7 @@ def pool_config_dict(self) -> dict[str, Any]:
125124
A string keyed dict of config kwargs for the Asyncmy create_pool function.
126125
"""
127126
if self.pool_config:
128-
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
127+
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
129128
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
130129
raise ImproperConfigurationError(msg)
131130

sqlspec/adapters/asyncpg/config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from sqlspec._serialization import decode_json, encode_json
1111
from sqlspec.config import GenericDatabaseConfig, GenericPoolConfig
1212
from sqlspec.exceptions import ImproperConfigurationError
13-
from sqlspec.utils.dataclass import simple_asdict
14-
from sqlspec.utils.empty import Empty, EmptyType
13+
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
1514

1615
if TYPE_CHECKING:
1716
from asyncio import AbstractEventLoop
@@ -98,7 +97,7 @@ def pool_config_dict(self) -> dict[str, Any]:
9897
function.
9998
"""
10099
if self.pool_config:
101-
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
100+
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
102101
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
103102
raise ImproperConfigurationError(msg)
104103

@@ -125,15 +124,15 @@ async def create_pool(self) -> Pool:
125124
return self.pool_instance
126125

127126
@asynccontextmanager
128-
async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]:
127+
async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]:
129128
db_pool = await self.create_pool()
130129
try:
131130
yield
132131
finally:
133132
db_pool.terminate()
134133
await db_pool.close()
135134

136-
def provide_pool(self, *args: Any, **kwargs) -> Awaitable[Pool]:
135+
def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[Pool]:
137136
"""Create a pool instance.
138137
139138
Returns:

sqlspec/adapters/duckdb/config.py

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,66 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, cast
66

77
from sqlspec.config import GenericDatabaseConfig
88
from sqlspec.exceptions import ImproperConfigurationError
9-
from sqlspec.utils.dataclass import simple_asdict
10-
from sqlspec.utils.empty import Empty, EmptyType
9+
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict
1110

1211
if TYPE_CHECKING:
13-
from collections.abc import Generator
12+
from collections.abc import Generator, Sequence
1413

1514
from duckdb import DuckDBPyConnection
1615

17-
__all__ = ("DuckDBConfig",)
16+
__all__ = ("DuckDBConfig", "ExtensionConfig")
17+
18+
19+
@dataclass
20+
class ExtensionConfig:
21+
"""Configuration for a DuckDB extension.
22+
23+
This class provides configuration options for DuckDB extensions, including installation
24+
and post-install configuration settings.
25+
26+
Args:
27+
name: The name of the extension to install
28+
config: Optional configuration settings to apply after installation
29+
force_install: Whether to force reinstall if already present
30+
repository: Optional repository name to install from
31+
repository_url: Optional repository URL to install from
32+
version: Optional version of the extension to install
33+
"""
34+
35+
name: str
36+
config: dict[str, Any] | None = None
37+
force_install: bool = False
38+
repository: str | None = None
39+
repository_url: str | None = None
40+
version: str | None = None
41+
42+
@classmethod
43+
def from_dict(cls, name: str, config: dict[str, Any] | bool | None = None) -> ExtensionConfig:
44+
"""Create an ExtensionConfig from a configuration dictionary.
45+
46+
Args:
47+
name: The name of the extension
48+
config: Configuration dictionary that may contain settings
49+
50+
Returns:
51+
A new ExtensionConfig instance
52+
"""
53+
if config is None:
54+
return cls(name=name)
55+
56+
if not isinstance(config, dict):
57+
config = {"force_install": bool(config)}
58+
59+
install_args = {
60+
key: config.pop(key)
61+
for key in ["force_install", "repository", "repository_url", "version", "config", "name"]
62+
if key in config
63+
}
64+
return cls(name=name, **install_args)
1865

1966

2067
@dataclass
@@ -39,31 +86,100 @@ class DuckDBConfig(GenericDatabaseConfig):
3986
For details see: https://duckdb.org/docs/api/python/overview#connection-options
4087
"""
4188

89+
extensions: Sequence[ExtensionConfig] | EmptyType = Empty
90+
"""A sequence of extension configurations to install and configure upon connection creation."""
91+
92+
def __post_init__(self) -> None:
93+
"""Post-initialization validation and processing.
94+
95+
This method handles merging extension configurations from both the extensions field
96+
and the config dictionary, if present. The config['extensions'] field can be either:
97+
- A dictionary mapping extension names to their configurations
98+
- A list of extension names (which will be installed with force_install=True)
99+
100+
Raises:
101+
ImproperConfigurationError: If there are duplicate extension configurations.
102+
"""
103+
if self.config is Empty:
104+
self.config = {}
105+
106+
if self.extensions is Empty:
107+
self.extensions = []
108+
# this is purely for mypy
109+
assert isinstance(self.config, dict) # noqa: S101
110+
assert isinstance(self.extensions, list) # noqa: S101
111+
112+
_e = self.config.pop("extensions", {})
113+
if not isinstance(_e, (dict, list, tuple)):
114+
msg = "When configuring extensions in the 'config' dictionary, the value must be a dictionary or sequence of extension names"
115+
raise ImproperConfigurationError(msg)
116+
if not isinstance(_e, dict):
117+
_e = {str(ext): {"force_install": False} for ext in _e}
118+
119+
if len(set(_e.keys()).intersection({ext.name for ext in self.extensions})) > 0:
120+
msg = "Configuring the same extension in both 'extensions' and as a key in 'config['extensions']' is not allowed"
121+
raise ImproperConfigurationError(msg)
122+
123+
self.extensions.extend([ExtensionConfig.from_dict(name, ext_config) for name, ext_config in _e.items()])
124+
125+
def _configure_extensions(self, connection: DuckDBPyConnection) -> None:
126+
"""Configure extensions for the connection.
127+
128+
Args:
129+
connection: The DuckDB connection to configure extensions for.
130+
131+
Raises:
132+
ImproperConfigurationError: If extension installation or configuration fails.
133+
"""
134+
if self.extensions is Empty:
135+
return
136+
137+
for extension in cast("list[ExtensionConfig]", self.extensions):
138+
try:
139+
if extension.force_install:
140+
connection.install_extension(
141+
extension=extension.name,
142+
force_install=extension.force_install,
143+
repository=extension.repository,
144+
repository_url=extension.repository_url,
145+
version=extension.version,
146+
)
147+
connection.load_extension(extension.name)
148+
149+
if extension.config:
150+
for key, value in extension.config.items():
151+
connection.execute(f"SET {key}={value}")
152+
except Exception as e:
153+
msg = f"Failed to configure extension {extension.name}. Error: {e!s}"
154+
raise ImproperConfigurationError(msg) from e
155+
42156
@property
43157
def connection_config_dict(self) -> dict[str, Any]:
44158
"""Return the connection configuration as a dict.
45159
46160
Returns:
47161
A string keyed dict of config kwargs for the duckdb.connect() function.
48162
"""
49-
config = simple_asdict(self, exclude_empty=True, convert_nested=False)
163+
config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False)
50164
if not config.get("database"):
51165
config["database"] = ":memory:"
52166
return config
53167

54168
def create_connection(self) -> DuckDBPyConnection:
55-
"""Create and return a new database connection.
169+
"""Create and return a new database connection with configured extensions.
56170
57171
Returns:
58-
A new DuckDB connection instance.
172+
A new DuckDB connection instance with extensions installed and configured.
59173
60174
Raises:
61-
ImproperConfigurationError: If the connection could not be established.
175+
ImproperConfigurationError: If the connection could not be established or extensions could not be configured.
62176
"""
63177
import duckdb
64178

65179
try:
66-
return duckdb.connect(**self.connection_config_dict)
180+
connection = duckdb.connect(**self.connection_config_dict)
181+
self._configure_extensions(connection)
182+
return connection
67183
except Exception as e:
68184
msg = f"Could not configure the DuckDB connection. Error: {e!s}"
69185
raise ImproperConfigurationError(msg) from e

sqlspec/adapters/oracledb/config/_asyncio.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
OracleGenericPoolConfig,
1414
)
1515
from sqlspec.exceptions import ImproperConfigurationError
16-
from sqlspec.utils.dataclass import simple_asdict
16+
from sqlspec.typing import dataclass_to_dict
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import AsyncGenerator, Awaitable
@@ -36,6 +36,11 @@ class OracleAsyncDatabaseConfig(OracleGenericDatabaseConfig[AsyncConnectionPool,
3636

3737
pool_config: OracleAsyncPoolConfig | None = None
3838
"""Oracle Pool configuration"""
39+
pool_instance: AsyncConnectionPool | None = None
40+
"""Optional pool to use.
41+
42+
If set, the plugin will use the provided pool rather than instantiate one.
43+
"""
3944

4045
@property
4146
def pool_config_dict(self) -> dict[str, Any]:
@@ -46,7 +51,7 @@ def pool_config_dict(self) -> dict[str, Any]:
4651
function.
4752
"""
4853
if self.pool_config is not None:
49-
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
54+
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
5055
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
5156
raise ImproperConfigurationError(msg)
5257

@@ -71,14 +76,14 @@ async def create_pool(self) -> AsyncConnectionPool:
7176
return self.pool_instance
7277

7378
@asynccontextmanager
74-
async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]:
79+
async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]:
7580
db_pool = await self.create_pool()
7681
try:
7782
yield
7883
finally:
7984
await db_pool.close(force=True)
8085

81-
def provide_pool(self, *args: Any, **kwargs) -> Awaitable[AsyncConnectionPool]:
86+
def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[AsyncConnectionPool]:
8287
"""Create a pool instance.
8388
8489
Returns:

0 commit comments

Comments
 (0)