Skip to content

fix: always refresh TTL when refresh_on_read is enabled #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,22 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe

## Redis Checkpoint TTL Support

Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration:
Both Redis checkpoint savers and stores support automatic expiration using Redis TTL:

```python
# Configure TTL for checkpoint savers
# Configure automatic expiration
ttl_config = {
"default_ttl": 60, # Default TTL in minutes
"refresh_on_read": True, # Refresh TTL when checkpoint is read
"default_ttl": 60, # Expire checkpoints after 60 minutes
"refresh_on_read": True, # Reset expiration time when reading checkpoints
}

# Use with any checkpoint saver implementation
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer:
checkpointer.setup()
# Use the checkpointer...
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as saver:
saver.setup()
# Checkpoints will expire after 60 minutes of inactivity
```

When no TTL is configured, checkpoints are persistent (never expire automatically).

### Removing TTL (Pinning Threads)

You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire:
Expand Down
9 changes: 4 additions & 5 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,14 +936,13 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
doc_checkpoint_id,
)

# Check current TTL before doing expensive refresh operations
# Always refresh TTL when refresh_on_read is enabled
# This ensures all related keys maintain synchronized TTLs
current_ttl = self._redis.ttl(checkpoint_key)
default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL

# Only refresh if TTL is below threshold (or key doesn't exist)
# Only refresh if key exists and has TTL (skip keys with no expiry)
# TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining
if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold):
if current_ttl > 0:
# Note: We don't refresh TTL for keys with no expiry (TTL = -1)
# Get all blob keys related to this checkpoint
from langgraph.checkpoint.redis.base import (
Expand Down
9 changes: 5 additions & 4 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,12 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
)
current_ttl = await self._redis.ttl(checkpoint_key)

default_ttl_minutes = self.ttl_config.get("default_ttl", 60)
ttl_threshold = int(default_ttl_minutes * 60 * 0.6) # 60% of original TTL
# Always refresh TTL when refresh_on_read is enabled
# This ensures all related keys maintain synchronized TTLs

# Only refresh if TTL is below threshold (or key doesn't exist)
if current_ttl == -2 or (current_ttl > 0 and current_ttl <= ttl_threshold):
# Only refresh if key exists and has TTL (skip keys with no expiry)
# TTL states: -2 = key doesn't exist, -1 = key exists but no TTL, 0 = expired, >0 = seconds remaining
if current_ttl > 0:
# Get all blob keys related to this checkpoint
from langgraph.checkpoint.redis.base import (
CHECKPOINT_BLOB_PREFIX,
Expand Down
305 changes: 305 additions & 0 deletions tests/test_async_ttl_synchronization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
"""Test async TTL synchronization behavior for AsyncRedisSaver."""

import asyncio
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from uuid import uuid4

import pytest
from langgraph.checkpoint.base import Checkpoint

from langgraph.checkpoint.redis.aio import AsyncRedisSaver


@asynccontextmanager
async def _saver(
redis_url: str, ttl_config: dict
) -> AsyncGenerator[AsyncRedisSaver, None]:
"""Create an AsyncRedisSaver with the given TTL configuration."""
async with AsyncRedisSaver.from_conn_string(redis_url, ttl=ttl_config) as saver:
await saver.setup()
yield saver


@pytest.mark.asyncio
async def test_async_ttl_refresh_on_read(redis_url: str) -> None:
"""Test that TTL is always refreshed when refresh_on_read is enabled (async)."""

# Configure with TTL refresh on read
ttl_config = {
"default_ttl": 2, # 2 minutes = 120 seconds
"refresh_on_read": True,
}

async with _saver(redis_url, ttl_config) as saver:
thread_id = str(uuid4())
checkpoint_ns = ""
checkpoint_id = str(uuid4())

# Create a checkpoint
checkpoint = Checkpoint(
v=1,
id=checkpoint_id,
ts="2024-01-01T00:00:00+00:00",
channel_values={"test": "value"},
channel_versions={},
versions_seen={},
)

config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
}
}

# Save the checkpoint
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})

# Get the checkpoint key
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)
Copy link
Preview

Copilot AI Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The import statements for checkpoint key generation are repeated in multiple test functions. Consider moving these imports to the top of the file or creating a helper function to reduce code duplication.

Suggested change
)
# Imports moved to top of file

Copilot uses AI. Check for mistakes.


checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
)

# Check initial TTL (should be around 120 seconds)
initial_ttl = await saver._redis.ttl(checkpoint_key)
assert (
115 <= initial_ttl <= 120
), f"Initial TTL should be ~120s, got {initial_ttl}"

# Wait a bit (simulate time passing)
await asyncio.sleep(2)

# Read the checkpoint - this should refresh TTL to full value
result = await saver.aget_tuple(saved_config)
assert result is not None

# Check TTL after read - should be refreshed to full value
refreshed_ttl = await saver._redis.ttl(checkpoint_key)
assert (
115 <= refreshed_ttl <= 120
), f"TTL should be refreshed to ~120s, got {refreshed_ttl}"


@pytest.mark.asyncio
async def test_async_ttl_no_refresh_when_disabled(redis_url: str) -> None:
"""Test that TTL is not refreshed when refresh_on_read is disabled (async)."""

# Configure without TTL refresh on read
ttl_config = {
"default_ttl": 2, # 2 minutes = 120 seconds
"refresh_on_read": False, # Don't refresh TTL on read
}

async with _saver(redis_url, ttl_config) as saver:
thread_id = str(uuid4())
checkpoint_ns = ""
checkpoint_id = str(uuid4())

# Create a checkpoint
checkpoint = Checkpoint(
v=1,
id=checkpoint_id,
ts="2024-01-01T00:00:00+00:00",
channel_values={"test": "value"},
channel_versions={},
versions_seen={},
)

config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
}
}

# Save the checkpoint
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})

# Get the checkpoint key
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)
Copy link
Preview

Copilot AI Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The import statements for checkpoint key generation are repeated in multiple test functions. Consider moving these imports to the top of the file or creating a helper function to reduce code duplication.

Suggested change
)

Copilot uses AI. Check for mistakes.


checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
)

# Check initial TTL
initial_ttl = await saver._redis.ttl(checkpoint_key)
assert (
115 <= initial_ttl <= 120
), f"Initial TTL should be ~120s, got {initial_ttl}"

# Wait a bit
await asyncio.sleep(2)

# Read the checkpoint - should NOT refresh TTL when refresh_on_read=False
result = await saver.aget_tuple(saved_config)
assert result is not None

# Check TTL after read - should NOT be refreshed
current_ttl = await saver._redis.ttl(checkpoint_key)
assert (
current_ttl < initial_ttl - 1
), f"TTL should have decreased, got {current_ttl}"


@pytest.mark.asyncio
async def test_async_ttl_synchronization_with_external_keys(redis_url: str) -> None:
"""Test TTL synchronization between checkpoint keys and external user keys (async)."""

# Configure with TTL refresh on read for synchronization
ttl_config = {
"default_ttl": 2, # 2 minutes = 120 seconds
"refresh_on_read": True,
}

async with _saver(redis_url, ttl_config) as saver:
thread_id = str(uuid4())
checkpoint_ns = ""
checkpoint_id = str(uuid4())

# Create a checkpoint
checkpoint = Checkpoint(
v=1,
id=checkpoint_id,
ts="2024-01-01T00:00:00+00:00",
channel_values={"test": "value"},
channel_versions={},
versions_seen={},
)

config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
}
}

# Save the checkpoint
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})

# Create external keys that should expire together
external_key1 = f"user:metadata:{thread_id}"
external_key2 = f"user:context:{thread_id}"

# Set external keys with same TTL
await saver._redis.set(external_key1, "metadata_value", ex=120)
await saver._redis.set(external_key2, "context_value", ex=120)

# Wait a bit
await asyncio.sleep(2)

# Read checkpoint - should refresh its TTL
result = await saver.aget_tuple(saved_config)
assert result is not None

# Manually refresh external keys' TTL (simulating user's synchronization logic)
await saver._redis.expire(external_key1, 120)
await saver._redis.expire(external_key2, 120)

# Check that all TTLs are synchronized
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)

checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
)

checkpoint_ttl = await saver._redis.ttl(checkpoint_key)
external_ttl1 = await saver._redis.ttl(external_key1)
external_ttl2 = await saver._redis.ttl(external_key2)

# All TTLs should be close to each other (within 2 seconds)
assert (
abs(checkpoint_ttl - external_ttl1) <= 2
), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl1}"
assert (
abs(checkpoint_ttl - external_ttl2) <= 2
), f"TTLs not synchronized: {checkpoint_ttl} vs {external_ttl2}"
assert (
115 <= checkpoint_ttl <= 120
), f"Checkpoint TTL should be ~120s, got {checkpoint_ttl}"


@pytest.mark.asyncio
async def test_async_ttl_no_refresh_for_persistent_keys(redis_url: str) -> None:
"""Test that keys without TTL (persistent) are not affected by refresh logic (async)."""

# Configure with TTL refresh on read
ttl_config = {
"default_ttl": 2, # 2 minutes
"refresh_on_read": True,
}

async with _saver(redis_url, ttl_config) as saver:
thread_id = str(uuid4())
checkpoint_ns = ""
checkpoint_id = str(uuid4())

# Create a checkpoint
checkpoint = Checkpoint(
v=1,
id=checkpoint_id,
ts="2024-01-01T00:00:00+00:00",
channel_values={"test": "value"},
channel_versions={},
versions_seen={},
)

config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
}
}

# Save the checkpoint
saved_config = await saver.aput(config, checkpoint, {"test": "metadata"}, {})

# Remove TTL to make it persistent
from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
to_storage_safe_id,
to_storage_safe_str,
)

checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(thread_id),
to_storage_safe_str(checkpoint_ns),
to_storage_safe_id(saved_config["configurable"]["checkpoint_id"]),
)
await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1)

# Verify key is persistent (TTL = -1)
ttl_before = await saver._redis.ttl(checkpoint_key)
assert ttl_before == -1, f"Key should be persistent (TTL=-1), got {ttl_before}"

# Read the checkpoint
result = await saver.aget_tuple(saved_config)
assert result is not None

# Verify key is still persistent (not affected by refresh)
ttl_after = await saver._redis.ttl(checkpoint_key)
assert (
ttl_after == -1
), f"Key should remain persistent (TTL=-1), got {ttl_after}"
Loading