Skip to content

Commit 51a38ff

Browse files
authored
Add hash field expiration support for Redis 7.4+ (#752)
* Add hash field expiration support for Redis 7.4+ - Add Field(expire=N) parameter for declarative TTL on hash fields - Add expire_field(), field_ttl(), persist_field() instance methods - Add field_expirations parameter to save() for runtime TTL overrides - Add supports_hash_field_expiration() version check (requires redis-py 5.1+) - Add comprehensive tests for all expiration functionality Closes #750 * Fix: preserve field TTLs when save() is called HSET removes field-level TTLs, so save() now: 1. Gets current TTLs before HSET 2. Restores them after HSET This prevents manually-set TTLs from being overwritten. Fixes #753
1 parent 0c1abc8 commit 51a38ff

File tree

2 files changed

+568
-0
lines changed

2 files changed

+568
-0
lines changed

aredis_om/model/model.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,32 @@
7777
log = logging.getLogger(__name__)
7878
escaper = TokenEscaper()
7979

80+
# Minimum redis-py version for hash field expiration support
81+
_HASH_FIELD_EXPIRATION_MIN_VERSION = (5, 1, 0)
82+
83+
84+
def supports_hash_field_expiration() -> bool:
85+
"""
86+
Check if the installed redis-py version supports hash field expiration commands.
87+
88+
Hash field expiration (HEXPIRE, HTTL, HPERSIST, etc.) was added in redis-py 5.1.0
89+
and requires Redis server 7.4+.
90+
91+
Returns:
92+
True if redis-py >= 5.1.0 and has the hexpire method, False otherwise.
93+
"""
94+
try:
95+
import redis as redis_lib
96+
97+
version_str = getattr(redis_lib, "__version__", "0.0.0")
98+
version_parts = tuple(int(x) for x in version_str.split(".")[:3])
99+
if version_parts >= _HASH_FIELD_EXPIRATION_MIN_VERSION:
100+
# Also check that the method actually exists
101+
return hasattr(redis_lib.asyncio.Redis, "hexpire")
102+
return False
103+
except (ValueError, AttributeError):
104+
return False
105+
80106

81107
def convert_datetime_to_timestamp(obj):
82108
"""Convert datetime objects to Unix timestamps for storage."""
@@ -1879,13 +1905,15 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None:
18791905
index = kwargs.pop("index", None)
18801906
full_text_search = kwargs.pop("full_text_search", None)
18811907
vector_options = kwargs.pop("vector_options", None)
1908+
expire = kwargs.pop("expire", None)
18821909
super().__init__(default=default, **kwargs)
18831910
self.primary_key = primary_key
18841911
self.sortable = sortable
18851912
self.case_sensitive = case_sensitive
18861913
self.index = index
18871914
self.full_text_search = full_text_search
18881915
self.vector_options = vector_options
1916+
self.expire = expire
18891917

18901918

18911919
class RelationshipInfo(Representation):
@@ -1996,8 +2024,27 @@ def Field(
19962024
index: Union[bool, UndefinedType] = Undefined,
19972025
full_text_search: Union[bool, UndefinedType] = Undefined,
19982026
vector_options: Optional[VectorFieldOptions] = None,
2027+
expire: Optional[int] = None,
19992028
**kwargs: Unpack[_FromFieldInfoInputs],
20002029
) -> Any:
2030+
"""
2031+
Create a field with Redis OM specific options.
2032+
2033+
Args:
2034+
default: Default value for the field.
2035+
primary_key: Whether this field is the primary key.
2036+
sortable: Whether this field should be sortable in queries.
2037+
case_sensitive: Whether string matching should be case-sensitive.
2038+
index: Whether this field should be indexed for searching.
2039+
full_text_search: Whether to enable full-text search on this field.
2040+
vector_options: Vector field configuration for similarity search.
2041+
expire: TTL in seconds for this field (HashModel only, requires Redis 7.4+).
2042+
When set, the field will automatically expire after save().
2043+
**kwargs: Additional Pydantic field options.
2044+
2045+
Returns:
2046+
FieldInfo instance with the configured options.
2047+
"""
20012048
field_info = FieldInfo(
20022049
**kwargs,
20032050
default=default,
@@ -2007,6 +2054,7 @@ def Field(
20072054
index=index,
20082055
full_text_search=full_text_search,
20092056
vector_options=vector_options,
2057+
expire=expire,
20102058
)
20112059
return field_info
20122060

@@ -2631,12 +2679,62 @@ def __init_subclass__(cls, **kwargs):
26312679
f"HashModels cannot index dataclass fields. Field: {name}"
26322680
)
26332681

2682+
def _get_field_expirations(
2683+
self, field_expirations: Optional[Dict[str, int]] = None
2684+
) -> Dict[str, int]:
2685+
"""
2686+
Collect field expirations from Field(expire=N) defaults and overrides.
2687+
2688+
Args:
2689+
field_expirations: Optional dict of {field_name: ttl_seconds} to override defaults.
2690+
2691+
Returns:
2692+
Dict of field names to TTL in seconds.
2693+
"""
2694+
expirations: Dict[str, int] = {}
2695+
2696+
# Collect default expirations from Field(expire=N)
2697+
for name, field in self.model_fields.items():
2698+
field_info = field
2699+
# Handle metadata-wrapped FieldInfo
2700+
if (
2701+
not isinstance(field, FieldInfo)
2702+
and hasattr(field, "metadata")
2703+
and len(field.metadata) > 0
2704+
and isinstance(field.metadata[0], FieldInfo)
2705+
):
2706+
field_info = field.metadata[0]
2707+
2708+
expire = getattr(field_info, "expire", None)
2709+
if expire is not None:
2710+
expirations[name] = expire
2711+
2712+
# Override with explicit field_expirations
2713+
if field_expirations:
2714+
expirations.update(field_expirations)
2715+
2716+
return expirations
2717+
26342718
async def save(
26352719
self: "Model",
26362720
pipeline: Optional[redis.client.Pipeline] = None,
26372721
nx: bool = False,
26382722
xx: bool = False,
2723+
field_expirations: Optional[Dict[str, int]] = None,
26392724
) -> Optional["Model"]:
2725+
"""
2726+
Save the model to Redis.
2727+
2728+
Args:
2729+
pipeline: Optional Redis pipeline for batching commands.
2730+
nx: Only save if the key doesn't exist.
2731+
xx: Only save if the key already exists.
2732+
field_expirations: Dict of {field_name: ttl_seconds} to set field expirations.
2733+
Overrides any Field(expire=N) defaults. Requires Redis 7.4+.
2734+
2735+
Returns:
2736+
The saved model, or None if nx/xx conditions weren't met.
2737+
"""
26402738
if nx and xx:
26412739
raise ValueError("Cannot specify both nx and xx")
26422740
if pipeline and (nx or xx):
@@ -2666,6 +2764,12 @@ async def save(
26662764

26672765
key = self.key()
26682766

2767+
# Collect field expirations
2768+
expirations = self._get_field_expirations(field_expirations)
2769+
2770+
# Check if we're using a pipeline (pipelines don't support TTL preservation)
2771+
is_pipeline = pipeline is not None
2772+
26692773
async def _do_save(conn):
26702774
# Check nx/xx conditions (HSET doesn't support these natively)
26712775
if nx or xx:
@@ -2675,7 +2779,37 @@ async def _do_save(conn):
26752779
if xx and not exists:
26762780
return None # Key doesn't exist, xx means only update existing
26772781

2782+
# Preserve existing field TTLs before HSET (HSET removes field-level TTLs)
2783+
# See issue #753: .save() conflicts with TTL on unrelated field
2784+
# Note: TTL preservation is skipped when using pipelines because
2785+
# pipeline commands return futures, not actual values
2786+
preserved_ttls: Dict[str, int] = {}
2787+
if supports_hash_field_expiration() and not is_pipeline:
2788+
fields_to_check = [f for f in document.keys() if f != "pk"]
2789+
if fields_to_check:
2790+
current_ttls = await conn.httl(key, *fields_to_check)
2791+
if current_ttls:
2792+
for i, field_name in enumerate(fields_to_check):
2793+
if current_ttls[i] > 0: # Has a TTL
2794+
preserved_ttls[field_name] = current_ttls[i]
2795+
26782796
await conn.hset(key, mapping=document)
2797+
2798+
# Apply field expirations after HSET (requires Redis 7.4+)
2799+
# When using pipelines, we can still apply default expirations but
2800+
# can't preserve manually-set TTLs
2801+
if supports_hash_field_expiration():
2802+
for field_name in document.keys():
2803+
if field_name == "pk":
2804+
continue
2805+
# Priority: preserved TTL > explicit field_expirations > Field(expire=N) default
2806+
if field_name in preserved_ttls:
2807+
# Restore the TTL that was removed by HSET
2808+
await conn.hexpire(key, preserved_ttls[field_name], field_name)
2809+
elif field_name in expirations:
2810+
# Apply new expiration (from Field(expire=N) or field_expirations param)
2811+
await conn.hexpire(key, expirations[field_name], field_name)
2812+
26792813
return self
26802814

26812815
# TODO: Wrap any Redis response errors in a custom exception?
@@ -2861,6 +2995,101 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
28612995

28622996
return schema
28632997

2998+
# =========================================================================
2999+
# Hash Field Expiration Methods (Redis 7.4+)
3000+
# =========================================================================
3001+
3002+
async def expire_field(
3003+
self,
3004+
field_name: str,
3005+
seconds: int,
3006+
nx: bool = False,
3007+
xx: bool = False,
3008+
gt: bool = False,
3009+
lt: bool = False,
3010+
) -> int:
3011+
"""
3012+
Set a TTL on a specific hash field.
3013+
3014+
Requires Redis 7.4+ and redis-py >= 5.1.0.
3015+
3016+
Args:
3017+
field_name: The name of the field to expire.
3018+
seconds: TTL in seconds.
3019+
nx: Only set expiry if field has no expiry.
3020+
xx: Only set expiry if field already has an expiry.
3021+
gt: Only set expiry if new expiry is greater than current.
3022+
lt: Only set expiry if new expiry is less than current.
3023+
3024+
Returns:
3025+
1 if expiry was set, -2 if field doesn't exist, 0 if conditions not met.
3026+
3027+
Raises:
3028+
NotImplementedError: If redis-py version doesn't support HEXPIRE.
3029+
"""
3030+
if not supports_hash_field_expiration():
3031+
raise NotImplementedError(
3032+
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
3033+
)
3034+
3035+
db = self.db()
3036+
key = self.key()
3037+
result = await db.hexpire(key, seconds, field_name, nx=nx, xx=xx, gt=gt, lt=lt)
3038+
# hexpire returns a list of results, one per field
3039+
return result[0] if result else -2
3040+
3041+
async def field_ttl(self, field_name: str) -> int:
3042+
"""
3043+
Get the remaining TTL of a hash field in seconds.
3044+
3045+
Requires Redis 7.4+ and redis-py >= 5.1.0.
3046+
3047+
Args:
3048+
field_name: The name of the field.
3049+
3050+
Returns:
3051+
TTL in seconds, -1 if no expiry, -2 if field doesn't exist.
3052+
3053+
Raises:
3054+
NotImplementedError: If redis-py version doesn't support HTTL.
3055+
"""
3056+
if not supports_hash_field_expiration():
3057+
raise NotImplementedError(
3058+
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
3059+
)
3060+
3061+
db = self.db()
3062+
key = self.key()
3063+
result = await db.httl(key, field_name)
3064+
# httl returns a list of results, one per field
3065+
return result[0] if result else -2
3066+
3067+
async def persist_field(self, field_name: str) -> int:
3068+
"""
3069+
Remove the expiration from a hash field.
3070+
3071+
Requires Redis 7.4+ and redis-py >= 5.1.0.
3072+
3073+
Args:
3074+
field_name: The name of the field.
3075+
3076+
Returns:
3077+
1 if expiry was removed, -1 if no expiry, -2 if field doesn't exist.
3078+
3079+
Raises:
3080+
NotImplementedError: If redis-py version doesn't support HPERSIST.
3081+
"""
3082+
if not supports_hash_field_expiration():
3083+
raise NotImplementedError(
3084+
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
3085+
)
3086+
3087+
db = self.db()
3088+
key = self.key()
3089+
result = await db.hpersist(key, field_name)
3090+
# hpersist returns a list of results, one per field
3091+
return result[0] if result else -2
3092+
28643093

28653094
class JsonModel(RedisModel, abc.ABC):
28663095
def __init_subclass__(cls, **kwargs):

0 commit comments

Comments
 (0)