Skip to content

Commit b51c6b6

Browse files
authored
feat!: allow override of registry for declarative bases (#307)
* feat: allow easy override of `registry` for declarative bases * feat: implement a metadata registry * feat: remove the need for `AlembicCommand` subclass in Litestar
1 parent 2644e1e commit b51c6b6

File tree

23 files changed

+216
-184
lines changed

23 files changed

+216
-184
lines changed

advanced_alchemy/alembic/commands.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
self,
3939
engine: Engine | AsyncEngine,
4040
version_table_name: str,
41+
bind_key: str | None = None,
4142
file_: str | os.PathLike[str] | None = None,
4243
ini_section: str = "alembic",
4344
output_buffer: TextIO | None = None,
@@ -56,6 +57,7 @@ def __init__(
5657
Args:
5758
engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance.
5859
version_table_name (str): The name of the version table.
60+
bind_key (str | None): The bind key for the metadata.
5961
file_ (str | os.PathLike[str] | None): The file path for the alembic configuration.
6062
ini_section (str): The ini section name.
6163
output_buffer (typing.TextIO | None): The output buffer for alembic commands.
@@ -70,6 +72,7 @@ def __init__(
7072
user_module_prefix (str | None): The prefix for user modules.
7173
"""
7274
self.template_directory = template_directory
75+
self.bind_key = bind_key
7376
self.version_table_name = version_table_name
7477
self.version_table_pk = engine.dialect.name != "spanner+spanner"
7578
self.version_table_schema = version_table_schema

advanced_alchemy/alembic/templates/asyncio/env.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy import Column, pool
77
from sqlalchemy.ext.asyncio import AsyncEngine, async_engine_from_config
88

9-
from advanced_alchemy.base import orm_registry
9+
from advanced_alchemy.base import metadata_registry
1010
from alembic import context
1111
from alembic.autogenerate import rewriter
1212
from alembic.operations import ops
@@ -23,16 +23,6 @@
2323
# this is the Alembic Config object, which provides
2424
# access to the values within the .ini file in use.
2525
config: AlembicCommandConfig = context.config # type: ignore # noqa: PGH003
26-
27-
28-
# add your model's MetaData object here
29-
# for 'autogenerate' support
30-
target_metadata = orm_registry.metadata
31-
32-
# other values from the config, defined by the needs of env.py,
33-
# can be acquired:
34-
# ... etc.
35-
3626
writer = rewriter.Rewriter()
3727

3828

@@ -75,7 +65,7 @@ def run_migrations_offline() -> None:
7565
"""
7666
context.configure(
7767
url=config.db_url,
78-
target_metadata=target_metadata,
68+
target_metadata=metadata_registry.get(config.bind_key),
7969
literal_binds=True,
8070
dialect_opts={"paramstyle": "named"},
8171
compare_type=config.compare_type,
@@ -94,7 +84,7 @@ def do_run_migrations(connection: Connection) -> None:
9484
"""Run migrations."""
9585
context.configure(
9686
connection=connection,
97-
target_metadata=target_metadata,
87+
target_metadata=metadata_registry.get(config.bind_key),
9888
compare_type=config.compare_type,
9989
version_table=config.version_table_name,
10090
version_table_pk=config.version_table_pk,

advanced_alchemy/alembic/templates/sync/env.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from sqlalchemy import Column, Engine, engine_from_config, pool
66

7-
from advanced_alchemy.base import orm_registry
7+
from advanced_alchemy.base import metadata_registry
88
from alembic import context
99
from alembic.autogenerate import rewriter
1010
from alembic.operations import ops
@@ -21,16 +21,6 @@
2121
# this is the Alembic Config object, which provides
2222
# access to the values within the .ini file in use.
2323
config: AlembicCommandConfig = context.config # type: ignore # noqa: PGH003
24-
25-
26-
# add your model's MetaData object here
27-
# for 'autogenerate' support
28-
target_metadata = orm_registry.metadata
29-
30-
# other values from the config, defined by the needs of env.py,
31-
# can be acquired:
32-
# ... etc.
33-
3424
writer = rewriter.Rewriter()
3525

3626

@@ -73,7 +63,7 @@ def run_migrations_offline() -> None:
7363
"""
7464
context.configure(
7565
url=config.db_url,
76-
target_metadata=target_metadata,
66+
target_metadata=metadata_registry.get(config.bind_key),
7767
literal_binds=True,
7868
dialect_opts={"paramstyle": "named"},
7969
compare_type=config.compare_type,
@@ -92,7 +82,7 @@ def do_run_migrations(connection: Connection) -> None:
9282
"""Run migrations."""
9383
context.configure(
9484
connection=connection,
95-
target_metadata=target_metadata,
85+
target_metadata=metadata_registry.get(config.bind_key),
9686
compare_type=config.compare_type,
9787
version_table=config.version_table_name,
9888
version_table_pk=config.version_table_pk,

0 commit comments

Comments
 (0)