2626from dataclasses import dataclass
2727from operator import attrgetter
2828from typing import (
29+ TYPE_CHECKING ,
2930 Any ,
3031 Callable ,
3132 ClassVar ,
5354 text ,
5455 true ,
5556)
56- from sqlalchemy .orm import InstrumentedAttribute
5757from sqlalchemy .sql import operators as op
5858from sqlalchemy .sql .dml import ReturningDelete , ReturningUpdate
5959from typing_extensions import TypeAlias , TypedDict , TypeVar
6060
6161from advanced_alchemy .base import ModelProtocol
6262
63+ if TYPE_CHECKING :
64+ from sqlalchemy .orm import InstrumentedAttribute
65+
6366__all__ = (
6467 "BeforeAfter" ,
6568 "CollectionFilter" ,
@@ -153,7 +156,9 @@ def append_to_statement(
153156 return statement
154157
155158 @staticmethod
156- def _get_instrumented_attr (model : Any , key : Union [str , InstrumentedAttribute [Any ]]) -> InstrumentedAttribute [Any ]:
159+ def _get_instrumented_attr (
160+ model : Any , key : "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]"
161+ ) -> "Union[ColumnElement[Any], InstrumentedAttribute[Any]]" :
157162 """Get SQLAlchemy instrumented attribute from model.
158163
159164 Args:
@@ -166,9 +171,7 @@ def _get_instrumented_attr(model: Any, key: Union[str, InstrumentedAttribute[Any
166171 See Also:
167172 :class:`sqlalchemy.orm.attributes.InstrumentedAttribute`: SQLAlchemy attribute
168173 """
169- if isinstance (key , str ):
170- return cast ("InstrumentedAttribute[Any]" , getattr (model , key ))
171- return key
174+ return cast ("InstrumentedAttribute[Any]" , getattr (model , key )) if isinstance (key , str ) else key
172175
173176
174177@dataclass
@@ -186,8 +189,8 @@ class BeforeAfter(StatementFilter):
186189
187190 """
188191
189- field_name : str
190- """Name of the model attribute to filter on ."""
192+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
193+ """Field name, model attribute, or func expression ."""
191194 before : Optional [datetime .datetime ]
192195 """Filter results where field is earlier than this value."""
193196 after : Optional [datetime .datetime ]
@@ -232,8 +235,8 @@ class OnBeforeAfter(StatementFilter):
232235
233236 """
234237
235- field_name : str
236- """Name of the model attribute to filter on ."""
238+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
239+ """Field name, model attribute, or func expression ."""
237240 on_or_before : Optional [datetime .datetime ]
238241 """Filter results where field is on or earlier than this value."""
239242 on_or_after : Optional [datetime .datetime ]
@@ -280,8 +283,8 @@ class CollectionFilter(InAnyFilter, Generic[T]):
280283 Use ``prefer_any=True`` in ``append_to_statement`` to use the ``ANY`` operator.
281284 """
282285
283- field_name : str
284- """Name of the model attribute to filter on ."""
286+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
287+ """Field name, model attribute, or func expression ."""
285288 values : Union [Collection [T ], None ]
286289 """Values for the ``IN`` clause. If this is None, no filter is applied.
287290 An empty list will force an empty result set (WHERE 1=-1)"""
@@ -339,8 +342,8 @@ class NotInCollectionFilter(InAnyFilter, Generic[T]):
339342
340343 """
341344
342- field_name : str
343- """Name of the model attribute to filter on ."""
345+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
346+ """Field name, model attribute, or func expression ."""
344347 values : Union [Collection [T ], None ]
345348 """Values for the ``NOT IN`` clause. If None or empty, no filter is applied."""
346349
@@ -443,8 +446,8 @@ class OrderBy(StatementFilter):
443446 - :meth:`sqlalchemy.sql.expression.ColumnElement.desc`: Descending order
444447 """
445448
446- field_name : str
447- """Name of the model attribute to sort on ."""
449+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
450+ """Field name, model attribute, or func expression (e.g., ``func.random()``) ."""
448451 sort_order : Literal ["asc" , "desc" ] = "asc"
449452 """Sort direction ("asc" or "desc")."""
450453
@@ -614,8 +617,8 @@ class ComparisonFilter(StatementFilter):
614617 ValueError: If an invalid operator is provided
615618 """
616619
617- field_name : str
618- """Name of the model attribute to filter on ."""
620+ field_name : "Union[ str, ColumnElement[Any], InstrumentedAttribute[Any]]"
621+ """Field name, model attribute, or func expression ."""
619622 operator : str
620623 """Comparison operator to use (one of 'eq', 'ne', 'gt', 'ge', 'lt', 'le')."""
621624 value : Any
0 commit comments