diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md new file mode 100644 index 000000000000..04edc4680ea0 --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -0,0 +1,186 @@ +# Long Text Embedding with Chunked Processing + +This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length. + +## πŸš€ Quick Start + +### Start the Server + +Use the provided script to start a vLLM server with chunked processing enabled: + +```bash +# Basic usage (supports very long texts up to ~3M tokens) +./service.sh + +# Custom configuration with different models +MODEL_NAME="jinaai/jina-embeddings-v3" \ +MAX_EMBED_LEN=1048576 \ +./service.sh + +# For extremely long documents +MODEL_NAME="intfloat/multilingual-e5-large" \ +MAX_EMBED_LEN=3072000 \ +./service.sh +``` + +### Test Long Text Embedding + +Run the comprehensive test client: + +```bash +python client.py +``` + +## πŸ“ Files + +| File | Description | +|------|-------------| +| `service.sh` | Server startup script with chunked processing enabled | +| `client.py` | Comprehensive test client for long text embedding | + +## βš™οΈ Configuration + +### Server Configuration + +The key parameters for chunked processing are in the `--override-pooler-config`: + +```json +{ + "pooling_type": "auto", + "normalize": true, + "enable_chunked_processing": true, + "max_embed_len": 3072000 +} +``` + +!!! note + `pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length. + +#### Chunked Processing Behavior + +Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length: + +| Component | Behavior | Description | +|-----------|----------|-------------| +| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy | +| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts | +| **Performance** | Optimal | All chunks processed for complete semantic coverage | + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) | +| `PORT` | `31090` | Server port | +| `GPU_COUNT` | `1` | Number of GPUs to use | +| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) | +| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) | +| `API_KEY` | `EMPTY` | API key for authentication | + +## πŸ”§ How It Works + +1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables +2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity +3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy +4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks +5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing + +### Input Length Handling + +- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens) +- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered +- **Exceeds max_embed_len**: Input is rejected with clear error message +- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN` + +### Extreme Long Text Support + +With `MAX_EMBED_LEN=3072000`, you can process: + +- **Academic papers**: Full research papers with references +- **Legal documents**: Complete contracts and legal texts +- **Books**: Entire chapters or small books +- **Code repositories**: Large codebases and documentation + +## πŸ“Š Performance Characteristics + +### Chunked Processing Performance + +| Aspect | Behavior | Performance | +|--------|----------|-------------| +| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length | +| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead | +| **Memory Usage** | Proportional to number of chunks | Moderate, scalable | +| **Semantic Quality** | Complete text coverage | Optimal for long documents | + +## πŸ§ͺ Test Cases + +The test client demonstrates: + +- βœ… **Short text**: Normal processing (baseline) +- βœ… **Medium text**: Single chunk processing +- βœ… **Long text**: Multi-chunk processing with aggregation +- βœ… **Very long text**: Many chunks processing +- βœ… **Extreme long text**: Document-level processing (100K+ tokens) +- βœ… **Batch processing**: Mixed-length inputs in one request +- βœ… **Consistency**: Reproducible results across runs + +## πŸ› Troubleshooting + +### Common Issues + +1. **Chunked processing not enabled**: + + ```log + ValueError: This model's maximum position embeddings length is 4096 tokens... + ``` + + **Solution**: Ensure `enable_chunked_processing: true` in pooler config + +2. **Input exceeds max_embed_len**: + + ```log + ValueError: This model's maximum embedding input length is 3072000 tokens... + ``` + + **Solution**: Increase `max_embed_len` in pooler config or reduce input length + +3. **Memory errors**: + + ```log + RuntimeError: CUDA out of memory + ``` + + **Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs + +4. **Slow processing**: + **Expected**: Long text takes more time due to multiple inference calls + +### Debug Information + +Server logs show chunked processing activity: + +```log +INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing +INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) +``` + +## 🀝 Contributing + +To extend chunked processing support to other embedding models: + +1. Check model compatibility with the pooling architecture +2. Test with various text lengths +3. Validate embedding quality compared to single-chunk processing +4. Submit PR with test cases and documentation updates + +## πŸ†• Enhanced Features + +### max_embed_len Parameter + +The new `max_embed_len` parameter provides: + +- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable +- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len` +- **Extreme Length Support**: Process documents with millions of tokens +- **Clear Error Messages**: Better feedback when inputs exceed limits +- **Backward Compatibility**: Existing configurations continue to work diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py new file mode 100644 index 000000000000..6e9838ac6d8d --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script demonstrating long text embedding with chunked processing in vLLM. + +This example shows how to use vLLM's chunked processing feature to handle text +inputs that exceed the model's maximum token length. The feature automatically +splits long text into chunks and handles different pooling types optimally. + +Prerequisites: +1. Start vLLM server with chunked processing enabled: + + # MEAN pooling (processes all chunks, recommended for complete coverage) + vllm serve intfloat/multilingual-e5-large \ + --override-pooler-config \ + '{"pooling_type": "MEAN", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ + --served-model-name multilingual-e5-large \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + + # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) + vllm serve BAAI/bge-large-en-v1.5 \ + --override-pooler-config \ + '{"pooling_type": "CLS", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ + --served-model-name bge-large-en-v1.5 \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + +2. Install required dependencies: + pip install openai requests +""" + +import time + +import numpy as np +from openai import OpenAI + +# Configuration +API_KEY = "your-api-key" # Replace with your actual API key +BASE_URL = "http://localhost:31090/v1" +MODEL_NAME = "multilingual-e5-large" + + +def generate_long_text(base_text: str, repeat_count: int) -> str: + """Generate long text by repeating base text.""" + return base_text * repeat_count + + +def test_embedding_with_different_lengths(): + """Test embedding generation with different text lengths.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + # Test cases with different text lengths + test_cases = [ + { + "name": "Short Text", + "text": "Hello, this is a short text for embedding.", + "expected_chunks": 1, + }, + { + "name": "Medium Text", + "text": generate_long_text( + "This is a medium-length text that should fit within the " + "model's context window. " * 20, + 2, + ), + "expected_chunks": 1, + }, + { + "name": "Long Text (2 chunks)", + "text": generate_long_text( + "This is a very long text that will exceed the model's " + "maximum context length and trigger chunked processing. " * 50, + 5, + ), + "expected_chunks": 2, + }, + { + "name": "Very Long Text (3+ chunks)", + "text": generate_long_text( + "This text is extremely long and will definitely " + "require multiple chunks for processing. " * 100, + 10, + ), + "expected_chunks": 3, + }, + ] + + print("πŸ§ͺ Testing vLLM Long Text Embedding with Chunked Processing") + print("=" * 70) + + for i, test_case in enumerate(test_cases, 1): + print(f"\nπŸ“ Test {i}: {test_case['name']}") + print(f"Text length: {len(test_case['text'])} characters") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=test_case["text"], model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + # Extract embedding data + embedding = response.data[0].embedding + embedding_dim = len(embedding) + + print("βœ… Success!") + print(f" - Embedding dimension: {embedding_dim}") + print(f" - Processing time: {processing_time:.2f}s") + print(f" - Expected chunks: ~{test_case['expected_chunks']}") + print(f" - First 5 values: {embedding[:5]}") + + except Exception as e: + print(f"❌ Failed: {str(e)}") + + +def test_batch_embedding(): + """Test batch embedding with mixed-length inputs.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\nπŸ”„ Testing Batch Embedding with Mixed Lengths") + print("=" * 50) + + # Mix of short and long texts + batch_inputs = [ + "Short text 1", + generate_long_text("Medium length text that fits in one chunk. " * 20, 1), + "Another short text", + generate_long_text("Long text requiring chunked processing. " * 100, 5), + ] + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("βœ… Batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + print( + f" - Average time per input: {processing_time / len(batch_inputs):.2f}s" + ) + + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars β†’ {embedding_dim}D embedding" + ) + + except Exception as e: + print(f"❌ Batch processing failed: {str(e)}") + + +def test_multiple_long_texts_batch(): + """Test batch processing with multiple long texts to verify chunk ID uniqueness.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\nπŸ”§ Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)") + print("=" * 70) + + # Create multiple distinct long texts that will all require chunking + # Note: All pooling types now use MEAN aggregation across chunks: + # - Native pooling (MEAN/CLS/LAST) is used within each chunk + # - MEAN aggregation combines results across all chunks + # - Full semantic coverage for all pooling types + long_texts = [ + generate_long_text( + "First long document about artificial intelligence and machine learning. " + * 80, + 6, + ), + generate_long_text( + "Second long document about natural language processing and transformers. " + * 80, + 6, + ), + generate_long_text( + "Third long document about computer vision and neural networks. " * 80, 6 + ), + ] + + # Add some short texts to mix things up + batch_inputs = [ + "Short text before long texts", + long_texts[0], + "Short text between long texts", + long_texts[1], + long_texts[2], + "Short text after long texts", + ] + + print("πŸ“Š Batch composition:") + for i, text in enumerate(batch_inputs): + length = len(text) + text_type = "Long (will be chunked)" if length > 5000 else "Short" + print(f" - Input {i + 1}: {length} chars ({text_type})") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("\nβœ… Multiple long texts batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings returned: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + + # Verify each embedding is different (no incorrect aggregation) + embeddings = [data.embedding for data in response.data] + + if len(embeddings) >= 3: + import numpy as np + + # Compare embeddings of the long texts (indices 1, 3, 4) + long_embeddings = [ + np.array(embeddings[1]), # First long text + np.array(embeddings[3]), # Second long text + np.array(embeddings[4]), # Third long text + ] + + print("\nπŸ” Verifying embedding uniqueness:") + for i in range(len(long_embeddings)): + for j in range(i + 1, len(long_embeddings)): + cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / ( + np.linalg.norm(long_embeddings[i]) + * np.linalg.norm(long_embeddings[j]) + ) + print( + f" - Similarity between long text {i + 1} and {j + 1}: " + f"{cosine_sim:.4f}" + ) + + if ( + cosine_sim < 0.9 + ): # Different content should have lower similarity + print(" βœ… Good: Embeddings are appropriately different") + else: + print( + " ⚠️ High similarity - may indicate chunk " + "aggregation issue" + ) + + print("\nπŸ“‹ Per-input results:") + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + embedding_norm = np.linalg.norm(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars β†’ {embedding_dim}D " + f"embedding (norm: {embedding_norm:.4f})" + ) + + print( + "\nβœ… This test verifies the fix for chunk ID collisions in " + "batch processing" + ) + print(" - Before fix: Multiple long texts would have conflicting chunk IDs") + print(" - After fix: Each prompt's chunks have unique IDs with prompt index") + + except Exception as e: + print(f"❌ Multiple long texts batch test failed: {str(e)}") + print(" This might indicate the chunk ID collision bug is present!") + + +def test_embedding_consistency(): + """Test that chunked processing produces consistent results.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\nπŸ” Testing Embedding Consistency") + print("=" * 40) + + # Use the same long text multiple times + long_text = generate_long_text( + "Consistency test text for chunked processing validation. " * 50, 3 + ) + + embeddings = [] + + try: + for i in range(3): + response = client.embeddings.create( + input=long_text, model=MODEL_NAME, encoding_format="float" + ) + embeddings.append(response.data[0].embedding) + print(f" - Generated embedding {i + 1}") + + # Check consistency (embeddings should be identical) + if len(embeddings) >= 2: + # Calculate similarity between first two embeddings + + emb1 = np.array(embeddings[0]) + emb2 = np.array(embeddings[1]) + + # Cosine similarity + cosine_sim = np.dot(emb1, emb2) / ( + np.linalg.norm(emb1) * np.linalg.norm(emb2) + ) + + print("βœ… Consistency test completed!") + print(f" - Cosine similarity between runs: {cosine_sim:.6f}") + print(" - Expected: ~1.0 (identical embeddings)") + + if cosine_sim > 0.999: + print(" - βœ… High consistency achieved!") + else: + print(" - ⚠️ Consistency may vary due to numerical precision") + + except Exception as e: + print(f"❌ Consistency test failed: {str(e)}") + + +def main(): + """Main function to run all tests.""" + print("πŸš€ vLLM Long Text Embedding Client") + print(f"πŸ“‘ Connecting to: {BASE_URL}") + print(f"πŸ€– Model: {MODEL_NAME}") + masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****" + print(f"πŸ”‘ API Key: {masked_key}") + + # Run all test cases + test_embedding_with_different_lengths() + test_batch_embedding() + test_multiple_long_texts_batch() + test_embedding_consistency() + + print("\n" + "=" * 70) + print("πŸŽ‰ All tests completed!") + print("\nπŸ’‘ Key Features Demonstrated:") + print(" - βœ… Automatic chunked processing for long text") + print(" - βœ… Seamless handling of mixed-length batches") + print(" - βœ… Multiple long texts in single batch (chunk ID fix)") + print(" - βœ… Unified chunked processing:") + print(" β€’ Native pooling used within each chunk") + print(" β€’ MEAN aggregation across all chunks") + print(" β€’ Complete semantic coverage for all pooling types") + print(" - βœ… Consistent embedding generation") + print(" - βœ… Backward compatibility with short text") + print("\nπŸ“š For more information, see:") + print( + " - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html" + ) + print(" - Chunked Processing Guide: openai_embedding_long_text.md") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh new file mode 100644 index 000000000000..f356d7d4529e --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# vLLM Embedding Server with Enhanced Chunked Processing +# This script starts a vLLM server with chunked processing enabled for long text embedding. +# Now supports proper pooling type validation and model-specific configurations. + +set -euo pipefail + +# Configuration +MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"} +MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"} + +PORT=${PORT:-31090} +GPU_COUNT=${GPU_COUNT:-1} +MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000} +API_KEY=${API_KEY:-"your-api-key"} + +# Enhanced pooling configuration with model-specific defaults +POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST +export VLLM_ENABLE_CHUNKED_PROCESSING=true +export CUDA_VISIBLE_DEVICES=2,3,4,5 +# export VLLM_ATTENTION_BACKEND=XFORMERS + +echo "πŸš€ Starting vLLM Embedding Server with Enhanced Chunked Processing" +echo "==================================================================" + +# Environment variables for optimization +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +# Function to determine optimal pooling type for known models +get_optimal_pooling_type() { + local model="$1" + case "$model" in + *"e5-"* | *"multilingual-e5"*) + echo "MEAN" # E5 series native pooling + ;; + *"bge-"*) + echo "CLS" # BGE series native pooling + ;; + *"gte-"*) + echo "LAST" # GTE series native pooling + ;; + *"sentence-t5"* | *"st5"*) + echo "MEAN" # Sentence-T5 native pooling + ;; + *"jina-embeddings"*) + echo "MEAN" # Jina embeddings native pooling + ;; + *"Qwen"*"Embedding"*) + echo "LAST" # Qwen embeddings native pooling + ;; + *) + echo "MEAN" # Default native pooling for unknown models + ;; + esac +} + +# Auto-detect pooling type if not explicitly set +if [ "$POOLING_TYPE" = "auto" ]; then + POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME") + echo "πŸ” Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME" +fi + +# Display configuration +echo "πŸ“‹ Configuration:" +echo " - Model: $MODEL_NAME" +echo " - Port: $PORT" +echo " - GPU Count: $GPU_COUNT" +echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}" +echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens" +echo " - Native Pooling Type: $POOLING_TYPE + Normalization" +echo " - Cross-chunk Aggregation: MEAN (automatic)" +echo "" + +# Validate GPU availability +if command -v nvidia-smi &> /dev/null; then + gpu_count=$(nvidia-smi --list-gpus | wc -l) + echo "πŸ–₯️ Available GPUs: $gpu_count" + if [ "$GPU_COUNT" -gt "$gpu_count" ]; then + echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available" + echo " Adjusting to use $gpu_count GPUs" + GPU_COUNT=$gpu_count + fi +else + echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped." +fi + +# Chunked processing uses unified MEAN aggregation +echo "ℹ️ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks" +echo " - All chunks processed for complete semantic coverage" +echo " - Weighted averaging based on chunk token counts" + +echo "" +echo "πŸ”§ Starting server with enhanced chunked processing configuration..." + +# Build pooler config JSON +POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}" + +# Start vLLM server with enhanced chunked processing +vllm serve "$MODEL_NAME" \ + --tensor-parallel-size "$GPU_COUNT" \ + --enforce-eager \ + --override-pooler-config "$POOLER_CONFIG" \ + --served-model-name ${MODEL_CODE} \ + --api-key "$API_KEY" \ + --trust-remote-code \ + --port "$PORT" \ + --host 0.0.0.0 + +echo "" +echo "βœ… vLLM Embedding Server started successfully!" +echo "" +echo "πŸ“‘ Server Information:" +echo " - Base URL: http://localhost:$PORT" +echo " - Model Code: ${MODEL_CODE}" +echo " - API Key: $API_KEY" +echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN" +echo "" +echo "πŸ§ͺ Test the server with:" +echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo "" +echo "πŸ“š Enhanced features enabled:" +echo " βœ… Intelligent native pooling type detection" +echo " βœ… Unified MEAN aggregation for chunked processing" +echo " βœ… Model-specific native pooling optimization" +echo " βœ… Enhanced max embedding length (${MAX_EMBED_LEN} tokens)" +echo " βœ… Complete semantic coverage for all pooling types" +echo " βœ… OpenAI-compatible API" +echo " βœ… GPU acceleration" +echo "" +echo "πŸ”§ Advanced usage:" +echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection" +echo " - Set MAX_EMBED_LEN to adjust maximum input length" +echo " - All pooling types use MEAN aggregation across chunks" diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/openai/test_embedding_long_text.py new file mode 100644 index 000000000000..86bd34abb97e --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_long_text.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test cases for long text embedding with automatic chunking mechanism. + +This test suite validates vLLM's automatic chunking functionality for handling +text inputs that exceed the model's maximum token length, specifically targeting +the intfloat/multilingual-e5-small model (max token length: 512). +""" + +import random + +import openai +import pytest +import pytest_asyncio + +from vllm.entrypoints.openai.protocol import EmbeddingResponse + +from ...utils import RemoteOpenAIServer + + +def _generate_random_text(word_count: int) -> str: + """Generate random text with approximately the specified word count.""" + # Common English words with focus on verbs and nouns for realistic text + common_words = [ + # Essential articles and pronouns (minimal) + "the", + "and", + "you", + "they", + "this", + "that", + "these", + "those", + + # Action verbs + "create", + "build", + "develop", + "design", + "implement", + "execute", + "analyze", + "process", + "generate", + "calculate", + "evaluate", + "optimize", + "transform", + "integrate", + "configure", + "deploy", + "monitor", + "manage", + "discover", + "explore", + "investigate", + "research", + "study", + "examine", + "improve", + "enhance", + "upgrade", + "modify", + "update", + "maintain", + "solve", + "resolve", + "handle", + "address", + "tackle", + "overcome", + "communicate", + "collaborate", + "coordinate", + "organize", + "plan", + "achieve", + "accomplish", + "complete", + "finish", + "deliver", + "provide", + + # Technology and science nouns + "system", + "application", + "software", + "hardware", + "network", + "database", + "algorithm", + "model", + "framework", + "platform", + "interface", + "protocol", + "architecture", + "infrastructure", + "component", + "module", + "service", + "technology", + "innovation", + "solution", + "methodology", + "approach", + "artificial", + "intelligence", + "machine", + "learning", + "neural", + "network", + "computer", + "processor", + "memory", + "storage", + "computation", + "data", + "information", + "knowledge", + "insight", + "pattern", + "trend", + "analysis", + "research", + "development", + "engineering", + "science", + "mathematics", + "statistics", + "probability", + "optimization", + "performance", + "efficiency", + + # General nouns + "project", + "team", + "organization", + "company", + "business", + "industry", + "market", + "customer", + "user", + "client", + "product", + "feature", + "function", + "requirement", + "specification", + "documentation", + "report", + "result", + "outcome", + "impact", + "benefit", + "advantage", + "challenge", + "problem", + "opportunity", + "strategy", + "goal", + "objective", + "target", + "milestone", + "process", + "procedure", + "workflow", + "pipeline", + "operation", + "task", + "activity", + "event", + "session", + "meeting", + "discussion", + "decision" + ] + + words = [] + for _ in range(word_count): + words.append(random.choice(common_words)) + + # Add some punctuation for more realistic text + text = " ".join(words) + # Add periods every 10-20 words + words_list = text.split() + result = [] + for i, word in enumerate(words_list): + result.append(word) + if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1): + result[-1] += "." + + return " ".join(result) + + +MODEL_NAME = "intfloat/multilingual-e5-small" +DTYPE = "bfloat16" + +# Test text: Generate text with approximately 1500 words to exceed 1024 tokens +LONG_TEXT_1500_WORDS = _generate_random_text(1500) + +# Test text: Generate text with approximately 2500 words to exceed 2048 tokens +LONG_TEXT_2500_WORDS = _generate_random_text(2500) + + +@pytest.fixture(scope="module") +def server_with_chunked_processing(): + """Start server with automatic chunking processing enabled.""" + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", # Set smaller max_model_len to trigger chunking mechanism + '--override-pooler-config', + ('{"pooling_type": "MEAN", "normalize": true, ' + '"enable_chunked_processing": true, "max_embed_len": 10000}'), + "--gpu-memory-utilization", + "0.8", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_chunked_processing(server_with_chunked_processing): + """Create async client with chunking processing support.""" + async with server_with_chunked_processing.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_1500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~1500 character long text + (~1028 tokens, exceeding 512 token limit).""" + + # Verify text length + # Verify text has sufficient word count (approximately 1500 words) + word_count = len(LONG_TEXT_1500_WORDS.split()) + assert word_count >= 1400, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~1500 words, we expect roughly + # 1024+ tokens (exceeding 512 token limit) + # Should exceed single chunk limit of 512 + assert embeddings.usage.prompt_tokens > 800 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_2500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~2500 character long text + (~2048 tokens, requiring multiple chunks).""" + + # Verify text length + # Verify text has sufficient word count (approximately 2500 words) + word_count = len(LONG_TEXT_2500_WORDS.split()) + assert word_count >= 2300, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_2500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~2500 words, we expect + # roughly 2048+ tokens (requiring multiple chunks) + # Should require multiple chunks for processing + assert embeddings.usage.prompt_tokens > 1500 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_batch_long_text_embedding( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test batch long text embedding processing.""" + + input_texts = [ + LONG_TEXT_1500_WORDS, + LONG_TEXT_2500_WORDS, + "This is a short text test.", # Short text for comparison + ] + + # Send batch embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=input_texts, + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 # Three input texts + + # Verify each embedding dimension + for i, embedding_data in enumerate(embeddings.data): + assert len(embedding_data.embedding) == 384 + assert embedding_data.index == i + + # Verify embedding vector validity + embedding_vector = embedding_data.embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + # Verify token usage + assert embeddings.usage.completion_tokens == 0 + # Total token count should be very substantial + assert embeddings.usage.prompt_tokens > 1000 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_vs_normal_consistency( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test consistency between chunked and + normal processing (using short text).""" + + # Use a short text within the 512 token limit + short_text = ("Artificial intelligence technology is changing our world, " + "bringing unprecedented opportunities and challenges.") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[short_text], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 384 + assert embeddings.usage.completion_tokens == 0 + # Short text should not require chunked processing + assert embeddings.usage.prompt_tokens < 512 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # ιͺŒθ―embeddingε‘ι‡ηš„ζœ‰ζ•ˆζ€§ + embedding_vector = embeddings.data[0].embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_processing_response_format( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test response format and structure during chunked processing.""" + + # Test with long text to trigger chunking + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert embeddings.data[0].object == "embedding" + assert embeddings.data[0].index == 0 + + # Verify embedding vector properties + embedding_vector = embeddings.data[0].embedding + import math + vector_norm = math.sqrt(sum(x * x for x in embedding_vector)) + # Check that the vector is normalized + # (default behavior for most embedding models) + assert 0.8 < vector_norm < 1.2, ( + f"Vector norm should be reasonable, actual: {vector_norm}") diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 6649cd89ee34..b4ea15ef5a0f 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2598,6 +2598,25 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + This parameter enables accepting long inputs without requiring + VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds + max_embed_len, it will be handled according to the original max_model_len + validation logic. Defaults to None (i.e. set to max_model_len). + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 84ba00873103..9dcad8e391c6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from typing import Final, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator, Mapping +from typing import Any, Final, Literal, Optional, Union, cast import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never, override @@ -12,19 +14,28 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this docstring +# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext) + RequestPrompt, + ServeContext, + TextTokensPrompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput) + PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams +from vllm.utils import chunk_list logger = init_logger(__name__) @@ -46,6 +57,17 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + pooler_config = self.model_config.pooler_config + + # Avoid repeated attribute lookups + self.supports_chunked_processing = bool( + pooler_config and pooler_config.enable_chunked_processing) + self.max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + @override async def _preprocess( self, @@ -129,6 +151,435 @@ def _build_response( usage=usage, ) + def _get_max_position_embeddings(self) -> int: + """Get the model's effective maximum sequence length for chunking.""" + return self.model_config.max_model_len + + def _should_use_chunked_processing(self, request) -> bool: + """Check if chunked processing should be used for this request.""" + return isinstance( + request, + (EmbeddingCompletionRequest, + EmbeddingChatRequest)) and self.supports_chunked_processing + + async def _process_chunked_request( + self, + ctx: EmbeddingServeContext, + original_prompt: TextTokensPrompt, + pooling_params, + trace_headers, + prompt_idx: int, + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + """Process a single prompt using chunked processing.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + token_ids = original_prompt["prompt_token_ids"] + + # Split into chunks using max_position_embeddings + max_pos_embeddings = self._get_max_position_embeddings() + # Process all chunks for MEAN aggregation + for chunk_idx, chunk_tokens in enumerate( + chunk_list(token_ids, max_pos_embeddings)): + # Create a request ID for this chunk + chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" + f"chunk-{chunk_idx}") + + # Create engine prompt for this chunk + chunk_engine_prompt = EngineTokensPrompt( + prompt_token_ids=chunk_tokens) + + # Create chunk request prompt for logging + chunk_text = "" + chunk_request_prompt = TextTokensPrompt( + prompt=chunk_text, prompt_token_ids=chunk_tokens) + + # Log the chunk + self._log_inputs(chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Create generator for this chunk and wrap it to return indices + original_generator = self.engine_client.encode( + chunk_engine_prompt, + pooling_params, + chunk_request_id, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(original_generator) + + return generators + + def _validate_input( + self, + request, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + """Override to support chunked processing for embedding requests.""" + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingCompletionRequest, EmbeddingChatRequest)): + # Check if chunked processing is enabled for pooling models + enable_chunked = self._should_use_chunked_processing(request) + + # Use max_position_embeddings for chunked processing decisions + max_pos_embeddings = self._get_max_position_embeddings() + + # Determine the effective max length for validation + if self.max_embed_len is not None: + # Use max_embed_len for validation instead of max_model_len + length_type = "maximum embedding input length" + max_length_value = self.max_embed_len + else: + # Fall back to max_model_len validation (original behavior) + length_type = "maximum context length" + max_length_value = self.max_model_len + + validation_error_msg = ( + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input.") + + chunked_processing_error_msg = ( + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input " + "or enable chunked processing.") + + # Check if input exceeds max length + if token_num > max_length_value: + raise ValueError( + validation_error_msg.format( + length_type=length_type, + max_length_value=max_length_value, + token_num=token_num)) + + # Check for chunked processing + # when exceeding max_position_embeddings + if token_num > max_pos_embeddings: + if enable_chunked: + # Allow long inputs when chunked processing is enabled + logger.info( + "Input length %s exceeds max_position_embeddings " + "%s, will use chunked processing", token_num, + max_pos_embeddings) + else: + raise ValueError( + chunked_processing_error_msg.format( + length_type="maximum position embeddings length", + max_length_value=max_pos_embeddings, + token_num=token_num)) + + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # For other request types, use the parent's implementation + return super()._validate_input(request, input_ids, input_text) + + def _is_text_tokens_prompt(self, prompt) -> bool: + """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + async def _create_single_prompt_generator( + self, + ctx: EmbeddingServeContext, + engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], + request_prompt: RequestPrompt, + pooling_params: PoolingParams, + trace_headers: Optional[Mapping[str, str]], + prompt_index: int, + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + """Create a generator for a single prompt using standard processing.""" + request_id_item = f"{ctx.request_id}-{prompt_index}" + + self._log_inputs(request_id_item, + request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + + # Return the original generator without wrapping + return self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + @override + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Override to support chunked processing.""" + ctx = cast(EmbeddingServeContext, ctx) + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + # If no chunked processing needed, delegate to parent class + if not use_chunked: + return await super()._prepare_generators(ctx) + + # Custom logic for chunked processing + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params + + # Verify and set the task for pooling params + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + max_pos_embeddings = self._get_max_position_embeddings() + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_prompt = ctx.request_prompts[i] + + # Check if this specific prompt needs chunked processing + if self._is_text_tokens_prompt(request_prompt): + # Cast to TextTokensPrompt since we've verified + # prompt_token_ids + text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + if (len(text_tokens_prompt["prompt_token_ids"]) + > max_pos_embeddings): + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, text_tokens_prompt, pooling_params, + trace_headers, i) + generators.extend(chunk_generators) + continue + + # Normal processing for short prompts or non-token prompts + # Cast engine_prompt to the expected type for mypy + engine_prompt_typed = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = await self._create_single_prompt_generator( + ctx, engine_prompt_typed, request_prompt, pooling_params, + trace_headers, i) + generators.append(generator) + + from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + @override + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect and aggregate batch results + with support for chunked processing. + + For chunked requests, performs online aggregation to + minimize memory usage. + For regular requests, collects results normally. + """ + ctx = cast(EmbeddingServeContext, ctx) + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + if not use_chunked: + return await super()._collect_batch(ctx=ctx) + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + except (ValueError, IndexError): + # Fallback: extract from result_idx if parsing fails + prompt_idx = result_idx + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': None, + 'total_weight': 0, + 'chunk_count': 0, + 'request_id': result.request_id.split("-chunk-")[0] + } + + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, 'data'): + # PoolingOutput case + embedding_data = result.outputs.data + elif hasattr(result.outputs, 'embedding'): + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + else: + return self.create_error_response( + f"Unsupported output type: " + f"{type(result.outputs).__name__}") + + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor(embedding_data, + dtype=torch.float32) + + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight + + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding + else: + # Accumulate + aggregator['weighted_sum'] += weighted_embedding + + aggregator['total_weight'] += weight + aggregator['chunk_count'] += 1 + else: + # Non-chunked result - extract prompt_idx from request_id + parts = result.request_id.split("-") + try: + # Last part should be prompt index + prompt_idx = int(parts[-1]) + except (ValueError, IndexError): + prompt_idx = result_idx # Fallback to result_idx + + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) + + # Finalize aggregated results + final_res_batch: list[Union[PoolingRequestOutput, + EmbeddingRequestOutput]] = [] + num_prompts = len(ctx.engine_prompts) + + for prompt_idx in range(num_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, + (int, float)) and total_weight > 0): + + # Compute final mean embedding + final_embedding = weighted_sum / total_weight + + # Create a PoolingRequestOutput + # for the aggregated result + pooling_output_data = PoolingOutput( + data=final_embedding) + + # Get original prompt token IDs for this prompt + original_prompt = ctx.request_prompts[prompt_idx] + if not self._is_text_tokens_prompt(original_prompt): + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"TextTokensPrompt") + + original_token_ids = cast( + TextTokensPrompt, + original_prompt)["prompt_token_ids"] + + pooling_request_output = PoolingRequestOutput( + request_id=aggregator['request_id'], + prompt_token_ids=original_token_ids, + outputs=pooling_output_data, + finished=True) + + final_res_batch.append(pooling_request_output) + else: + return self.create_error_response( + f"Failed to aggregate chunks " + f"for prompt {prompt_idx}") + elif prompt_idx in short_prompts_results: + final_res_batch.append( + cast(PoolingRequestOutput, + short_prompts_results[prompt_idx])) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) + + return None + + except Exception as e: + return self.create_error_response(str(e)) + class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd"