From dcc2a23bee54a0a84bd585427b314f5f645e862e Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 27 Mar 2024 16:42:27 -0500 Subject: [PATCH 1/9] Solve the issue #1312 1. Provide a descriptor `ModelGetter` for solving the `db.Model` type from the `db` type dynamically. 2. Add `t.Type[sa_orm.MappedAsDataclass]` to `_FSA_MCT`. 3. Let `SQLAlchemy(...)` annotated by the provided `model_class` type. --- src/flask_sqlalchemy/extension.py | 208 ++++++++++++++++++++++++------ 1 file changed, 167 insertions(+), 41 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 3429e059..de434622 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -10,10 +10,12 @@ import sqlalchemy.event as sa_event import sqlalchemy.exc as sa_exc import sqlalchemy.orm as sa_orm +import typing_extensions as te from flask import abort from flask import current_app from flask import Flask from flask import has_app_context +from sqlalchemy.util import typing as compat_typing from .model import _QueryProperty from .model import BindMixin @@ -32,15 +34,14 @@ # Type accepted for model_class argument -_FSA_MCT = t.TypeVar( - "_FSA_MCT", - bound=t.Union[ - t.Type[Model], - sa_orm.DeclarativeMeta, - t.Type[sa_orm.DeclarativeBase], - t.Type[sa_orm.DeclarativeBaseNoMeta], - ], -) +_FSA_MCT = t.Union[ + t.Type[Model], + sa_orm.DeclarativeMeta, + t.Type[sa_orm.DeclarativeBase], + t.Type[sa_orm.DeclarativeBaseNoMeta], + t.Type[sa_orm.MappedAsDataclass], +] +_FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True) # Type returned by make_declarative_base @@ -48,6 +49,117 @@ class _FSAModel(Model): metadata: sa.MetaData +if t.TYPE_CHECKING: + + class _FSAModel_KW(_FSAModel): + def __init__(self, **kw: t.Any) -> None: + ... + +else: + # To minimize side effects, the type hint only works for static type checker. + # At run time, `_FSAModel_KW` falls back to `_FSAModel` + _FSAModel_KW = _FSAModel + + +if t.TYPE_CHECKING: + + @compat_typing.dataclass_transform( + field_specifiers=( + sa_orm.MappedColumn, + sa_orm.RelationshipProperty, + sa_orm.Composite, + sa_orm.Synonym, + sa_orm.mapped_column, + sa_orm.relationship, + sa_orm.composite, + sa_orm.synonym, + sa_orm.deferred, + ), + ) + class _FSAModel_DataClass(_FSAModel): + ... + +else: + # To minimize side effects, the type hint only works for static type checker. + # At run time, `_FSAModel_DataClass` falls back to `_FSAModel` + _FSAModel_DataClass = _FSAModel + + +class ModelGetter: + """Model getter for the ``SQLAlchemy().Model`` property. + + This getter is used for determining the correct type of ``SQLAlchemy().Model``. + + When ``SQLAlchemy`` is initialized by + + .. code-block:: python + + db = SQLAlchemy(model_class=MappedAsDataclass) + + the ``db.Model`` property needs to be a class decorated by ``dataclass_transform``. + + Otherwise, the ``db.Model`` property needs to provide a synthesized initialization + method accepting unknown keyword arguments. These keyword arguments are not + annotated but limited in the range of data items. This rule is guaranteed by the + featuers of all other candidates of ``model_class``. + + Calling the class property ``SQLAlchemy.Model`` will return this descriptor + directly. + """ + + # This variant is at first. Its priority is highest for making SQLAlchemy[Any] + # exports a Model with type[_FSAModel_KW]. + # Note that in actual using cases, users do not need to inherit Model classes. + @te.overload + def __get__( + self, obj: SQLAlchemy[t.Type[Model]], obj_cls: t.Any = None + ) -> t.Type[_FSAModel_KW]: + ... + + # This variant needs to be prior than DeclarativeBase, because a class may inherit + # multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO + # list, this configuration make type[_FSAModel_DataClass] preferred. + @te.overload + def __get__( + self, obj: SQLAlchemy[t.Type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None + ) -> t.Type[_FSAModel_DataClass]: + ... + + @te.overload + def __get__( + self, obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None + ) -> t.Type[_FSAModel_KW]: + ... + + @te.overload + def __get__( + self, + obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBaseNoMeta]], + obj_cls: t.Any = None, + ) -> t.Type[_FSAModel_KW]: + ... + + @te.overload + def __get__( + self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None + ) -> t.Type[_FSAModel_KW]: + ... + + @te.overload + def __get__( + self: te.Self, obj: None, obj_cls: t.Optional[t.Type[SQLAlchemy[t.Any]]] = None + ) -> t.Type[_FSAModel]: + ... + + def __get__( + self: te.Self, obj: t.Optional[SQLAlchemy[t.Any]], obj_cls: t.Any = None + ) -> t.Union[te.Self, t.Type[Model], t.Type[t.Any]]: + if isinstance(obj, SQLAlchemy): + return obj._Model + else: + return self + + def _get_2x_declarative_bases( model_class: _FSA_MCT, ) -> list[t.Type[t.Union[sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta]]]: @@ -58,7 +170,7 @@ def _get_2x_declarative_bases( ] -class SQLAlchemy: +class SQLAlchemy(t.Generic[_FSA_MCT_T]): """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, associating tables and models with specific engines, and cleaning up connections and sessions after each request. @@ -168,7 +280,7 @@ def __init__( metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, query_class: type[Query] = Query, - model_class: _FSA_MCT = Model, # type: ignore[assignment] + model_class: _FSA_MCT_T = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, disable_autonaming: bool = False, @@ -241,29 +353,17 @@ def __init__( This is a subclass of SQLAlchemy's ``Table`` rather than a function. """ - self.Model = self._make_declarative_base( + self._Model = self._make_declarative_base( model_class, disable_autonaming=disable_autonaming ) - """A SQLAlchemy declarative model class. Subclass this to define database - models. - - If a model does not set ``__tablename__``, it will be generated by converting - the class name from ``CamelCase`` to ``snake_case``. It will not be generated - if the model looks like it uses single-table inheritance. - - If a model or parent class sets ``__bind_key__``, it will use that metadata and - database engine. Otherwise, it will use the default :attr:`metadata` and - :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. - - For code using the SQLAlchemy 1.x API, customize this model by subclassing - :class:`.Model` and passing the ``model_class`` parameter to the extension. - A fully created declarative model class can be - passed as well, to use a custom metaclass. - - For code using the SQLAlchemy 2.x API, customize this model by subclassing - :class:`sqlalchemy.orm.DeclarativeBase` or - :class:`sqlalchemy.orm.DeclarativeBaseNoMeta` - and passing the ``model_class`` parameter to the extension. + """A SQLAlchemy declarative model class. This private model class is returned + by ``_make_declarative_base``. + + At run time, this class is the same as ``SQLAlchemy.Model``. Accessing + ``SQLAlchemy.Model`` rather than this class is more recommended because + ``SQLAlchemy.Model`` can provide better type hints. + + :meta private: """ if engine_options is None: @@ -277,6 +377,31 @@ def __init__( if app is not None: self.init_app(app) + # Need to be placed after __init__ because __init__ takes a default value + # named `Model`. + Model = ModelGetter() + """A SQLAlchemy declarative model class. Subclass this to define database + models. + + If a model does not set ``__tablename__``, it will be generated by converting + the class name from ``CamelCase`` to ``snake_case``. It will not be generated + if the model looks like it uses single-table inheritance. + + If a model or parent class sets ``__bind_key__``, it will use that metadata and + database engine. Otherwise, it will use the default :attr:`metadata` and + :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. + + For code using the SQLAlchemy 1.x API, customize this model by subclassing + :class:`.Model` and passing the ``model_class`` parameter to the extension. + A fully created declarative model class can be + passed as well, to use a custom metaclass. + + For code using the SQLAlchemy 2.x API, customize this model by subclassing + :class:`sqlalchemy.orm.DeclarativeBase` or + :class:`sqlalchemy.orm.DeclarativeBaseNoMeta` + and passing the ``model_class`` parameter to the extension. + """ + def __repr__(self) -> str: if not has_app_context(): return f"<{type(self).__name__}>" @@ -503,9 +628,7 @@ def __new__( return Table def _make_declarative_base( - self, - model_class: _FSA_MCT, - disable_autonaming: bool = False, + self, model_class: _FSA_MCT, disable_autonaming: bool = False ) -> t.Type[_FSAModel]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -534,7 +657,7 @@ def _make_declarative_base( ``model`` can be an already created declarative model class. """ model: t.Type[_FSAModel] - declarative_bases = _get_2x_declarative_bases(model_class) + declarative_bases = _get_2x_declarative_bases(t.cast(t.Any, model_class)) if len(declarative_bases) > 1: # raise error if more than one declarative base is found raise ValueError( @@ -547,11 +670,14 @@ def _make_declarative_base( mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: mixin_classes.remove(NameMixin) - model = types.new_class( - "FlaskSQLAlchemyBase", - (*mixin_classes, *model_class.__bases__), - {"metaclass": type(declarative_bases[0])}, - lambda ns: ns.update(body), + model = t.cast( + t.Type[_FSAModel], + types.new_class( + "FlaskSQLAlchemyBase", + (*mixin_classes, *model_class.__bases__), + {"metaclass": type(declarative_bases[0])}, + lambda ns: ns.update(body), + ), ) elif not isinstance(model_class, sa_orm.DeclarativeMeta): metadata = self._make_metadata(None) From 8f2bfd1b7dbe0c5829ef57b99f3da09b95666ffe Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 27 Mar 2024 16:43:34 -0500 Subject: [PATCH 2/9] Ensure the full type hints for parameterized `SQLAlchemy`. --- src/flask_sqlalchemy/model.py | 6 +++--- src/flask_sqlalchemy/session.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index c6f9e5a9..e79bf4a4 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -33,7 +33,7 @@ class Model: already created declarative model class as ``model_class``. """ - __fsa__: t.ClassVar[SQLAlchemy] + __fsa__: t.ClassVar[SQLAlchemy[t.Any]] """Internal reference to the extension object. :meta private: @@ -75,7 +75,7 @@ class BindMetaMixin(type): directly on the child model. """ - __fsa__: SQLAlchemy + __fsa__: SQLAlchemy[t.Any] metadata: sa.MetaData def __init__( @@ -106,7 +106,7 @@ class BindMixin: .. versionchanged:: 3.1.0 """ - __fsa__: SQLAlchemy + __fsa__: SQLAlchemy[t.Any] metadata: sa.MetaData @classmethod diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index 631fffa8..b1c15710 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -23,7 +23,7 @@ class Session(sa_orm.Session): Renamed from ``SignallingSession``. """ - def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: + def __init__(self, db: SQLAlchemy[t.Any], **kwargs: t.Any) -> None: super().__init__(**kwargs) self._db = db self._model_changes: dict[object, tuple[t.Any, str]] = {} From 1638622055755213f6c68af429c07f9998881648 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 27 Mar 2024 16:44:40 -0500 Subject: [PATCH 3/9] Ensure the full type hints for parameterized `SQLAlchemy` in all test codes. --- tests/conftest.py | 4 ++-- tests/test_cli.py | 2 +- tests/test_engine.py | 2 +- tests/test_extension_object.py | 2 +- tests/test_extension_repr.py | 13 ++++++++----- tests/test_legacy_query.py | 13 +++++++------ tests/test_metadata.py | 16 +++++++++------- tests/test_model.py | 28 +++++++++++++++------------- tests/test_model_bind.py | 16 +++++++++------- tests/test_model_name.py | 26 +++++++++++++------------- tests/test_pagination.py | 6 +++--- tests/test_record_queries.py | 4 +++- tests/test_session.py | 2 +- tests/test_table_bind.py | 10 ++++++---- tests/test_view_query.py | 10 +++++----- 15 files changed, 84 insertions(+), 70 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d4ab92f4..44270a5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,7 +63,7 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: @pytest.fixture(params=test_classes) -def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: +def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy[t.Any]: if request.param is not Model: return SQLAlchemy(app, model_class=types.new_class(*request.param)) else: @@ -79,7 +79,7 @@ def model_class(request: pytest.FixtureRequest) -> t.Any: @pytest.fixture -def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]: +def Todo(app: Flask, db: SQLAlchemy[t.Any]) -> t.Generator[t.Any, None, None]: if issubclass(db.Model, (sa_orm.MappedAsDataclass)): class Todo(db.Model): diff --git a/tests/test_cli.py b/tests/test_cli.py index 91672733..42131592 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_shell_context(db: SQLAlchemy, Todo: t.Any) -> None: +def test_shell_context(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: context = add_models_to_shell() assert context["db"] is db assert context["Todo"] is Todo diff --git a/tests/test_engine.py b/tests/test_engine.py index 0e88d5e3..dcdb0ad8 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,7 +12,7 @@ from flask_sqlalchemy import SQLAlchemy -def test_default_engine(app: Flask, db: SQLAlchemy) -> None: +def test_default_engine(app: Flask, db: SQLAlchemy[t.Any]) -> None: with app.app_context(): assert db.engine is db.engines[None] diff --git a/tests/test_extension_object.py b/tests/test_extension_object.py index 0cb5a608..e0ae8699 100644 --- a/tests/test_extension_object.py +++ b/tests/test_extension_object.py @@ -13,7 +13,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() diff --git a/tests/test_extension_repr.py b/tests/test_extension_repr.py index cfa94c75..c2312e5a 100644 --- a/tests/test_extension_repr.py +++ b/tests/test_extension_repr.py @@ -1,12 +1,15 @@ from __future__ import annotations +import typing as t + from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model def test_repr_no_context() -> None: - db = SQLAlchemy() + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -15,7 +18,7 @@ def test_repr_no_context() -> None: def test_repr_default() -> None: - db = SQLAlchemy() + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -25,7 +28,7 @@ def test_repr_default() -> None: def test_repr_default_plustwo() -> None: - db = SQLAlchemy() + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" app.config["SQLALCHEMY_BINDS"] = { @@ -39,7 +42,7 @@ def test_repr_default_plustwo() -> None: def test_repr_nodefault() -> None: - db = SQLAlchemy() + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = {"x": "sqlite:///:memory:"} @@ -49,7 +52,7 @@ def test_repr_nodefault() -> None: def test_repr_nodefault_plustwo() -> None: - db = SQLAlchemy() + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = { "a": "sqlite:///:memory:", diff --git a/tests/test_legacy_query.py b/tests/test_legacy_query.py index 170e5bb7..40aa08df 100644 --- a/tests/test_legacy_query.py +++ b/tests/test_legacy_query.py @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model from flask_sqlalchemy.query import Query @@ -25,7 +26,7 @@ def ignore_query_warning() -> t.Generator[None, None, None]: @pytest.mark.usefixtures("app_ctx") -def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() @@ -36,7 +37,7 @@ def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_first_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.commit() assert Todo.query.filter_by(title="a").first_or_404().title == "a" @@ -46,7 +47,7 @@ def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_one_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.add(Todo(title="b")) db.session.add(Todo(title="b")) @@ -63,7 +64,7 @@ def test_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: +def test_paginate(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add_all(Todo() for _ in range(150)) db.session.commit() p = Todo.query.paginate() @@ -75,7 +76,7 @@ def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_default_query_class(db: SQLAlchemy) -> None: +def test_default_query_class(db: SQLAlchemy[t.Any]) -> None: class Parent(db.Model): id = sa.Column(sa.Integer, primary_key=True) children1 = db.relationship("Child", backref="parent1", lazy="dynamic") @@ -101,7 +102,7 @@ def test_custom_query_class(app: Flask) -> None: class CustomQuery(Query): pass - db = SQLAlchemy(app, query_class=CustomQuery) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app, query_class=CustomQuery) class Parent(db.Model): id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 8b54e5bc..833a08db 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -13,7 +13,7 @@ from flask_sqlalchemy.model import Model -def test_default_metadata(db: SQLAlchemy) -> None: +def test_default_metadata(db: SQLAlchemy[t.Any]) -> None: assert db.metadata is db.metadatas[None] assert db.metadata.info["bind_key"] is None assert db.Model.metadata is db.metadata @@ -21,7 +21,7 @@ def test_default_metadata(db: SQLAlchemy) -> None: def test_custom_metadata_1x() -> None: metadata = sa.MetaData() - db = SQLAlchemy(metadata=metadata) + db: SQLAlchemy[t.Any] = SQLAlchemy(metadata=metadata) assert db.metadata is metadata assert db.metadata.info["bind_key"] is None assert db.Model.metadata is db.metadata @@ -34,7 +34,9 @@ class Base(sa_orm.DeclarativeBase): pass with pytest.deprecated_call(): - db = SQLAlchemy(model_class=Base, metadata=custom_metadata) + db: SQLAlchemy[t.Type[Base]] = SQLAlchemy( + model_class=Base, metadata=custom_metadata + ) assert db.metadata is Base.metadata assert db.metadata.info["bind_key"] is None @@ -88,7 +90,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: model_class.metadata = sa.MetaData( naming_convention={"pk": "spk_%(table_name)s"} ) - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[t.Type[t.Any]] = SQLAlchemy(app, model_class=model_class) else: db = SQLAlchemy( app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) @@ -100,7 +102,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") def test_create_drop_all(app: Flask) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -131,7 +133,7 @@ class Post(db.Model): @pytest.mark.parametrize("bind_key", ["a", ["a"]]) def test_create_key_spec(app: Flask, bind_key: str | list[str | None]) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -151,7 +153,7 @@ class Post(db.Model): def test_reflect(app: Flask) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///user.db" app.config["SQLALCHEMY_BINDS"] = {"post": "sqlite:///post.db"} - db = SQLAlchemy(app) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) db.Table("post", sa.Column("id", sa.Integer, primary_key=True), bind_key="post") db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 0968a1e2..837b5186 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -15,7 +15,7 @@ def test_default_model_class_1x(app: Flask) -> None: - db = SQLAlchemy(app) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) assert db.Model.query_class is db.Query assert db.Model.metadata is db.metadata @@ -27,7 +27,7 @@ def test_custom_model_class_1x(app: Flask) -> None: class CustomModel(Model): pass - db = SQLAlchemy(app, model_class=CustomModel) + db: SQLAlchemy[t.Type[CustomModel]] = SQLAlchemy(app, model_class=CustomModel) assert issubclass(db.Model, CustomModel) assert isinstance(db.Model, DefaultMeta) @@ -39,7 +39,7 @@ class CustomMeta(DefaultMeta): pass CustomModel = sa_orm.declarative_base(cls=base, name="Model", metaclass=CustomMeta) - db = SQLAlchemy(app, model_class=CustomModel) + db: SQLAlchemy[CustomMeta] = SQLAlchemy(app, model_class=CustomModel) assert db.Model is CustomModel assert db.Model.query_class is db.Query assert "query" in db.Model.__dict__ @@ -82,11 +82,11 @@ class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass): @pytest.mark.usefixtures("app_ctx") -def test_declaredattr(app: Flask, model_class: t.Any) -> None: +def test_declaredattr(app: Flask, model_class: t.Type[Model]) -> None: if model_class is Model: class IdModel(Model): - @sa.orm.declared_attr + @sa_orm.declared_attr @classmethod def id(cls: type[Model]): # type: ignore[no-untyped-def] for base in cls.__mro__[1:-1]: @@ -96,7 +96,9 @@ def id(cls: type[Model]): # type: ignore[no-untyped-def] return sa.Column(sa.ForeignKey(base.id), primary_key=True) return sa.Column(sa.Integer, primary_key=True) - db = SQLAlchemy(app, model_class=IdModel) + db: t.Union[SQLAlchemy[t.Type[IdModel]], SQLAlchemy[t.Type[Base]]] = SQLAlchemy( + app, model_class=IdModel + ) class User(db.Model): name = db.Column(db.String) @@ -140,7 +142,7 @@ class Employee(User): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") def test_abstractmodel(app: Flask, model_class: t.Any) -> None: - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[t.Any] = SQLAlchemy(app, model_class=model_class) if issubclass(db.Model, (sa_orm.MappedAsDataclass)): @@ -201,7 +203,7 @@ class Post(TimestampModel): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") def test_mixinmodel(app: Flask, model_class: t.Any) -> None: - db = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[t.Type[t.Any]] = SQLAlchemy(app, model_class=model_class) if issubclass(db.Model, (sa_orm.MappedAsDataclass)): @@ -216,7 +218,7 @@ class TimestampMixin(sa_orm.MappedAsDataclass): init=False, ) - class Post(TimestampMixin, db.Model): + class Post(db.Model, TimestampMixin): id: sa_orm.Mapped[int] = sa_orm.mapped_column( db.Integer, primary_key=True, init=False ) @@ -232,7 +234,7 @@ class TimestampMixin: # type: ignore[no-redef] db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow ) - class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + class Post(db.Model, TimestampMixin): # type: ignore[no-redef] id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True) title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) @@ -244,7 +246,7 @@ class TimestampMixin: # type: ignore[no-redef] db.DateTime, onupdate=datetime.utcnow, default=datetime.utcnow ) - class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + class Post(db.Model, TimestampMixin): # type: ignore[no-redef] id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String, nullable=False) @@ -258,7 +260,7 @@ class Post(TimestampMixin, db.Model): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") -def test_model_repr(db: SQLAlchemy) -> None: +def test_model_repr(db: SQLAlchemy[t.Type[Model]]) -> None: class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -286,7 +288,7 @@ class Base(sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta): # type: ignor @pytest.mark.usefixtures("app_ctx") def test_disable_autonaming_true_sql1(app: Flask) -> None: - db = SQLAlchemy(app, disable_autonaming=True) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app, disable_autonaming=True) with pytest.raises(sa_exc.InvalidRequestError): diff --git a/tests/test_model_bind.py b/tests/test_model_bind.py index 7c633c83..c058b3bc 100644 --- a/tests/test_model_bind.py +++ b/tests/test_model_bind.py @@ -1,18 +1,20 @@ from __future__ import annotations +import typing as t + import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -def test_bind_key_default(db: SQLAlchemy) -> None: +def test_bind_key_default(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) assert User.metadata is db.metadata -def test_metadata_per_bind(db: SQLAlchemy) -> None: +def test_metadata_per_bind(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): __bind_key__ = "other" id = sa.Column(sa.Integer, primary_key=True) @@ -20,7 +22,7 @@ class User(db.Model): assert User.metadata is db.metadatas["other"] -def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: +def test_multiple_binds_same_table_name(db: SQLAlchemy[t.Any]) -> None: class UserA(db.Model): __tablename__ = "user" id = sa.Column(sa.Integer, primary_key=True) @@ -35,7 +37,7 @@ class UserB(db.Model): assert UserA.__table__.metadata is not UserB.__table__.metadata -def test_inherit_parent(db: SQLAlchemy) -> None: +def test_inherit_parent(db: SQLAlchemy[t.Any]) -> None: class User(db.Model): __bind_key__ = "auth" id = sa.Column(sa.Integer, primary_key=True) @@ -51,7 +53,7 @@ class Admin(User): assert "metadata" not in Admin.__dict__ -def test_inherit_abstract_parent(db: SQLAlchemy) -> None: +def test_inherit_abstract_parent(db: SQLAlchemy[t.Any]) -> None: class AbstractUser(db.Model): __abstract__ = True __bind_key__ = "auth" @@ -63,7 +65,7 @@ class User(AbstractUser): assert "metadata" not in User.__dict__ -def test_explicit_metadata(db: SQLAlchemy) -> None: +def test_explicit_metadata(db: SQLAlchemy[t.Any]) -> None: other_metadata = sa.MetaData() class User(db.Model): @@ -75,7 +77,7 @@ class User(db.Model): assert "other" not in db.metadatas -def test_explicit_table(db: SQLAlchemy) -> None: +def test_explicit_table(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), diff --git a/tests/test_model_name.py b/tests/test_model_name.py index 1b8cf87c..fef987f9 100644 --- a/tests/test_model_name.py +++ b/tests/test_model_name.py @@ -48,7 +48,7 @@ def test_camel_to_snake_case(name: str, expect: str) -> None: assert camel_to_snake_case(name) == expect -def test_name(db: SQLAlchemy) -> None: +def test_name(db: SQLAlchemy[t.Any]) -> None: class FOOBar(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -64,7 +64,7 @@ class Ham(db.Model): assert Ham.__tablename__ == "spam" -def test_single_name(db: SQLAlchemy) -> None: +def test_single_name(db: SQLAlchemy[t.Any]) -> None: """Single table inheritance should not set a new name.""" class Duck(db.Model): @@ -77,7 +77,7 @@ class Mallard(Duck): assert Mallard.__tablename__ == "duck" -def test_joined_name(db: SQLAlchemy) -> None: +def test_joined_name(db: SQLAlchemy[t.Any]) -> None: """Model has a separate primary key; it should set a new name.""" class Duck(db.Model): @@ -89,7 +89,7 @@ class Donald(Duck): assert Donald.__tablename__ == "donald" -def test_mixin_id(db: SQLAlchemy) -> None: +def test_mixin_id(db: SQLAlchemy[t.Any]) -> None: """Primary key provided by mixin should still allow model to set tablename. """ @@ -104,7 +104,7 @@ class Duck(Base, db.Model): assert Duck.__tablename__ == "duck" -def test_mixin_attr(db: SQLAlchemy) -> None: +def test_mixin_attr(db: SQLAlchemy[t.Any]) -> None: """A declared attr tablename will be used down multiple levels of inheritance. """ @@ -130,7 +130,7 @@ class Mallard(Duck): assert Mallard.__tablename__ == "MALLARD" -def test_abstract_name(db: SQLAlchemy) -> None: +def test_abstract_name(db: SQLAlchemy[t.Any]) -> None: """Abstract model should not set a name. Subclass should set a name.""" class Base(db.Model): @@ -144,7 +144,7 @@ class Duck(Base): assert Duck.__tablename__ == "duck" -def test_complex_inheritance(db: SQLAlchemy) -> None: +def test_complex_inheritance(db: SQLAlchemy[t.Any]) -> None: """Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class. """ @@ -163,7 +163,7 @@ class RubberDuck(IdMixin, Duck): # type: ignore[misc] assert RubberDuck.__tablename__ == "rubber_duck" -def test_manual_name(db: SQLAlchemy) -> None: +def test_manual_name(db: SQLAlchemy[t.Any]) -> None: """Setting a manual name prevents generation for the immediate model. A name is generated for joined but not single-table inheritance. """ @@ -189,7 +189,7 @@ class Donald(Duck): assert Donald.__tablename__ == "DUCK" -def test_primary_constraint(db: SQLAlchemy) -> None: +def test_primary_constraint(db: SQLAlchemy[t.Any]) -> None: """Primary key will be picked up from table args.""" class Duck(db.Model): @@ -201,7 +201,7 @@ class Duck(db.Model): assert Duck.__tablename__ == "duck" -def test_no_access_to_class_property(db: SQLAlchemy) -> None: +def test_no_access_to_class_property(db: SQLAlchemy[t.Any]) -> None: """Ensure the implementation doesn't access class properties or declared attrs while inspecting the unmapped model. """ @@ -237,7 +237,7 @@ def floats(self) -> None: assert not ns.floats -def test_metadata_has_table(db: SQLAlchemy) -> None: +def test_metadata_has_table(db: SQLAlchemy[t.Any]) -> None: user = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) class User(db.Model): @@ -246,7 +246,7 @@ class User(db.Model): assert User.__table__ is user -def test_correct_error_for_no_primary_key(db: SQLAlchemy) -> None: +def test_correct_error_for_no_primary_key(db: SQLAlchemy[t.Any]) -> None: with pytest.raises(sa_exc.ArgumentError) as info: class User(db.Model): @@ -255,7 +255,7 @@ class User(db.Model): assert "could not assemble any primary key" in str(info.value) -def test_single_has_parent_table(db: SQLAlchemy) -> None: +def test_single_has_parent_table(db: SQLAlchemy[t.Any]) -> None: class Duck(db.Model): id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 14e24a9e..5149b5ca 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -119,7 +119,7 @@ def test_iter_pages_short(page: int) -> None: class _PaginateCallable: - def __init__(self, app: Flask, db: SQLAlchemy, Todo: t.Any) -> None: + def __init__(self, app: Flask, db: SQLAlchemy[t.Any], Todo: t.Any) -> None: self.app = app self.db = db self.Todo = Todo @@ -143,7 +143,7 @@ def __call__( @pytest.fixture -def paginate(app: Flask, db: SQLAlchemy, Todo: t.Any) -> _PaginateCallable: +def paginate(app: Flask, db: SQLAlchemy[t.Any], Todo: t.Any) -> _PaginateCallable: with app.app_context(): for i in range(1, 251): db.session.add(Todo(title=f"task {i}")) @@ -197,7 +197,7 @@ def test_error_out(paginate: _PaginateCallable, page: t.Any, per_page: t.Any) -> @pytest.mark.usefixtures("app_ctx") -def test_no_items_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_no_items_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: p = db.paginate(db.select(Todo)) assert len(p.items) == 0 diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index c5cc73a2..805e45fc 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import typing as t import pytest import sqlalchemy as sa @@ -8,13 +9,14 @@ from flask import Flask from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model from flask_sqlalchemy.record_queries import get_recorded_queries @pytest.mark.usefixtures("app_ctx") def test_query_info(app: Flask) -> None: app.config["SQLALCHEMY_RECORD_QUERIES"] = True - db = SQLAlchemy(app) + db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) # Copied and pasted from conftest.py if issubclass(db.Model, (sa_orm.MappedAsDataclass)): diff --git a/tests/test_session.py b/tests/test_session.py index cf75626a..5ff170e1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -11,7 +11,7 @@ from flask_sqlalchemy.session import Session -def test_scope(app: Flask, db: SQLAlchemy) -> None: +def test_scope(app: Flask, db: SQLAlchemy[t.Any]) -> None: with pytest.raises(RuntimeError): db.session() diff --git a/tests/test_table_bind.py b/tests/test_table_bind.py index fd83d1a9..78a85fc6 100644 --- a/tests/test_table_bind.py +++ b/tests/test_table_bind.py @@ -1,23 +1,25 @@ from __future__ import annotations +import typing as t + import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy -def test_bind_key_default(db: SQLAlchemy) -> None: +def test_bind_key_default(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) assert user_table.metadata is db.metadata -def test_metadata_per_bind(db: SQLAlchemy) -> None: +def test_metadata_per_bind(db: SQLAlchemy[t.Any]) -> None: user_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" ) assert user_table.metadata is db.metadatas["other"] -def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: +def test_multiple_binds_same_table_name(db: SQLAlchemy[t.Any]) -> None: user1_table = db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) user2_table = db.Table( "user", sa.Column("id", sa.Integer, primary_key=True), bind_key="other" @@ -27,7 +29,7 @@ def test_multiple_binds_same_table_name(db: SQLAlchemy) -> None: assert user2_table.metadata is db.metadatas["other"] -def test_explicit_metadata(db: SQLAlchemy) -> None: +def test_explicit_metadata(db: SQLAlchemy[t.Any]) -> None: other_metadata = sa.MetaData() user_table = db.Table( "user", diff --git a/tests/test_view_query.py b/tests/test_view_query.py index c1d056c1..0557b40b 100644 --- a/tests/test_view_query.py +++ b/tests/test_view_query.py @@ -12,7 +12,7 @@ @pytest.mark.usefixtures("app_ctx") -def test_view_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_view_get_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: item = Todo() db.session.add(item) db.session.commit() @@ -22,7 +22,7 @@ def test_view_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_first_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.commit() result = db.first_or_404(db.select(Todo).filter_by(title="a")) @@ -33,7 +33,7 @@ def test_first_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_view_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: +def test_view_one_or_404(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add(Todo(title="a")) db.session.add(Todo(title="b")) db.session.add(Todo(title="b")) @@ -51,7 +51,7 @@ def test_view_one_or_404(db: SQLAlchemy, Todo: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") -def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: +def test_paginate(db: SQLAlchemy[t.Any], Todo: t.Any) -> None: db.session.add_all(Todo() for _ in range(150)) db.session.commit() p = db.paginate(db.select(Todo)) @@ -64,7 +64,7 @@ def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: # This test creates its own inline model so that it can use that as the type @pytest.mark.usefixtures("app_ctx") -def test_view_get_or_404_typed(db: SQLAlchemy, app: Flask) -> None: +def test_view_get_or_404_typed(db: SQLAlchemy[t.Any], app: Flask) -> None: # Copied and pasted from conftest.py if issubclass(db.Model, (sa_orm.MappedAsDataclass)): From 3065b7c2ae063caf200b1e8bca444576b7002cd6 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 27 Mar 2024 16:45:27 -0500 Subject: [PATCH 4/9] Fix an issue when `tox p` fails because `mypy` is forbidden. See details here: https://stackoverflow.com/a/47716994/8266012 --- tox.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tox.ini b/tox.ini index 7085991f..1915efe3 100644 --- a/tox.ini +++ b/tox.ini @@ -23,12 +23,14 @@ commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} deps = pre-commit skip_install = true commands = pre-commit run --all-files +allowlist_externals = mypy [testenv:typing] deps = -r requirements/mypy.txt commands = mypy --python-version 3.8 mypy --python-version 3.11 +allowlist_externals = mypy [testenv:docs] deps = -r requirements/docs.txt From 2dddeb7568c8cb5fbd4bc6b62616897fffa315a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:05:35 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/flask_sqlalchemy/extension.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 80b50c64..3d54ac00 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -52,8 +52,7 @@ class _FSAModel(Model): if t.TYPE_CHECKING: class _FSAModel_KW(_FSAModel): - def __init__(self, **kw: t.Any) -> None: - ... + def __init__(self, **kw: t.Any) -> None: ... else: # To minimize side effects, the type hint only works for static type checker. @@ -76,8 +75,7 @@ def __init__(self, **kw: t.Any) -> None: sa_orm.deferred, ), ) - class _FSAModel_DataClass(_FSAModel): - ... + class _FSAModel_DataClass(_FSAModel): ... else: # To minimize side effects, the type hint only works for static type checker. @@ -113,8 +111,7 @@ class ModelGetter: @te.overload def __get__( self, obj: SQLAlchemy[t.Type[Model]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: - ... + ) -> t.Type[_FSAModel_KW]: ... # This variant needs to be prior than DeclarativeBase, because a class may inherit # multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO @@ -122,34 +119,29 @@ def __get__( @te.overload def __get__( self, obj: SQLAlchemy[t.Type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_DataClass]: - ... + ) -> t.Type[_FSAModel_DataClass]: ... @te.overload def __get__( self, obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: - ... + ) -> t.Type[_FSAModel_KW]: ... @te.overload def __get__( self, obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBaseNoMeta]], obj_cls: t.Any = None, - ) -> t.Type[_FSAModel_KW]: - ... + ) -> t.Type[_FSAModel_KW]: ... @te.overload def __get__( self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: - ... + ) -> t.Type[_FSAModel_KW]: ... @te.overload def __get__( self: te.Self, obj: None, obj_cls: t.Optional[t.Type[SQLAlchemy[t.Any]]] = None - ) -> t.Type[_FSAModel]: - ... + ) -> t.Type[_FSAModel]: ... def __get__( self: te.Self, obj: t.Optional[SQLAlchemy[t.Any]], obj_cls: t.Any = None From 3c46668e8ee515a0eed01345dacc2527b36fcf08 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Mon, 29 Apr 2024 08:12:02 -0500 Subject: [PATCH 6/9] `t.Type` -> `type` to match the pre-commit check. --- src/flask_sqlalchemy/extension.py | 34 +++++++++++++++---------------- tests/test_extension_repr.py | 10 ++++----- tests/test_legacy_query.py | 2 +- tests/test_metadata.py | 10 ++++----- tests/test_model.py | 14 ++++++------- tests/test_record_queries.py | 2 +- 6 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 3d54ac00..066418a1 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -35,11 +35,11 @@ # Type accepted for model_class argument _FSA_MCT = t.Union[ - t.Type[Model], + type[Model], sa_orm.DeclarativeMeta, - t.Type[sa_orm.DeclarativeBase], - t.Type[sa_orm.DeclarativeBaseNoMeta], - t.Type[sa_orm.MappedAsDataclass], + type[sa_orm.DeclarativeBase], + type[sa_orm.DeclarativeBaseNoMeta], + type[sa_orm.MappedAsDataclass], ] _FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True) @@ -110,42 +110,42 @@ class ModelGetter: # Note that in actual using cases, users do not need to inherit Model classes. @te.overload def __get__( - self, obj: SQLAlchemy[t.Type[Model]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: ... + self, obj: SQLAlchemy[type[Model]], obj_cls: t.Any = None + ) -> type[_FSAModel_KW]: ... # This variant needs to be prior than DeclarativeBase, because a class may inherit # multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO # list, this configuration make type[_FSAModel_DataClass] preferred. @te.overload def __get__( - self, obj: SQLAlchemy[t.Type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_DataClass]: ... + self, obj: SQLAlchemy[type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None + ) -> type[_FSAModel_DataClass]: ... @te.overload def __get__( - self, obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: ... + self, obj: SQLAlchemy[type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None + ) -> type[_FSAModel_KW]: ... @te.overload def __get__( self, - obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBaseNoMeta]], + obj: SQLAlchemy[type[sa_orm.DeclarativeBaseNoMeta]], obj_cls: t.Any = None, - ) -> t.Type[_FSAModel_KW]: ... + ) -> type[_FSAModel_KW]: ... @te.overload def __get__( self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None - ) -> t.Type[_FSAModel_KW]: ... + ) -> type[_FSAModel_KW]: ... @te.overload def __get__( - self: te.Self, obj: None, obj_cls: t.Optional[t.Type[SQLAlchemy[t.Any]]] = None - ) -> t.Type[_FSAModel]: ... + self: te.Self, obj: None, obj_cls: t.Optional[type[SQLAlchemy[t.Any]]] = None + ) -> type[_FSAModel]: ... def __get__( self: te.Self, obj: t.Optional[SQLAlchemy[t.Any]], obj_cls: t.Any = None - ) -> t.Union[te.Self, t.Type[Model], t.Type[t.Any]]: + ) -> t.Union[te.Self, type[Model], type[t.Any]]: if isinstance(obj, SQLAlchemy): return obj._Model else: @@ -665,7 +665,7 @@ def _make_declarative_base( if disable_autonaming: mixin_classes.remove(NameMixin) model = t.cast( - t.Type[_FSAModel], + type[_FSAModel], types.new_class( "FlaskSQLAlchemyBase", (*mixin_classes, *model_class.__bases__), diff --git a/tests/test_extension_repr.py b/tests/test_extension_repr.py index c2312e5a..de226f69 100644 --- a/tests/test_extension_repr.py +++ b/tests/test_extension_repr.py @@ -9,7 +9,7 @@ def test_repr_no_context() -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -18,7 +18,7 @@ def test_repr_no_context() -> None: def test_repr_default() -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" @@ -28,7 +28,7 @@ def test_repr_default() -> None: def test_repr_default_plustwo() -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" app.config["SQLALCHEMY_BINDS"] = { @@ -42,7 +42,7 @@ def test_repr_default_plustwo() -> None: def test_repr_nodefault() -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = {"x": "sqlite:///:memory:"} @@ -52,7 +52,7 @@ def test_repr_nodefault() -> None: def test_repr_nodefault_plustwo() -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy() + db: SQLAlchemy[type[Model]] = SQLAlchemy() app = Flask(__name__) app.config["SQLALCHEMY_BINDS"] = { "a": "sqlite:///:memory:", diff --git a/tests/test_legacy_query.py b/tests/test_legacy_query.py index 40aa08df..7d073d84 100644 --- a/tests/test_legacy_query.py +++ b/tests/test_legacy_query.py @@ -102,7 +102,7 @@ def test_custom_query_class(app: Flask) -> None: class CustomQuery(Query): pass - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app, query_class=CustomQuery) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app, query_class=CustomQuery) class Parent(db.Model): id = sa.Column(sa.Integer, primary_key=True) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 833a08db..e1fc4c8f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -34,7 +34,7 @@ class Base(sa_orm.DeclarativeBase): pass with pytest.deprecated_call(): - db: SQLAlchemy[t.Type[Base]] = SQLAlchemy( + db: SQLAlchemy[type[Base]] = SQLAlchemy( model_class=Base, metadata=custom_metadata ) @@ -90,7 +90,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: model_class.metadata = sa.MetaData( naming_convention={"pk": "spk_%(table_name)s"} ) - db: SQLAlchemy[t.Type[t.Any]] = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[type[t.Any]] = SQLAlchemy(app, model_class=model_class) else: db = SQLAlchemy( app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) @@ -102,7 +102,7 @@ def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: @pytest.mark.usefixtures("app_ctx") def test_create_drop_all(app: Flask) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -133,7 +133,7 @@ class Post(db.Model): @pytest.mark.parametrize("bind_key", ["a", ["a"]]) def test_create_key_spec(app: Flask, bind_key: str | list[str | None]) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -153,7 +153,7 @@ class Post(db.Model): def test_reflect(app: Flask) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///user.db" app.config["SQLALCHEMY_BINDS"] = {"post": "sqlite:///post.db"} - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) db.Table("user", sa.Column("id", sa.Integer, primary_key=True)) db.Table("post", sa.Column("id", sa.Integer, primary_key=True), bind_key="post") db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 4fe60151..11098591 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -20,7 +20,7 @@ def now() -> datetime: def test_default_model_class_1x(app: Flask) -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) assert db.Model.query_class is db.Query assert db.Model.metadata is db.metadata @@ -32,7 +32,7 @@ def test_custom_model_class_1x(app: Flask) -> None: class CustomModel(Model): pass - db: SQLAlchemy[t.Type[CustomModel]] = SQLAlchemy(app, model_class=CustomModel) + db: SQLAlchemy[type[CustomModel]] = SQLAlchemy(app, model_class=CustomModel) assert issubclass(db.Model, CustomModel) assert isinstance(db.Model, DefaultMeta) @@ -87,7 +87,7 @@ class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass): @pytest.mark.usefixtures("app_ctx") -def test_declaredattr(app: Flask, model_class: t.Type[Model]) -> None: +def test_declaredattr(app: Flask, model_class: type[Model]) -> None: if model_class is Model: class IdModel(Model): @@ -101,7 +101,7 @@ def id(cls: type[Model]): # type: ignore[no-untyped-def] return sa.Column(sa.ForeignKey(base.id), primary_key=True) return sa.Column(sa.Integer, primary_key=True) - db: t.Union[SQLAlchemy[t.Type[IdModel]], SQLAlchemy[t.Type[Base]]] = SQLAlchemy( + db: t.Union[SQLAlchemy[type[IdModel]], SQLAlchemy[type[Base]]] = SQLAlchemy( app, model_class=IdModel ) @@ -206,7 +206,7 @@ class Post(TimestampModel): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") def test_mixinmodel(app: Flask, model_class: t.Any) -> None: - db: SQLAlchemy[t.Type[t.Any]] = SQLAlchemy(app, model_class=model_class) + db: SQLAlchemy[type[t.Any]] = SQLAlchemy(app, model_class=model_class) if issubclass(db.Model, (sa_orm.MappedAsDataclass)): @@ -261,7 +261,7 @@ class Post(db.Model, TimestampMixin): # type: ignore[no-redef] @pytest.mark.usefixtures("app_ctx") -def test_model_repr(db: SQLAlchemy[t.Type[Model]]) -> None: +def test_model_repr(db: SQLAlchemy[type[Model]]) -> None: class User(db.Model): id = sa.Column(sa.Integer, primary_key=True) @@ -289,7 +289,7 @@ class Base(sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta): # type: ignor @pytest.mark.usefixtures("app_ctx") def test_disable_autonaming_true_sql1(app: Flask) -> None: - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app, disable_autonaming=True) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app, disable_autonaming=True) with pytest.raises(sa_exc.InvalidRequestError): diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index 805e45fc..720a1f54 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -16,7 +16,7 @@ @pytest.mark.usefixtures("app_ctx") def test_query_info(app: Flask) -> None: app.config["SQLALCHEMY_RECORD_QUERIES"] = True - db: SQLAlchemy[t.Type[Model]] = SQLAlchemy(app) + db: SQLAlchemy[type[Model]] = SQLAlchemy(app) # Copied and pasted from conftest.py if issubclass(db.Model, (sa_orm.MappedAsDataclass)): From d40889488945a520e741cd688ad714002c26648a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:12:39 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_extension_repr.py | 2 -- tests/test_record_queries.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/test_extension_repr.py b/tests/test_extension_repr.py index de226f69..d9ff5b1c 100644 --- a/tests/test_extension_repr.py +++ b/tests/test_extension_repr.py @@ -1,7 +1,5 @@ from __future__ import annotations -import typing as t - from flask import Flask from flask_sqlalchemy import SQLAlchemy diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index 720a1f54..6d5d932f 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import typing as t import pytest import sqlalchemy as sa From 9b97634c53508f5625ca04c5a8a35ade6f5018b4 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Mon, 29 Apr 2024 08:52:31 -0500 Subject: [PATCH 8/9] Change t.Union/t.Optional to "|" operator to prevent `tox -p` to fail. --- src/flask_sqlalchemy/extension.py | 14 +++++++------- tests/test_model.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 066418a1..d77e79c0 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -35,11 +35,11 @@ # Type accepted for model_class argument _FSA_MCT = t.Union[ - type[Model], + t.Type[Model], sa_orm.DeclarativeMeta, - type[sa_orm.DeclarativeBase], - type[sa_orm.DeclarativeBaseNoMeta], - type[sa_orm.MappedAsDataclass], + t.Type[sa_orm.DeclarativeBase], + t.Type[sa_orm.DeclarativeBaseNoMeta], + t.Type[sa_orm.MappedAsDataclass], ] _FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True) @@ -140,12 +140,12 @@ def __get__( @te.overload def __get__( - self: te.Self, obj: None, obj_cls: t.Optional[type[SQLAlchemy[t.Any]]] = None + self: te.Self, obj: None, obj_cls: type[SQLAlchemy[t.Any]] | None = None ) -> type[_FSAModel]: ... def __get__( - self: te.Self, obj: t.Optional[SQLAlchemy[t.Any]], obj_cls: t.Any = None - ) -> t.Union[te.Self, type[Model], type[t.Any]]: + self: te.Self, obj: SQLAlchemy[t.Any] | None, obj_cls: t.Any = None + ) -> te.Self | type[Model] | type[t.Any]: if isinstance(obj, SQLAlchemy): return obj._Model else: diff --git a/tests/test_model.py b/tests/test_model.py index 11098591..ed24bb93 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -101,7 +101,7 @@ def id(cls: type[Model]): # type: ignore[no-untyped-def] return sa.Column(sa.ForeignKey(base.id), primary_key=True) return sa.Column(sa.Integer, primary_key=True) - db: t.Union[SQLAlchemy[type[IdModel]], SQLAlchemy[type[Base]]] = SQLAlchemy( + db: SQLAlchemy[type[IdModel]] | SQLAlchemy[type[Base]] = SQLAlchemy( app, model_class=IdModel ) From f383186b66385e0994d4fcf1c92a6f45460f2113 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Mon, 29 Apr 2024 09:03:25 -0500 Subject: [PATCH 9/9] `t.cast` should accept `t.Type[...]` instead of `type[...]` to prevent runtime failure in python3.8. --- src/flask_sqlalchemy/extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index d77e79c0..9b479e74 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -665,7 +665,7 @@ def _make_declarative_base( if disable_autonaming: mixin_classes.remove(NameMixin) model = t.cast( - type[_FSAModel], + t.Type[_FSAModel], types.new_class( "FlaskSQLAlchemyBase", (*mixin_classes, *model_class.__bases__),