Skip to content

feat: Add optional idempotency support to batches API #3171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/source/providers/batches/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

## Overview

Protocol for batch processing API operations.

The Batches API enables efficient processing of multiple requests in a single operation,
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.

The API is designed to allow use of openai client libraries for seamless integration.

This API provides the following extensions:
- idempotent batch creation

Note: This API is currently under active development and may undergo changes.

This section contains documentation for all available providers for the **batches** API.
Expand Down
10 changes: 8 additions & 2 deletions llama_stack/apis/batches/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):

@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.

"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.

The API is designed to allow use of openai client libraries for seamless integration.

This API provides the following extensions:
- idempotent batch creation

Note: This API is currently under active development and may undergo changes.
"""

Expand All @@ -45,13 +49,15 @@ async def create_batch(
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.

:param input_file_id: The ID of an uploaded file containing requests for the batch.
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...
Expand Down
80 changes: 64 additions & 16 deletions llama_stack/providers/inline/batches/reference/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import asyncio
import hashlib
import itertools
import json
import time
Expand Down Expand Up @@ -136,28 +137,45 @@ async def create_batch(
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.

Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
This implementation provides optional idempotency: when an idempotency key
(idempotency_key) is provided, a deterministic ID is generated based on the input
parameters. If a batch with the same parameters already exists, it will be
returned instead of creating a duplicate. Without an idempotency key,
each request creates a new batch with a unique ID.

Args:
input_file_id: The ID of an uploaded file containing requests for the batch.
endpoint: The endpoint to be used for all requests in the batch.
completion_window: The time window within which the batch should be processed.
metadata: Optional metadata for the batch.
idempotency_key: Optional idempotency key for enabling idempotent behavior.

Returns:
The created or existing batch object.
"""

# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status

# TODO: set expiration time for garbage collection

if endpoint not in ["/v1/chat/completions"]:
Expand All @@ -171,6 +189,35 @@ async def create_batch(
)

batch_id = f"batch_{uuid.uuid4().hex[:16]}"

# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default batch id use's a 16 char hex section, is there a reason to use a different length here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

secret way to tell the difference. happy to align them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's fine then. i personally like to add a prefix for those reasons (OpenAI follows this but they don't expose an idempotency key) but different size is okay.

batch_id = f"batch_{hash_digest}"

try:
existing_batch = await self.retrieve_batch(batch_id)

if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)

logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass

current_time = int(time.time())

batch = BatchObject(
Expand All @@ -185,6 +232,7 @@ async def create_batch(
)

await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")

if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))
Expand Down
91 changes: 91 additions & 0 deletions tests/integration/batches/test_batches_idempotency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

"""
Integration tests for batch idempotency functionality using the OpenAI client library.

This module tests the idempotency feature in the batches API using the OpenAI-compatible
client interface. These tests verify that the idempotency key (idempotency_key) works correctly
in a real client-server environment.

Test Categories:
1. Successful Idempotency: Same key returns same batch with identical parameters
- test_idempotent_batch_creation_successful: Verifies that requests with the same
idempotency key return identical batches, even with different metadata order

2. Conflict Detection: Same key with conflicting parameters raises HTTP 409 errors
- test_idempotency_conflict_with_different_params: Verifies that reusing an idempotency key
with truly conflicting parameters (both file ID and metadata values) raises ConflictError
"""

import time

import pytest
from openai import ConflictError


class TestBatchesIdempotencyIntegration:
"""Integration tests for batch idempotency using OpenAI client."""

def test_idempotent_batch_creation_successful(self, openai_client):
"""Test that identical requests with same idempotency key return the same batch."""
batch1 = openai_client.batches.create(
input_file_id="bogus-id",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"test_type": "idempotency_success",
"purpose": "integration_test",
},
extra_body={"idempotency_key": "test-idempotency-token-1"},
)

# sleep to ensure different timestamps
time.sleep(1)

batch2 = openai_client.batches.create(
input_file_id="bogus-id",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"purpose": "integration_test",
"test_type": "idempotency_success",
}, # Different order
extra_body={"idempotency_key": "test-idempotency-token-1"},
)

assert batch1.id == batch2.id
assert batch1.input_file_id == batch2.input_file_id
assert batch1.endpoint == batch2.endpoint
assert batch1.completion_window == batch2.completion_window
assert batch1.metadata == batch2.metadata
assert batch1.created_at == batch2.created_at

def test_idempotency_conflict_with_different_params(self, openai_client):
"""Test that using same idempotency key with different params raises conflict error."""
batch1 = openai_client.batches.create(
input_file_id="bogus-id-1",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"test_type": "conflict_test_1"},
extra_body={"idempotency_key": "conflict-token"},
)

with pytest.raises(ConflictError) as exc_info:
openai_client.batches.create(
input_file_id="bogus-id-2", # Different file ID
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"test_type": "conflict_test_2"}, # Different metadata
extra_body={"idempotency_key": "conflict-token"}, # Same token
)

assert exc_info.value.status_code == 409
assert "conflict" in str(exc_info.value).lower()

retrieved_batch = openai_client.batches.retrieve(batch1.id)
assert retrieved_batch.id == batch1.id
assert retrieved_batch.input_file_id == "bogus-id-1"
54 changes: 54 additions & 0 deletions tests/unit/providers/batches/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

"""Shared fixtures for batches provider unit tests."""

import tempfile
from pathlib import Path
from unittest.mock import AsyncMock

import pytest

from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig


@pytest.fixture
async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)

# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()

provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()

# unit tests should not require background processing
provider.process_batches = False

yield provider

await provider.shutdown()


@pytest.fixture
def sample_batch_data():
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
43 changes: 0 additions & 43 deletions tests/unit/providers/batches/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,60 +54,17 @@
"""

import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock

import pytest

from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig


class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""

@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)

# Create kvstore and mock APIs
from unittest.mock import AsyncMock

from llama_stack.providers.utils.kvstore import kvstore_impl

kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()

provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()

# unit tests should not require background processing
provider.process_batches = False

yield provider

await provider.shutdown()

@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}

def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Expand Down
Loading
Loading