Skip to content

Commit 05a413b

Browse files
authored
feat: add support for SQLAlchemy func expressions in filter classes (#585)
Adds support for SQLAlchemy func() expressions in filter classes to eliminate type checker errors when using database functions like `func.random()` or `func.lower()`.
1 parent 441ff76 commit 05a413b

File tree

6 files changed

+1567
-1093
lines changed

6 files changed

+1567
-1093
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.14.0"
25+
rev: "v0.14.2"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/filters.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from dataclasses import dataclass
2727
from operator import attrgetter
2828
from typing import (
29+
TYPE_CHECKING,
2930
Any,
3031
Callable,
3132
ClassVar,
@@ -53,13 +54,15 @@
5354
text,
5455
true,
5556
)
56-
from sqlalchemy.orm import InstrumentedAttribute
5757
from sqlalchemy.sql import operators as op
5858
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
5959
from typing_extensions import TypeAlias, TypedDict, TypeVar
6060

6161
from advanced_alchemy.base import ModelProtocol
6262

63+
if TYPE_CHECKING:
64+
from sqlalchemy.orm import InstrumentedAttribute
65+
6366
__all__ = (
6467
"BeforeAfter",
6568
"CollectionFilter",
@@ -153,7 +156,9 @@ def append_to_statement(
153156
return statement
154157

155158
@staticmethod
156-
def _get_instrumented_attr(model: Any, key: Union[str, InstrumentedAttribute[Any]]) -> InstrumentedAttribute[Any]:
159+
def _get_instrumented_attr(
160+
model: Any, key: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
161+
) -> "Union[ColumnElement[Any], InstrumentedAttribute[Any]]":
157162
"""Get SQLAlchemy instrumented attribute from model.
158163
159164
Args:
@@ -166,9 +171,7 @@ def _get_instrumented_attr(model: Any, key: Union[str, InstrumentedAttribute[Any
166171
See Also:
167172
:class:`sqlalchemy.orm.attributes.InstrumentedAttribute`: SQLAlchemy attribute
168173
"""
169-
if isinstance(key, str):
170-
return cast("InstrumentedAttribute[Any]", getattr(model, key))
171-
return key
174+
return cast("InstrumentedAttribute[Any]", getattr(model, key)) if isinstance(key, str) else key
172175

173176

174177
@dataclass
@@ -186,8 +189,8 @@ class BeforeAfter(StatementFilter):
186189
187190
"""
188191

189-
field_name: str
190-
"""Name of the model attribute to filter on."""
192+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
193+
"""Field name, model attribute, or func expression."""
191194
before: Optional[datetime.datetime]
192195
"""Filter results where field is earlier than this value."""
193196
after: Optional[datetime.datetime]
@@ -232,8 +235,8 @@ class OnBeforeAfter(StatementFilter):
232235
233236
"""
234237

235-
field_name: str
236-
"""Name of the model attribute to filter on."""
238+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
239+
"""Field name, model attribute, or func expression."""
237240
on_or_before: Optional[datetime.datetime]
238241
"""Filter results where field is on or earlier than this value."""
239242
on_or_after: Optional[datetime.datetime]
@@ -280,8 +283,8 @@ class CollectionFilter(InAnyFilter, Generic[T]):
280283
Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
281284
"""
282285

283-
field_name: str
284-
"""Name of the model attribute to filter on."""
286+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
287+
"""Field name, model attribute, or func expression."""
285288
values: Union[Collection[T], None]
286289
"""Values for the ``IN`` clause. If this is None, no filter is applied.
287290
An empty list will force an empty result set (WHERE 1=-1)"""
@@ -339,8 +342,8 @@ class NotInCollectionFilter(InAnyFilter, Generic[T]):
339342
340343
"""
341344

342-
field_name: str
343-
"""Name of the model attribute to filter on."""
345+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
346+
"""Field name, model attribute, or func expression."""
344347
values: Union[Collection[T], None]
345348
"""Values for the ``NOT IN`` clause. If None or empty, no filter is applied."""
346349

@@ -443,8 +446,8 @@ class OrderBy(StatementFilter):
443446
- :meth:`sqlalchemy.sql.expression.ColumnElement.desc`: Descending order
444447
"""
445448

446-
field_name: str
447-
"""Name of the model attribute to sort on."""
449+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
450+
"""Field name, model attribute, or func expression (e.g., ``func.random()``)."""
448451
sort_order: Literal["asc", "desc"] = "asc"
449452
"""Sort direction ("asc" or "desc")."""
450453

@@ -614,8 +617,8 @@ class ComparisonFilter(StatementFilter):
614617
ValueError: If an invalid operator is provided
615618
"""
616619

617-
field_name: str
618-
"""Name of the model attribute to filter on."""
620+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
621+
"""Field name, model attribute, or func expression."""
619622
operator: str
620623
"""Comparison operator to use (one of 'eq', 'ne', 'gt', 'ge', 'lt', 'le')."""
621624
value: Any

advanced_alchemy/repository/memory/_async.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,36 +209,58 @@ def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
209209
def _apply_limit_offset_pagination(result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
210210
return result[offset:limit]
211211

212-
@staticmethod
212+
def _extract_field_name(self, field: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]") -> str:
213+
"""Extract string field name from various input types.
214+
215+
Args:
216+
field: Field name, column element, or instrumented attribute
217+
218+
Returns:
219+
str: String field name for use with getattr()
220+
221+
Raises:
222+
RepositoryError: If a ColumnElement (func expression) is used with mock repository
223+
"""
224+
if isinstance(field, str):
225+
return field
226+
if isinstance(field, InstrumentedAttribute):
227+
return field.key
228+
msg = f"{type(field)} columns are not supported in mock repositories (in-memory filtering)"
229+
raise RepositoryError(msg)
230+
213231
def _filter_in_collection(
232+
self,
214233
result: list[ModelT],
215-
field_name: str,
234+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
216235
values: abc.Collection[Any],
217236
) -> list[ModelT]:
218-
return [item for item in result if getattr(item, field_name) in values]
237+
field_str = self._extract_field_name(field_name)
238+
return [item for item in result if getattr(item, field_str) in values]
219239

220-
@staticmethod
221240
def _filter_not_in_collection(
241+
self,
222242
result: list[ModelT],
223-
field_name: str,
243+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
224244
values: abc.Collection[Any],
225245
) -> list[ModelT]:
226246
if not values:
227247
return result
228-
return [item for item in result if getattr(item, field_name) not in values]
248+
field_str = self._extract_field_name(field_name)
249+
return [item for item in result if getattr(item, field_str) not in values]
229250

230-
@staticmethod
231251
def _filter_on_datetime_field(
252+
self,
232253
result: list[ModelT],
233-
field_name: str,
254+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
234255
before: Optional[datetime.datetime] = None,
235256
after: Optional[datetime.datetime] = None,
236257
on_or_before: Optional[datetime.datetime] = None,
237258
on_or_after: Optional[datetime.datetime] = None,
238259
) -> list[ModelT]:
260+
field_str = self._extract_field_name(field_name)
239261
result_: list[ModelT] = []
240262
for item in result:
241-
attr: datetime.datetime = getattr(item, field_name)
263+
attr: datetime.datetime = getattr(item, field_str)
242264
if before is not None and attr < before:
243265
result_.append(item)
244266
if after is not None and attr > after:
@@ -302,9 +324,13 @@ def _filter_result_by_kwargs(
302324
except AttributeError as error:
303325
raise RepositoryError from error
304326

305-
@staticmethod
306-
def _order_by(result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
307-
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
327+
def _order_by(
328+
self,
329+
result: list[ModelT],
330+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
331+
sort_desc: bool = False,
332+
) -> list[ModelT]:
333+
return sorted(result, key=lambda item: getattr(item, self._extract_field_name(field_name)), reverse=sort_desc)
308334

309335
def _apply_filters(
310336
self,

advanced_alchemy/repository/memory/_sync.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,36 +210,58 @@ def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
210210
def _apply_limit_offset_pagination(result: list[ModelT], limit: int, offset: int) -> list[ModelT]:
211211
return result[offset:limit]
212212

213-
@staticmethod
213+
def _extract_field_name(self, field: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]") -> str:
214+
"""Extract string field name from various input types.
215+
216+
Args:
217+
field: Field name, column element, or instrumented attribute
218+
219+
Returns:
220+
str: String field name for use with getattr()
221+
222+
Raises:
223+
RepositoryError: If a ColumnElement (func expression) is used with mock repository
224+
"""
225+
if isinstance(field, str):
226+
return field
227+
if isinstance(field, InstrumentedAttribute):
228+
return field.key
229+
msg = f"{type(field)} columns are not supported in mock repositories (in-memory filtering)"
230+
raise RepositoryError(msg)
231+
214232
def _filter_in_collection(
233+
self,
215234
result: list[ModelT],
216-
field_name: str,
235+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
217236
values: abc.Collection[Any],
218237
) -> list[ModelT]:
219-
return [item for item in result if getattr(item, field_name) in values]
238+
field_str = self._extract_field_name(field_name)
239+
return [item for item in result if getattr(item, field_str) in values]
220240

221-
@staticmethod
222241
def _filter_not_in_collection(
242+
self,
223243
result: list[ModelT],
224-
field_name: str,
244+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
225245
values: abc.Collection[Any],
226246
) -> list[ModelT]:
227247
if not values:
228248
return result
229-
return [item for item in result if getattr(item, field_name) not in values]
249+
field_str = self._extract_field_name(field_name)
250+
return [item for item in result if getattr(item, field_str) not in values]
230251

231-
@staticmethod
232252
def _filter_on_datetime_field(
253+
self,
233254
result: list[ModelT],
234-
field_name: str,
255+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
235256
before: Optional[datetime.datetime] = None,
236257
after: Optional[datetime.datetime] = None,
237258
on_or_before: Optional[datetime.datetime] = None,
238259
on_or_after: Optional[datetime.datetime] = None,
239260
) -> list[ModelT]:
261+
field_str = self._extract_field_name(field_name)
240262
result_: list[ModelT] = []
241263
for item in result:
242-
attr: datetime.datetime = getattr(item, field_name)
264+
attr: datetime.datetime = getattr(item, field_str)
243265
if before is not None and attr < before:
244266
result_.append(item)
245267
if after is not None and attr > after:
@@ -303,9 +325,13 @@ def _filter_result_by_kwargs(
303325
except AttributeError as error:
304326
raise RepositoryError from error
305327

306-
@staticmethod
307-
def _order_by(result: list[ModelT], field_name: str, sort_desc: bool = False) -> list[ModelT]:
308-
return sorted(result, key=lambda item: getattr(item, field_name), reverse=sort_desc)
328+
def _order_by(
329+
self,
330+
result: list[ModelT],
331+
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
332+
sort_desc: bool = False,
333+
) -> list[ModelT]:
334+
return sorted(result, key=lambda item: getattr(item, self._extract_field_name(field_name)), reverse=sort_desc)
309335

310336
def _apply_filters(
311337
self,

tests/integration/test_filters.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
from pytest import FixtureRequest
6-
from sqlalchemy import Engine, String, select
6+
from sqlalchemy import Engine, String, func, select
77
from sqlalchemy.ext.asyncio import AsyncEngine
88
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
99

@@ -456,6 +456,78 @@ def test_order_by_filter(session: Session, movie_model_sync: type[DeclarativeBas
456456
assert results[0].title == "The Hangover"
457457

458458

459+
def test_order_by_with_func_random(session: Session, movie_model_sync: type[DeclarativeBase]) -> None:
460+
"""Test OrderBy filter with func.random() expression."""
461+
Movie = movie_model_sync
462+
463+
# Skip mock engines
464+
dialect_name = getattr(session.bind.dialect, "name", "")
465+
if dialect_name == "mock":
466+
pytest.skip("Mock engines not supported for filter tests")
467+
468+
# Skip Oracle - uses dbms_random.value() instead of random()
469+
if dialect_name.startswith("oracle"):
470+
pytest.skip("Oracle uses dbms_random.value() instead of random()")
471+
472+
# Clean any existing data first, then setup fresh data
473+
if dialect_name != "mock":
474+
session.execute(Movie.__table__.delete())
475+
session.commit()
476+
setup_movie_data(session, Movie)
477+
478+
# Test func.random() - should not raise type error
479+
order_by_filter = OrderBy(field_name=func.random())
480+
statement = order_by_filter.append_to_statement(select(Movie), Movie)
481+
results = session.execute(statement).scalars().all()
482+
# Should return all movies, order is random
483+
assert len(results) == 3
484+
485+
486+
def test_order_by_with_func_lower(session: Session, movie_model_sync: type[DeclarativeBase]) -> None:
487+
"""Test OrderBy filter with func.lower() for case-insensitive sorting."""
488+
Movie = movie_model_sync
489+
490+
# Skip mock engines
491+
if getattr(session.bind.dialect, "name", "") == "mock":
492+
pytest.skip("Mock engines not supported for filter tests")
493+
494+
# Clean any existing data first, then setup fresh data
495+
if getattr(session.bind.dialect, "name", "") != "mock":
496+
session.execute(Movie.__table__.delete())
497+
session.commit()
498+
setup_movie_data(session, Movie)
499+
500+
# Test func.lower() for case-insensitive alphabetical sorting
501+
order_by_filter = OrderBy(field_name=func.lower(Movie.title), sort_order="asc")
502+
statement = order_by_filter.append_to_statement(select(Movie), Movie)
503+
results = session.execute(statement).scalars().all()
504+
# Should be sorted alphabetically: Shawshank, The Hangover, The Matrix
505+
assert results[0].title == "Shawshank Redemption"
506+
assert results[1].title == "The Hangover"
507+
assert results[2].title == "The Matrix"
508+
509+
510+
def test_order_by_with_instrumented_attribute(session: Session, movie_model_sync: type[DeclarativeBase]) -> None:
511+
"""Test OrderBy filter with InstrumentedAttribute (Model.field)."""
512+
Movie = movie_model_sync
513+
514+
# Skip mock engines
515+
if getattr(session.bind.dialect, "name", "") == "mock":
516+
pytest.skip("Mock engines not supported for filter tests")
517+
518+
# Clean any existing data first, then setup fresh data
519+
if getattr(session.bind.dialect, "name", "") != "mock":
520+
session.execute(Movie.__table__.delete())
521+
session.commit()
522+
setup_movie_data(session, Movie)
523+
524+
# Test with InstrumentedAttribute (backward compatibility)
525+
order_by_filter = OrderBy(field_name=Movie.release_date, sort_order="asc")
526+
statement = order_by_filter.append_to_statement(select(Movie), Movie)
527+
results = session.execute(statement).scalars().all()
528+
assert results[0].title == "Shawshank Redemption"
529+
530+
459531
def test_search_filter(session: Session, movie_model_sync: type[DeclarativeBase]) -> None:
460532
Movie = movie_model_sync
461533

0 commit comments

Comments
 (0)