Skip to content

Commit cbc6ef6

Browse files
committed
fix: always refresh TTL when refresh_on_read is enabled
Remove 60% threshold optimization to prevent TTL drift between checkpoint keys and user-managed external keys. When refresh_on_read=True, TTL is now always refreshed on checkpoint reads, ensuring predictable behavior for TTL synchronization use cases. - Remove unnecessary threshold logic in both sync and async implementations - Simplify TTL refresh behavior to be more intuitive - Add comprehensive tests for new TTL behavior - Update documentation to be clearer about TTL functionality
1 parent 7a0e417 commit cbc6ef6

File tree

5 files changed

+589
-17
lines changed

5 files changed

+589
-17
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,21 +248,22 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe
248248

249249
## Redis Checkpoint TTL Support
250250

251-
Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration:
251+
Both Redis checkpoint savers and stores support automatic expiration using Redis TTL:
252252

253253
```python
254-
# Configure TTL for checkpoint savers
254+
# Configure automatic expiration
255255
ttl_config = {
256-
"default_ttl": 60, # Default TTL in minutes
257-
"refresh_on_read": True, # Refresh TTL when checkpoint is read
256+
"default_ttl": 60, # Expire checkpoints after 60 minutes
257+
"refresh_on_read": True, # Reset expiration time when reading checkpoints
258258
}
259259

260-
# Use with any checkpoint saver implementation
261-
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer:
262-
checkpointer.setup()
263-
# Use the checkpointer...
260+
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as saver:
261+
saver.setup()
262+
# Checkpoints will expire after 60 minutes of inactivity
264263
```
265264

265+
When no TTL is configured, checkpoints are persistent (never expire automatically).
266+
266267
### Removing TTL (Pinning Threads)
267268

268269
You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire:

langgraph/checkpoint/redis/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -936,14 +936,13 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
936936
doc_checkpoint_id,
937937
)
938938

939-
# Check current TTL before doing expensive refresh operations
939+
# Always refresh TTL when refresh_on_read is enabled
940+
# This ensures all related keys maintain synchronized TTLs
940941
current_ttl = self._redis.ttl(checkpoint_key)
941-
default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
942-
ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL
943942

944-
# Only refresh if TTL is below threshold (or key doesn't exist)
943+
# Only refresh if key exists and has TTL (skip keys with no expiry)
945944
# TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining
946-
if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold):
945+
if current_ttl > 0:
947946
# Note: We don't refresh TTL for keys with no expiry (TTL = -1)
948947
# Get all blob keys related to this checkpoint
949948
from langgraph.checkpoint.redis.base import (

langgraph/checkpoint/redis/aio.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,12 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
483483
)
484484
current_ttl = await self._redis.ttl(checkpoint_key)
485485

486-
default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
487-
ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL
486+
# Always refresh TTL when refresh_on_read is enabled
487+
# This ensures all related keys maintain synchronized TTLs
488488

489-
# Only refresh if TTL is below threshold (or key doesn't exist)
490-
if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold):
489+
# Only refresh if key exists and has TTL (skip keys with no expiry)
490+
# TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining
491+
if current_ttl > 0:
491492
# Get all blob keys related to this checkpoint
492493
from langgraph.checkpoint.redis.base import (
493494
CHECKPOINT_BLOB_PREFIX,
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
"""Test async TTL synchronization behavior for AsyncRedisSaver."""
2+
3+
import asyncio
4+
import time
5+
from contextlib import asynccontextmanager
6+
from typing import AsyncGenerator
7+
from uuid import uuid4
8+
9+
import pytest
10+
from langgraph.checkpoint.base import Checkpoint
11+
12+
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
13+
14+
15+
@asynccontextmanager
16+
async def _saver(
17+
redis_url: str, ttl_config: dict
18+
) -> AsyncGenerator[AsyncRedisSaver, None]:
19+
"""Create an AsyncRedisSaver with the given TTL configuration."""
20+
async with AsyncRedisSaver.from_conn_string(redis_url, ttl=ttl_config) as saver:
21+
await saver.setup()
22+
yield saver
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_async_ttl_refresh_on_read(redis_url: str) -> None:
27+
"""Test that TTL is always refreshed when refresh_on_read is enabled (async)."""
28+
29+
# Configure with TTL refresh on read
30+
ttl_config = {
31+
"default_ttl": 2, # 2 minutes = 120 seconds
32+
"refresh_on_read": True,
33+
}
34+
35+
async with _saver(redis_url, ttl_config) as saver:
36+
thread_id = str(uuid4())
37+
checkpoint_ns = ""
38+
checkpoint_id = str(uuid4())
39+
40+
# Create a checkpoint
41+
checkpoint = Checkpoint(
42+
v=1,
43+
id=checkpoint_id,
44+
ts="2024-01-01T00:00:00+00:00",
45+
channel_values={"test": "value"},
46+
channel_versions={},
47+
versions_seen={},
48+
)
49+
50+
config = {
51+
"configurable": {
52+
"thread_id": thread_id,
53+
"checkpoint_ns": checkpoint_ns,
54+
}
55+
}
56+
57+
# Save the checkpoint
58+
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})
59+
60+
# Get the checkpoint key
61+
from langgraph.checkpoint.redis.base import BaseRedisSaver
62+
from langgraph.checkpoint.redis.util import (
63+
to_storage_safe_id,
64+
to_storage_safe_str,
65+
)
66+
67+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
68+
to_storage_safe_id(thread_id),
69+
to_storage_safe_str(checkpoint_ns),
70+
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
71+
)
72+
73+
# Check initial TTL (should be around 120 seconds)
74+
initial_ttl = await saver._redis.ttl(checkpoint_key)
75+
assert (
76+
115 <= initial_ttl <= 120
77+
), f"Initial TTL should be ~120s, got {initial_ttl}"
78+
79+
# Wait a bit (simulate time passing)
80+
await asyncio.sleep(2)
81+
82+
# Read the checkpoint - this should refresh TTL to full value
83+
result = await saver.aget_tuple(saved_config)
84+
assert result is not None
85+
86+
# Check TTL after read - should be refreshed to full value
87+
refreshed_ttl = await saver._redis.ttl(checkpoint_key)
88+
assert (
89+
115 <= refreshed_ttl <= 120
90+
), f"TTL should be refreshed to ~120s, got {refreshed_ttl}"
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_async_ttl_no_refresh_when_disabled(redis_url: str) -> None:
95+
"""Test that TTL is not refreshed when refresh_on_read is disabled (async)."""
96+
97+
# Configure without TTL refresh on read
98+
ttl_config = {
99+
"default_ttl": 2, # 2 minutes = 120 seconds
100+
"refresh_on_read": False, # Don't refresh TTL on read
101+
}
102+
103+
async with _saver(redis_url, ttl_config) as saver:
104+
thread_id = str(uuid4())
105+
checkpoint_ns = ""
106+
checkpoint_id = str(uuid4())
107+
108+
# Create a checkpoint
109+
checkpoint = Checkpoint(
110+
v=1,
111+
id=checkpoint_id,
112+
ts="2024-01-01T00:00:00+00:00",
113+
channel_values={"test": "value"},
114+
channel_versions={},
115+
versions_seen={},
116+
)
117+
118+
config = {
119+
"configurable": {
120+
"thread_id": thread_id,
121+
"checkpoint_ns": checkpoint_ns,
122+
}
123+
}
124+
125+
# Save the checkpoint
126+
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})
127+
128+
# Get the checkpoint key
129+
from langgraph.checkpoint.redis.base import BaseRedisSaver
130+
from langgraph.checkpoint.redis.util import (
131+
to_storage_safe_id,
132+
to_storage_safe_str,
133+
)
134+
135+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
136+
to_storage_safe_id(thread_id),
137+
to_storage_safe_str(checkpoint_ns),
138+
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
139+
)
140+
141+
# Check initial TTL
142+
initial_ttl = await saver._redis.ttl(checkpoint_key)
143+
assert (
144+
115 <= initial_ttl <= 120
145+
), f"Initial TTL should be ~120s, got {initial_ttl}"
146+
147+
# Wait a bit
148+
await asyncio.sleep(2)
149+
150+
# Read the checkpoint - should NOT refresh TTL when refresh_on_read=False
151+
result = await saver.aget_tuple(saved_config)
152+
assert result is not None
153+
154+
# Check TTL after read - should NOT be refreshed
155+
current_ttl = await saver._redis.ttl(checkpoint_key)
156+
assert (
157+
current_ttl < initial_ttl - 1
158+
), f"TTL should have decreased, got {current_ttl}"
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_async_ttl_synchronization_with_external_keys(redis_url: str) -> None:
163+
"""Test TTL synchronization between checkpoint keys and external user keys (async)."""
164+
165+
# Configure with TTL refresh on read for synchronization
166+
ttl_config = {
167+
"default_ttl": 2, # 2 minutes = 120 seconds
168+
"refresh_on_read": True,
169+
}
170+
171+
async with _saver(redis_url, ttl_config) as saver:
172+
thread_id = str(uuid4())
173+
checkpoint_ns = ""
174+
checkpoint_id = str(uuid4())
175+
176+
# Create a checkpoint
177+
checkpoint = Checkpoint(
178+
v=1,
179+
id=checkpoint_id,
180+
ts="2024-01-01T00:00:00+00:00",
181+
channel_values={"test": "value"},
182+
channel_versions={},
183+
versions_seen={},
184+
)
185+
186+
config = {
187+
"configurable": {
188+
"thread_id": thread_id,
189+
"checkpoint_ns": checkpoint_ns,
190+
}
191+
}
192+
193+
# Save the checkpoint
194+
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})
195+
196+
# Create external keys that should expire together
197+
external_key1 = f"user:metadata:{thread_id}"
198+
external_key2 = f"user:context:{thread_id}"
199+
200+
# Set external keys with same TTL
201+
await saver._redis.set(external_key1, "metadata_value", ex=120)
202+
await saver._redis.set(external_key2, "context_value", ex=120)
203+
204+
# Wait a bit
205+
await asyncio.sleep(2)
206+
207+
# Read checkpoint - should refresh its TTL
208+
result = await saver.aget_tuple(saved_config)
209+
assert result is not None
210+
211+
# Manually refresh external keys' TTL (simulating user's synchronization logic)
212+
await saver._redis.expire(external_key1, 120)
213+
await saver._redis.expire(external_key2, 120)
214+
215+
# Check that all TTLs are synchronized
216+
from langgraph.checkpoint.redis.base import BaseRedisSaver
217+
from langgraph.checkpoint.redis.util import (
218+
to_storage_safe_id,
219+
to_storage_safe_str,
220+
)
221+
222+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
223+
to_storage_safe_id(thread_id),
224+
to_storage_safe_str(checkpoint_ns),
225+
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
226+
)
227+
228+
checkpoint_ttl = await saver._redis.ttl(checkpoint_key)
229+
external_ttl1 = await saver._redis.ttl(external_key1)
230+
external_ttl2 = await saver._redis.ttl(external_key2)
231+
232+
# All TTLs should be close to each other (within 2 seconds)
233+
assert (
234+
abs(checkpoint_ttl - external_ttl1) <= 2
235+
), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl1}"
236+
assert (
237+
abs(checkpoint_ttl - external_ttl2) <= 2
238+
), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl2}"
239+
assert (
240+
115 <= checkpoint_ttl <= 120
241+
), f"Checkpoint TTL should be ~120s, got {checkpoint_ttl}"
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_async_ttl_no_refresh_for_persistent_keys(redis_url: str) -> None:
246+
"""Test that keys without TTL (persistent) are not affected by refresh logic (async)."""
247+
248+
# Configure with TTL refresh on read
249+
ttl_config = {
250+
"default_ttl": 2, # 2 minutes
251+
"refresh_on_read": True,
252+
}
253+
254+
async with _saver(redis_url, ttl_config) as saver:
255+
thread_id = str(uuid4())
256+
checkpoint_ns = ""
257+
checkpoint_id = str(uuid4())
258+
259+
# Create a checkpoint
260+
checkpoint = Checkpoint(
261+
v=1,
262+
id=checkpoint_id,
263+
ts="2024-01-01T00:00:00+00:00",
264+
channel_values={"test": "value"},
265+
channel_versions={},
266+
versions_seen={},
267+
)
268+
269+
config = {
270+
"configurable": {
271+
"thread_id": thread_id,
272+
"checkpoint_ns": checkpoint_ns,
273+
}
274+
}
275+
276+
# Save the checkpoint
277+
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})
278+
279+
# Remove TTL to make it persistent
280+
from langgraph.checkpoint.redis.base import BaseRedisSaver
281+
from langgraph.checkpoint.redis.util import (
282+
to_storage_safe_id,
283+
to_storage_safe_str,
284+
)
285+
286+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
287+
to_storage_safe_id(thread_id),
288+
to_storage_safe_str(checkpoint_ns),
289+
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
290+
)
291+
await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)
292+
293+
# Verify key is persistent (TTL = -1)
294+
ttl_before = await saver._redis.ttl(checkpoint_key)
295+
assert ttl_before == -1, f"Key should be persistent (TTL=-1), got {ttl_before}"
296+
297+
# Read the checkpoint
298+
result = await saver.aget_tuple(saved_config)
299+
assert result is not None
300+
301+
# Verify key is still persistent (not affected by refresh)
302+
ttl_after = await saver._redis.ttl(checkpoint_key)
303+
assert (
304+
ttl_after == -1
305+
), f"Key should remain persistent (TTL=-1), got {ttl_after}"

0 commit comments

Comments
 (0)