From 0777004c074ef837e445431b68d8716f4f3ccb62 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 18 Aug 2025 11:59:59 -0700 Subject: [PATCH] fix: always refresh TTL when refresh_on_read is enabled Remove 60% threshold optimization to prevent TTL drift between checkpoint keys and user-managed external keys. When refresh_on_read=True, TTL is now always refreshed on checkpoint reads, ensuring predictable behavior for TTL synchronization use cases. - Remove unnecessary threshold logic in both sync and async implementations - Simplify TTL refresh behavior to be more intuitive - Add comprehensive tests for new TTL behavior - Update documentation to be clearer about TTL functionality --- README.md | 17 +- langgraph/checkpoint/redis/__init__.py | 9 +- langgraph/checkpoint/redis/aio.py | 9 +- tests/test_async_ttl_synchronization.py | 305 ++++++++++++++++++++++++ tests/test_ttl_synchronization.py | 266 +++++++++++++++++++++ 5 files changed, 589 insertions(+), 17 deletions(-) create mode 100644 tests/test_async_ttl_synchronization.py create mode 100644 tests/test_ttl_synchronization.py diff --git a/README.md b/README.md index 5903013..61b0651 100644 --- a/README.md +++ b/README.md @@ -248,21 +248,22 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe ## Redis Checkpoint TTL Support -Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration: +Both Redis checkpoint savers and stores support automatic expiration using Redis TTL: ```python -# Configure TTL for checkpoint savers +# Configure automatic expiration ttl_config = { - "default_ttl": 60, # Default TTL in minutes - "refresh_on_read": True, # Refresh TTL when checkpoint is read + "default_ttl": 60, # Expire checkpoints after 60 minutes + "refresh_on_read": True, # Reset expiration time when reading checkpoints } -# Use with any checkpoint saver implementation -with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer: - checkpointer.setup() - # Use the checkpointer... +with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as saver: + saver.setup() + # Checkpoints will expire after 60 minutes of inactivity ``` +When no TTL is configured, checkpoints are persistent (never expire automatically). + ### Removing TTL (Pinning Threads) You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire: diff --git a/langgraph/checkpoint/redis/__init__.py b/langgraph/checkpoint/redis/__init__.py index 67405f2..a1b032c 100644 --- a/langgraph/checkpoint/redis/__init__.py +++ b/langgraph/checkpoint/redis/__init__.py @@ -936,14 +936,13 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: doc_checkpoint_id, ) - # Check current TTL before doing expensive refresh operations + # Always refresh TTL when refresh_on_read is enabled + # This ensures all related keys maintain synchronized TTLs current_ttl = self._redis.ttl(checkpoint_key) - default_ttl_minutes = self.ttl_config.get("default_ttl", 60) - ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL - # Only refresh if TTL is below threshold (or key doesn't exist) + # Only refresh if key exists and has TTL (skip keys with no expiry) # TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining - if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold): + if current_ttl > 0: # Note: We don't refresh TTL for keys with no expiry (TTL = -1) # Get all blob keys related to this checkpoint from langgraph.checkpoint.redis.base import ( diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 1ba54c0..8533dc4 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -483,11 +483,12 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: ) current_ttl = await self._redis.ttl(checkpoint_key) - default_ttl_minutes = self.ttl_config.get("default_ttl", 60) - ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL + # Always refresh TTL when refresh_on_read is enabled + # This ensures all related keys maintain synchronized TTLs - # Only refresh if TTL is below threshold (or key doesn't exist) - if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold): + # Only refresh if key exists and has TTL (skip keys with no expiry) + # TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining + if current_ttl > 0: # Get all blob keys related to this checkpoint from langgraph.checkpoint.redis.base import ( CHECKPOINT_BLOB_PREFIX, diff --git a/tests/test_async_ttl_synchronization.py b/tests/test_async_ttl_synchronization.py new file mode 100644 index 0000000..ae5c1a0 --- /dev/null +++ b/tests/test_async_ttl_synchronization.py @@ -0,0 +1,305 @@ +"""Test async TTL synchronization behavior for AsyncRedisSaver.""" + +import asyncio +import time +from contextlib import asynccontextmanager +from typing import AsyncGenerator +from uuid import uuid4 + +import pytest +from langgraph.checkpoint.base import Checkpoint + +from langgraph.checkpoint.redis.aio import AsyncRedisSaver + + +@asynccontextmanager +async def _saver( + redis_url: str, ttl_config: dict +) -> AsyncGenerator[AsyncRedisSaver, None]: + """Create an AsyncRedisSaver with the given TTL configuration.""" + async with AsyncRedisSaver.from_conn_string(redis_url, ttl=ttl_config) as saver: + await saver.setup() + yield saver + + +@pytest.mark.asyncio +async def test_async_ttl_refresh_on_read(redis_url: str) -> None: + """Test that TTL is always refreshed when refresh_on_read is enabled (async).""" + + # Configure with TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": True, + } + + async with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {}) + + # Get the checkpoint key + from langgraph.checkpoint.redis.base import BaseRedisSaver + from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, + ) + + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(thread_id), + to_storage_safe_str(checkpoint_ns), + to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]), + ) + + # Check initial TTL (should be around 120 seconds) + initial_ttl = await saver._redis.ttl(checkpoint_key) + assert ( + 115 <= initial_ttl <= 120 + ), f"Initial TTL should be ~120s, got {initial_ttl}" + + # Wait a bit (simulate time passing) + await asyncio.sleep(2) + + # Read the checkpoint - this should refresh TTL to full value + result = await saver.aget_tuple(saved_config) + assert result is not None + + # Check TTL after read - should be refreshed to full value + refreshed_ttl = await saver._redis.ttl(checkpoint_key) + assert ( + 115 <= refreshed_ttl <= 120 + ), f"TTL should be refreshed to ~120s, got {refreshed_ttl}" + + +@pytest.mark.asyncio +async def test_async_ttl_no_refresh_when_disabled(redis_url: str) -> None: + """Test that TTL is not refreshed when refresh_on_read is disabled (async).""" + + # Configure without TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": False, # Don't refresh TTL on read + } + + async with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {}) + + # Get the checkpoint key + from langgraph.checkpoint.redis.base import BaseRedisSaver + from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, + ) + + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(thread_id), + to_storage_safe_str(checkpoint_ns), + to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]), + ) + + # Check initial TTL + initial_ttl = await saver._redis.ttl(checkpoint_key) + assert ( + 115 <= initial_ttl <= 120 + ), f"Initial TTL should be ~120s, got {initial_ttl}" + + # Wait a bit + await asyncio.sleep(2) + + # Read the checkpoint - should NOT refresh TTL when refresh_on_read=False + result = await saver.aget_tuple(saved_config) + assert result is not None + + # Check TTL after read - should NOT be refreshed + current_ttl = await saver._redis.ttl(checkpoint_key) + assert ( + current_ttl < initial_ttl - 1 + ), f"TTL should have decreased, got {current_ttl}" + + +@pytest.mark.asyncio +async def test_async_ttl_synchronization_with_external_keys(redis_url: str) -> None: + """Test TTL synchronization between checkpoint keys and external user keys (async).""" + + # Configure with TTL refresh on read for synchronization + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": True, + } + + async with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {}) + + # Create external keys that should expire together + external_key1 = f"user:metadata:{thread_id}" + external_key2 = f"user:context:{thread_id}" + + # Set external keys with same TTL + await saver._redis.set(external_key1, "metadata_value", ex=120) + await saver._redis.set(external_key2, "context_value", ex=120) + + # Wait a bit + await asyncio.sleep(2) + + # Read checkpoint - should refresh its TTL + result = await saver.aget_tuple(saved_config) + assert result is not None + + # Manually refresh external keys' TTL (simulating user's synchronization logic) + await saver._redis.expire(external_key1, 120) + await saver._redis.expire(external_key2, 120) + + # Check that all TTLs are synchronized + from langgraph.checkpoint.redis.base import BaseRedisSaver + from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, + ) + + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(thread_id), + to_storage_safe_str(checkpoint_ns), + to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]), + ) + + checkpoint_ttl = await saver._redis.ttl(checkpoint_key) + external_ttl1 = await saver._redis.ttl(external_key1) + external_ttl2 = await saver._redis.ttl(external_key2) + + # All TTLs should be close to each other (within 2 seconds) + assert ( + abs(checkpoint_ttl - external_ttl1) <= 2 + ), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl1}" + assert ( + abs(checkpoint_ttl - external_ttl2) <= 2 + ), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl2}" + assert ( + 115 <= checkpoint_ttl <= 120 + ), f"Checkpoint TTL should be ~120s, got {checkpoint_ttl}" + + +@pytest.mark.asyncio +async def test_async_ttl_no_refresh_for_persistent_keys(redis_url: str) -> None: + """Test that keys without TTL (persistent) are not affected by refresh logic (async).""" + + # Configure with TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes + "refresh_on_read": True, + } + + async with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {}) + + # Remove TTL to make it persistent + from langgraph.checkpoint.redis.base import BaseRedisSaver + from langgraph.checkpoint.redis.util import ( + to_storage_safe_id, + to_storage_safe_str, + ) + + checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( + to_storage_safe_id(thread_id), + to_storage_safe_str(checkpoint_ns), + to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]), + ) + await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + + # Verify key is persistent (TTL = -1) + ttl_before = await saver._redis.ttl(checkpoint_key) + assert ttl_before == -1, f"Key should be persistent (TTL=-1), got {ttl_before}" + + # Read the checkpoint + result = await saver.aget_tuple(saved_config) + assert result is not None + + # Verify key is still persistent (not affected by refresh) + ttl_after = await saver._redis.ttl(checkpoint_key) + assert ( + ttl_after == -1 + ), f"Key should remain persistent (TTL=-1), got {ttl_after}" diff --git a/tests/test_ttl_synchronization.py b/tests/test_ttl_synchronization.py new file mode 100644 index 0000000..bef5ec2 --- /dev/null +++ b/tests/test_ttl_synchronization.py @@ -0,0 +1,266 @@ +"""Test TTL synchronization behavior for RedisSaver.""" + +import time +from contextlib import contextmanager +from typing import Generator +from uuid import uuid4 + +import pytest +from langgraph.checkpoint.base import Checkpoint + +from langgraph.checkpoint.redis import RedisSaver + + +@contextmanager +def _saver(redis_url: str, ttl_config: dict) -> Generator[RedisSaver, None, None]: + """Create a RedisSaver with the given TTL configuration.""" + with RedisSaver.from_conn_string(redis_url, ttl=ttl_config) as saver: + saver.setup() + yield saver + + +def test_ttl_refresh_on_read(redis_url: str) -> None: + """Test that TTL is always refreshed when refresh_on_read is enabled.""" + + # Configure with TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": True, + } + + with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = saver.put(config, checkpoint, {"test": "metadata"}, {}) + + # Get the checkpoint key + checkpoint_key = saver._make_redis_checkpoint_key_cached( + thread_id, checkpoint_ns, saved_config["configurable"]["checkpoint_id"] + ) + + # Check initial TTL (should be around 120 seconds) + initial_ttl = saver._redis.ttl(checkpoint_key) + assert ( + 115 <= initial_ttl <= 120 + ), f"Initial TTL should be ~120s, got {initial_ttl}" + + # Wait a bit (simulate time passing) + time.sleep(2) + + # Read the checkpoint - this should refresh TTL to full value + result = saver.get_tuple(saved_config) + assert result is not None + + # Check TTL after read - should be refreshed to full value + refreshed_ttl = saver._redis.ttl(checkpoint_key) + assert ( + 115 <= refreshed_ttl <= 120 + ), f"TTL should be refreshed to ~120s, got {refreshed_ttl}" + + +def test_ttl_no_refresh_when_disabled(redis_url: str) -> None: + """Test that TTL is not refreshed when refresh_on_read is disabled.""" + + # Configure without TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": False, # Don't refresh TTL on read + } + + with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = saver.put(config, checkpoint, {"test": "metadata"}, {}) + + # Get the checkpoint key + checkpoint_key = saver._make_redis_checkpoint_key_cached( + thread_id, checkpoint_ns, saved_config["configurable"]["checkpoint_id"] + ) + + # Check initial TTL + initial_ttl = saver._redis.ttl(checkpoint_key) + assert ( + 115 <= initial_ttl <= 120 + ), f"Initial TTL should be ~120s, got {initial_ttl}" + + # Wait a bit + time.sleep(2) + + # Read the checkpoint - should NOT refresh TTL when refresh_on_read=False + result = saver.get_tuple(saved_config) + assert result is not None + + # Check TTL after read - should NOT be refreshed + current_ttl = saver._redis.ttl(checkpoint_key) + assert ( + current_ttl < initial_ttl - 1 + ), f"TTL should have decreased, got {current_ttl}" + + +def test_ttl_synchronization_with_external_keys(redis_url: str) -> None: + """Test TTL synchronization between checkpoint keys and external user keys.""" + + # Configure with TTL refresh on read for synchronization + ttl_config = { + "default_ttl": 2, # 2 minutes = 120 seconds + "refresh_on_read": True, + } + + with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = saver.put(config, checkpoint, {"test": "metadata"}, {}) + + # Create external keys that should expire together + external_key1 = f"user:metadata:{thread_id}" + external_key2 = f"user:context:{thread_id}" + + # Set external keys with same TTL + saver._redis.set(external_key1, "metadata_value", ex=120) + saver._redis.set(external_key2, "context_value", ex=120) + + # Wait a bit + time.sleep(2) + + # Read checkpoint - should refresh its TTL + result = saver.get_tuple(saved_config) + assert result is not None + + # Manually refresh external keys' TTL (simulating user's synchronization logic) + saver._redis.expire(external_key1, 120) + saver._redis.expire(external_key2, 120) + + # Check that all TTLs are synchronized + checkpoint_key = saver._make_redis_checkpoint_key_cached( + thread_id, checkpoint_ns, saved_config["configurable"]["checkpoint_id"] + ) + + checkpoint_ttl = saver._redis.ttl(checkpoint_key) + external_ttl1 = saver._redis.ttl(external_key1) + external_ttl2 = saver._redis.ttl(external_key2) + + # All TTLs should be close to each other (within 2 seconds) + assert ( + abs(checkpoint_ttl - external_ttl1) <= 2 + ), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl1}" + assert ( + abs(checkpoint_ttl - external_ttl2) <= 2 + ), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl2}" + assert ( + 115 <= checkpoint_ttl <= 120 + ), f"Checkpoint TTL should be ~120s, got {checkpoint_ttl}" + + +def test_ttl_no_refresh_for_persistent_keys(redis_url: str) -> None: + """Test that keys without TTL (persistent) are not affected by refresh logic.""" + + # Configure with TTL refresh on read + ttl_config = { + "default_ttl": 2, # 2 minutes + "refresh_on_read": True, + } + + with _saver(redis_url, ttl_config) as saver: + thread_id = str(uuid4()) + checkpoint_ns = "" + checkpoint_id = str(uuid4()) + + # Create a checkpoint + checkpoint = Checkpoint( + v=1, + id=checkpoint_id, + ts="2024-01-01T00:00:00+00:00", + channel_values={"test": "value"}, + channel_versions={}, + versions_seen={}, + ) + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + } + + # Save the checkpoint + saved_config = saver.put(config, checkpoint, {"test": "metadata"}, {}) + + # Remove TTL to make it persistent + checkpoint_key = saver._make_redis_checkpoint_key_cached( + thread_id, checkpoint_ns, saved_config["configurable"]["checkpoint_id"] + ) + saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + + # Verify key is persistent (TTL = -1) + ttl_before = saver._redis.ttl(checkpoint_key) + assert ttl_before == -1, f"Key should be persistent (TTL=-1), got {ttl_before}" + + # Read the checkpoint + result = saver.get_tuple(saved_config) + assert result is not None + + # Verify key is still persistent (not affected by refresh) + ttl_after = saver._redis.ttl(checkpoint_key) + assert ( + ttl_after == -1 + ), f"Key should remain persistent (TTL=-1), got {ttl_after}"