From 9ebc61b07e69bd8fb5427c7abd0b9e2b3e7d58cc Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 05:57:18 +0800 Subject: [PATCH 01/39] The latest update introduces new text embedding examples and service scripts, incorporating chunk processing capabilities to handle exceptionally long inputs. The README documentation has been revised to provide comprehensive instructions on usage methods and configuration options. Signed-off-by: x22x22 --- .../openai_embedding_long_text/README.md | 196 +++++++ .../openai_embedding_long_text/client.py | 368 +++++++++++++ .../openai_embedding_long_text/service.sh | 138 +++++ vllm/config.py | 207 ++----- vllm/entrypoints/openai/serving_embedding.py | 504 +++++++++++++++++- vllm/entrypoints/openai/serving_engine.py | 58 +- 6 files changed, 1304 insertions(+), 167 deletions(-) create mode 100644 examples/online_serving/openai_embedding_long_text/README.md create mode 100644 examples/online_serving/openai_embedding_long_text/client.py create mode 100644 examples/online_serving/openai_embedding_long_text/service.sh diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md new file mode 100644 index 000000000000..dcd66a9fee9d --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -0,0 +1,196 @@ +# Long Text Embedding with Chunked Processing + +This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length. + +## ๐Ÿš€ Quick Start + +### 1. Start the Server + +Use the provided script to start a vLLM server with chunked processing enabled: + +```bash +# Basic usage (supports very long texts up to ~3M tokens) +./service.sh + +# Custom configuration with different models +MODEL_NAME="jinaai/jina-embeddings-v3" \ +MAX_EMBED_LEN=1048576 \ +./service.sh + +# For extremely long documents +MODEL_NAME="intfloat/multilingual-e5-large" \ +MAX_EMBED_LEN=3072000 \ +./service.sh +``` + +### 2. Test Long Text Embedding + +Run the comprehensive test client: + +```bash +python client.py +``` + +## ๐Ÿ“ Files + +| File | Description | +|------|-------------| +| `service.sh` | Server startup script with chunked processing enabled | +| `client.py` | Comprehensive test client for long text embedding | +| `../openai_embedding_client.py` | Basic embedding client (updated with chunked processing info) | + +## โš™๏ธ Configuration + +### Server Configuration + +The key parameters for chunked processing are in the `--override-pooler-config`: + +```json +{ + "pooling_type": "auto", + "normalize": true, + "enable_chunked_processing": true, + "max_embed_len": 3072000 +} +``` + +**Note**: `pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length. + +#### Chunked Processing Behavior + +Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length: + +| Component | Behavior | Description | +|-----------|----------|-------------| +| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy | +| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts | +| **Performance** | Optimal | All chunks processed for complete semantic coverage | + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) | +| `PORT` | `31090` | Server port | +| `GPU_COUNT` | `1` | Number of GPUs to use | +| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) | +| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) | +| `API_KEY` | `EMPTY` | API key for authentication | + +## ๐Ÿ”ง How It Works + +1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables +2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity +3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy +4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks +5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing + +### Input Length Handling + +- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens) +- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered +- **Exceeds max_embed_len**: Input is rejected with clear error message +- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN` + +### Extreme Long Text Support + +With `MAX_EMBED_LEN=3072000`, you can process: + +- **Academic papers**: Full research papers with references +- **Legal documents**: Complete contracts and legal texts +- **Books**: Entire chapters or small books +- **Code repositories**: Large codebases and documentation + +## ๐Ÿ“Š Performance Characteristics + +### Chunked Processing Performance + +| Aspect | Behavior | Performance | +|--------|----------|-------------| +| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length | +| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead | +| **Memory Usage** | Proportional to number of chunks | Moderate, scalable | +| **Semantic Quality** | Complete text coverage | Optimal for long documents | + +## ๐Ÿงช Test Cases + +The test client demonstrates: + +- โœ… **Short text**: Normal processing (baseline) +- โœ… **Medium text**: Single chunk processing +- โœ… **Long text**: Multi-chunk processing with aggregation +- โœ… **Very long text**: Many chunks processing +- โœ… **Extreme long text**: Document-level processing (100K+ tokens) +- โœ… **Batch processing**: Mixed-length inputs in one request +- โœ… **Consistency**: Reproducible results across runs + +## ๐Ÿ› Troubleshooting + +### Common Issues + +1. **Chunked processing not enabled**: + + ```log + ValueError: This model's maximum position embeddings length is 4096 tokens... + ``` + + **Solution**: Ensure `enable_chunked_processing: true` in pooler config + +2. **Input exceeds max_embed_len**: + + ```log + ValueError: This model's maximum embedding input length is 3072000 tokens... + ``` + + **Solution**: Increase `max_embed_len` in pooler config or reduce input length + +3. **Memory errors**: + + ```log + RuntimeError: CUDA out of memory + ``` + + **Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs + +4. **Slow processing**: + **Expected**: Long text takes more time due to multiple inference calls + +### Debug Information + +Server logs show chunked processing activity: + +```log +INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing +INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) +``` + +## ๐Ÿ“š Additional Resources + +- [Pooling Models Documentation](../../docs/models/pooling_models.md#chunked-processing-for-long-text) +- [Supported Models List](../../docs/models/supported_models.md#text-embedding) +- [Original Feature Documentation](../../README_CHUNKED_PROCESSING.md) + +## ๐Ÿค Contributing + +To extend chunked processing support to other embedding models: + +1. Check model compatibility with the pooling architecture +2. Test with various text lengths +3. Validate embedding quality compared to single-chunk processing +4. Submit PR with test cases and documentation updates + +## ๐Ÿ†• Enhanced Features + +### max_embed_len Parameter + +The new `max_embed_len` parameter provides: + +- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable +- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len` +- **Extreme Length Support**: Process documents with millions of tokens +- **Clear Error Messages**: Better feedback when inputs exceed limits +- **Backward Compatibility**: Existing configurations continue to work + +--- + +**Note**: Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../../docs/models/supported_models.md#chunked-processing-for-long-text) for the complete list. diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py new file mode 100644 index 000000000000..7e3663f2854a --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script demonstrating long text embedding with chunked processing in vLLM. + +This example shows how to use vLLM's chunked processing feature to handle text +inputs that exceed the model's maximum token length. The feature automatically +splits long text into chunks and handles different pooling types optimally. + +Prerequisites: +1. Start vLLM server with chunked processing enabled: + + # MEAN pooling (processes all chunks, recommended for complete coverage) + vllm serve intfloat/multilingual-e5-large \ + --task embed \ + --override-pooler-config \ + '{"pooling_type": "MEAN", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ + --served-model-name multilingual-e5-large \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + + # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) + vllm serve BAAI/bge-large-en-v1.5 \ + --task embed \ + --override-pooler-config \ + '{"pooling_type": "CLS", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ + --served-model-name bge-large-en-v1.5 \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + +2. Install required dependencies: + pip install openai requests +""" + +import time + +import numpy as np +from openai import OpenAI + +# Configuration +API_KEY = "your-api-key" # Replace with your actual API key +BASE_URL = "http://localhost:31090/v1" +MODEL_NAME = "multilingual-e5-large" + + +def generate_long_text(base_text: str, repeat_count: int) -> str: + """Generate long text by repeating base text.""" + return base_text * repeat_count + + +def test_embedding_with_different_lengths(): + """Test embedding generation with different text lengths.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + # Test cases with different text lengths + test_cases = [ + { + "name": "Short Text", + "text": "Hello, this is a short text for embedding.", + "expected_chunks": 1, + }, + { + "name": "Medium Text", + "text": generate_long_text( + "This is a medium-length text that should fit within the " + "model's context window. " * 20, + 2, + ), + "expected_chunks": 1, + }, + { + "name": "Long Text (2 chunks)", + "text": generate_long_text( + "This is a very long text that will exceed the model's " + "maximum context length and trigger chunked processing. " * 50, + 5, + ), + "expected_chunks": 2, + }, + { + "name": "Very Long Text (3+ chunks)", + "text": generate_long_text( + "This text is extremely long and will definitely " + "require multiple chunks for processing. " * 100, + 10, + ), + "expected_chunks": 3, + }, + ] + + print("๐Ÿงช Testing vLLM Long Text Embedding with Chunked Processing") + print("=" * 70) + + for i, test_case in enumerate(test_cases, 1): + print(f"\n๐Ÿ“ Test {i}: {test_case['name']}") + print(f"Text length: {len(test_case['text'])} characters") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=test_case["text"], model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + # Extract embedding data + embedding = response.data[0].embedding + embedding_dim = len(embedding) + + print("โœ… Success!") + print(f" - Embedding dimension: {embedding_dim}") + print(f" - Processing time: {processing_time:.2f}s") + print(f" - Expected chunks: ~{test_case['expected_chunks']}") + print(f" - First 5 values: {embedding[:5]}") + + except Exception as e: + print(f"โŒ Failed: {str(e)}") + + +def test_batch_embedding(): + """Test batch embedding with mixed-length inputs.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n๐Ÿ”„ Testing Batch Embedding with Mixed Lengths") + print("=" * 50) + + # Mix of short and long texts + batch_inputs = [ + "Short text 1", + generate_long_text("Medium length text that fits in one chunk. " * 20, 1), + "Another short text", + generate_long_text("Long text requiring chunked processing. " * 100, 5), + ] + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("โœ… Batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + print( + f" - Average time per input: {processing_time / len(batch_inputs):.2f}s" + ) + + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars โ†’ {embedding_dim}D embedding" + ) + + except Exception as e: + print(f"โŒ Batch processing failed: {str(e)}") + + +def test_multiple_long_texts_batch(): + """Test batch processing with multiple long texts to verify chunk ID uniqueness.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n๐Ÿ”ง Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)") + print("=" * 70) + + # Create multiple distinct long texts that will all require chunking + # Note: All pooling types now use MEAN aggregation across chunks: + # - Native pooling (MEAN/CLS/LAST) is used within each chunk + # - MEAN aggregation combines results across all chunks + # - Full semantic coverage for all pooling types + long_texts = [ + generate_long_text( + "First long document about artificial intelligence and machine learning. " + * 80, + 6, + ), + generate_long_text( + "Second long document about natural language processing and transformers. " + * 80, + 6, + ), + generate_long_text( + "Third long document about computer vision and neural networks. " * 80, 6 + ), + ] + + # Add some short texts to mix things up + batch_inputs = [ + "Short text before long texts", + long_texts[0], + "Short text between long texts", + long_texts[1], + long_texts[2], + "Short text after long texts", + ] + + print("๐Ÿ“Š Batch composition:") + for i, text in enumerate(batch_inputs): + length = len(text) + text_type = "Long (will be chunked)" if length > 5000 else "Short" + print(f" - Input {i + 1}: {length} chars ({text_type})") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("\nโœ… Multiple long texts batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings returned: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + + # Verify each embedding is different (no incorrect aggregation) + embeddings = [data.embedding for data in response.data] + + if len(embeddings) >= 3: + import numpy as np + + # Compare embeddings of the long texts (indices 1, 3, 4) + long_embeddings = [ + np.array(embeddings[1]), # First long text + np.array(embeddings[3]), # Second long text + np.array(embeddings[4]), # Third long text + ] + + print("\n๐Ÿ” Verifying embedding uniqueness:") + for i in range(len(long_embeddings)): + for j in range(i + 1, len(long_embeddings)): + cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / ( + np.linalg.norm(long_embeddings[i]) + * np.linalg.norm(long_embeddings[j]) + ) + print( + f" - Similarity between long text {i + 1} and {j + 1}: " + f"{cosine_sim:.4f}" + ) + + if ( + cosine_sim < 0.9 + ): # Different content should have lower similarity + print(" โœ… Good: Embeddings are appropriately different") + else: + print( + " โš ๏ธ High similarity - may indicate chunk " + "aggregation issue" + ) + + print("\n๐Ÿ“‹ Per-input results:") + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + embedding_norm = np.linalg.norm(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars โ†’ {embedding_dim}D " + f"embedding (norm: {embedding_norm:.4f})" + ) + + print( + "\nโœ… This test verifies the fix for chunk ID collisions in " + "batch processing" + ) + print(" - Before fix: Multiple long texts would have conflicting chunk IDs") + print(" - After fix: Each prompt's chunks have unique IDs with prompt index") + + except Exception as e: + print(f"โŒ Multiple long texts batch test failed: {str(e)}") + print(" This might indicate the chunk ID collision bug is present!") + + +def test_embedding_consistency(): + """Test that chunked processing produces consistent results.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n๐Ÿ” Testing Embedding Consistency") + print("=" * 40) + + # Use the same long text multiple times + long_text = generate_long_text( + "Consistency test text for chunked processing validation. " * 50, 3 + ) + + embeddings = [] + + try: + for i in range(3): + response = client.embeddings.create( + input=long_text, model=MODEL_NAME, encoding_format="float" + ) + embeddings.append(response.data[0].embedding) + print(f" - Generated embedding {i + 1}") + + # Check consistency (embeddings should be identical) + if len(embeddings) >= 2: + # Calculate similarity between first two embeddings + + emb1 = np.array(embeddings[0]) + emb2 = np.array(embeddings[1]) + + # Cosine similarity + cosine_sim = np.dot(emb1, emb2) / ( + np.linalg.norm(emb1) * np.linalg.norm(emb2) + ) + + print("โœ… Consistency test completed!") + print(f" - Cosine similarity between runs: {cosine_sim:.6f}") + print(" - Expected: ~1.0 (identical embeddings)") + + if cosine_sim > 0.999: + print(" - โœ… High consistency achieved!") + else: + print(" - โš ๏ธ Consistency may vary due to numerical precision") + + except Exception as e: + print(f"โŒ Consistency test failed: {str(e)}") + + +def main(): + """Main function to run all tests.""" + print("๐Ÿš€ vLLM Long Text Embedding Client") + print(f"๐Ÿ“ก Connecting to: {BASE_URL}") + print(f"๐Ÿค– Model: {MODEL_NAME}") + masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****" + print(f"๐Ÿ”‘ API Key: {masked_key}") + + # Run all test cases + test_embedding_with_different_lengths() + test_batch_embedding() + test_multiple_long_texts_batch() + test_embedding_consistency() + + print("\n" + "=" * 70) + print("๐ŸŽ‰ All tests completed!") + print("\n๐Ÿ’ก Key Features Demonstrated:") + print(" - โœ… Automatic chunked processing for long text") + print(" - โœ… Seamless handling of mixed-length batches") + print(" - โœ… Multiple long texts in single batch (chunk ID fix)") + print(" - โœ… Unified chunked processing:") + print(" โ€ข Native pooling used within each chunk") + print(" โ€ข MEAN aggregation across all chunks") + print(" โ€ข Complete semantic coverage for all pooling types") + print(" - โœ… Consistent embedding generation") + print(" - โœ… Backward compatibility with short text") + print("\n๐Ÿ“š For more information, see:") + print( + " - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html" + ) + print(" - Chunked Processing Guide: openai_embedding_long_text.md") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh new file mode 100644 index 000000000000..03feb485d6d4 --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# vLLM Embedding Server with Enhanced Chunked Processing +# This script starts a vLLM server with chunked processing enabled for long text embedding. +# Now supports proper pooling type validation and model-specific configurations. + +set -euo pipefail + +# Configuration +MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"} +MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"} + +PORT=${PORT:-31090} +GPU_COUNT=${GPU_COUNT:-1} +MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000} +API_KEY=${API_KEY:-"your-api-key"} + +# Enhanced pooling configuration with model-specific defaults +POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST +export VLLM_ENABLE_CHUNKED_PROCESSING=true +export CUDA_VISIBLE_DEVICES=2,3,4,5 +# export VLLM_ATTENTION_BACKEND=XFORMERS + +echo "๐Ÿš€ Starting vLLM Embedding Server with Enhanced Chunked Processing" +echo "==================================================================" + +# Environment variables for optimization +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +# Function to determine optimal pooling type for known models +get_optimal_pooling_type() { + local model="$1" + case "$model" in + *"e5-"* | *"multilingual-e5"*) + echo "MEAN" # E5 series native pooling + ;; + *"bge-"*) + echo "CLS" # BGE series native pooling + ;; + *"gte-"*) + echo "LAST" # GTE series native pooling + ;; + *"sentence-t5"* | *"st5"*) + echo "MEAN" # Sentence-T5 native pooling + ;; + *"jina-embeddings"*) + echo "MEAN" # Jina embeddings native pooling + ;; + *"Qwen"*"Embedding"*) + echo "LAST" # Qwen embeddings native pooling + ;; + *) + echo "MEAN" # Default native pooling for unknown models + ;; + esac +} + +# Auto-detect pooling type if not explicitly set +if [ "$POOLING_TYPE" = "auto" ]; then + POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME") + echo "๐Ÿ” Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME" +fi + +# Display configuration +echo "๐Ÿ“‹ Configuration:" +echo " - Model: $MODEL_NAME" +echo " - Port: $PORT" +echo " - GPU Count: $GPU_COUNT" +echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}" +echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens" +echo " - Native Pooling Type: $POOLING_TYPE + Normalization" +echo " - Cross-chunk Aggregation: MEAN (automatic)" +echo "" + +# Validate GPU availability +if command -v nvidia-smi &> /dev/null; then + gpu_count=$(nvidia-smi --list-gpus | wc -l) + echo "๐Ÿ–ฅ๏ธ Available GPUs: $gpu_count" + if [ "$GPU_COUNT" -gt "$gpu_count" ]; then + echo "โš ๏ธ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available" + echo " Adjusting to use $gpu_count GPUs" + GPU_COUNT=$gpu_count + fi +else + echo "โš ๏ธ Warning: nvidia-smi not found. GPU detection skipped." +fi + +# Chunked processing uses unified MEAN aggregation +echo "โ„น๏ธ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks" +echo " - All chunks processed for complete semantic coverage" +echo " - Weighted averaging based on chunk token counts" + +echo "" +echo "๐Ÿ”ง Starting server with enhanced chunked processing configuration..." + +# Build pooler config JSON +POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}" + +# Start vLLM server with enhanced chunked processing +vllm serve "$MODEL_NAME" \ + --tensor-parallel-size "$GPU_COUNT" \ + --enforce-eager \ + --override-pooler-config "$POOLER_CONFIG" \ + --served-model-name ${MODEL_CODE} \ + --task embed \ + --api-key "$API_KEY" \ + --trust-remote-code \ + --port "$PORT" \ + --host 0.0.0.0 + +echo "" +echo "โœ… vLLM Embedding Server started successfully!" +echo "" +echo "๐Ÿ“ก Server Information:" +echo " - Base URL: http://localhost:$PORT" +echo " - Model Code: ${MODEL_CODE}" +echo " - API Key: $API_KEY" +echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN" +echo "" +echo "๐Ÿงช Test the server with:" +echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo "" +echo "๐Ÿ“š Enhanced features enabled:" +echo " โœ… Intelligent native pooling type detection" +echo " โœ… Unified MEAN aggregation for chunked processing" +echo " โœ… Model-specific native pooling optimization" +echo " โœ… Enhanced max embedding length (${MAX_EMBED_LEN} tokens)" +echo " โœ… Complete semantic coverage for all pooling types" +echo " โœ… OpenAI-compatible API" +echo " โœ… GPU acceleration" +echo "" +echo "๐Ÿ”ง Advanced usage:" +echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection" +echo " - Set MAX_EMBED_LEN to adjust maximum input length" +echo " - All pooling types use MEAN aggregation across chunks" diff --git a/vllm/config.py b/vllm/config.py index 899862bf541e..6564121d401b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,11 +11,10 @@ import uuid import warnings from collections import Counter -from collections.abc import Mapping from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) -from functools import cached_property, lru_cache +from functools import cached_property from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -39,8 +38,8 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - maybe_override_with_speculators_target_model, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) + try_get_generation_config, try_get_safetensors_metadata, + try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect # yapf conflicts with isort for this block @@ -57,7 +56,6 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance - from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup from transformers.configuration_utils import PretrainedConfig @@ -75,7 +73,6 @@ else: DataclassInstance = Any PlacementGroup = Any - RuntimeEnv = Any PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any @@ -377,8 +374,7 @@ class ModelConfig: max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the - OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * - vocab_size) logprobs are allowed to be returned and it may cause OOM.""" + OpenAI Chat Completions API.""" logprobs_mode: LogprobsMode = "raw_logprobs" """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: @@ -538,15 +534,6 @@ def __post_init__(self) -> None: "affect the random state of the Python process that " "launched vLLM.", self.seed) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name(self.model, self.served_model_name) @@ -618,8 +605,8 @@ def __post_init__(self) -> None: self.config_format, hf_overrides_kw=hf_overrides_kw, hf_overrides_fn=hf_overrides_fn) - self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(self.hf_config) self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) @@ -789,9 +776,6 @@ def _task_to_convert(task: TaskOption) -> ConvertType: raise ValueError( "`override_neuron_config` is only supported on Neuron.") - # Avoid running try_verify_and_update_config multiple times - self.config_updated = False - self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -815,17 +799,12 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" - if getattr(self, "runner_type", self.runner) == "pooling": - return "TransformersModel" if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal return "TransformersForMultimodalLM" - return "TransformersForCausalLM" - - def using_transformers_backend(self) -> bool: - """Check if the model is using the Transformers backend class.""" - return self.architecture == self._get_transformers_backend_cls() + else: + return "TransformersForCausalLM" @property def registry(self): @@ -888,12 +867,6 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None - def set_disable_mm_preprocessor_cache(self, value: bool) -> None: - mm_config = self.get_multimodal_config() - - self.disable_mm_preprocessor_cache = value - mm_config.disable_mm_preprocessor_cache = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -913,6 +886,15 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) + if self.is_matryoshka: + if pooler_config.normalize is None: + pooler_config.normalize = True + elif not pooler_config.normalize: + raise ValueError( + "`normalize` must be enabled (set to True) " + "for models that are compatible with " + "Matryoshka Representation.") + return pooler_config return None @@ -1099,21 +1081,6 @@ def _parse_quant_hf_config(self): if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) - - else: - # Set quant_method for ModelOpt models. - producer_name = quant_cfg.get("producer", {}).get("name") - if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", - {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError( - f"Unknown ModelOpt quant algo: {quant_algo}") - return quant_cfg def _verify_quantization(self) -> None: @@ -1589,18 +1556,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config def try_get_generation_config(self) -> dict[str, Any]: - """ - This method attempts to retrieve the non-default values of the - generation config for this model. - - The generation config can contain information about special tokens, as - well as sampling parameters. Which is why this method exists separately - to `get_diff_sampling_param`. - - Returns: - A dictionary containing the non-default generation config. - """ - if self.generation_config in {"auto", "vllm"}: + if self.generation_config in ("auto", "vllm"): config = try_get_generation_config( self.hf_config_path or self.model, trust_remote_code=self.trust_remote_code, @@ -1619,18 +1575,13 @@ def try_get_generation_config(self) -> dict[str, Any]: def get_diff_sampling_param(self) -> dict[str, Any]: """ - This method returns a dictionary containing the non-default sampling - parameters with `override_generation_config` applied. - - The default sampling parameters are: - - - vLLM's neutral defaults if `self.generation_config="vllm"` - - the model's defaults if `self.generation_config="auto"` - - as defined in `generation_config.json` if - `self.generation_config="path/to/generation_config/dir"` + This method returns a dictionary containing the parameters + that differ from the default sampling parameters. If + `generation_config` is `"vllm"`, an empty dictionary is returned. Returns: - A dictionary containing the non-default sampling parameters. + dict[str, Any]: A dictionary with the differing sampling + parameters, if `generation_config` is `"vllm"` an empty dictionary. """ if self.generation_config == "vllm": config = {} @@ -2073,7 +2024,7 @@ class ParallelConfig: and when data_parallel_size > 0. Enables running an AsyncLLM and API server on a "per-node" basis where vLLM load balances between local data parallel ranks, but an external LB balances - between vLLM nodes/replicas. Set explicitly in conjunction with + between vLLM nodes/replicas. Set explicitly in conjunction with --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" @@ -2107,9 +2058,6 @@ class ParallelConfig: ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" - ray_runtime_env: Optional["RuntimeEnv"] = None - """Ray runtime environment to pass to distributed workers.""" - placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" @@ -3022,13 +2970,10 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") - from vllm.transformers_utils.configs import ( - SpeculatorsConfig) from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): + EAGLEConfig): pass else: eagle_config = EAGLEConfig( @@ -3056,19 +3001,6 @@ def __post_init__(self): f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") - if self.speculative_token_tree is None: - # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) - self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, @@ -3200,19 +3132,10 @@ def _verify_args(self) -> Self: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - from vllm.transformers_utils.configs import SpeculatorsConfig - - eagle3_target_supported = ["llama"] - if self.draft_model_config and isinstance( - self.draft_model_config.hf_config, SpeculatorsConfig): - eagle3_target_supported.append("qwen") - - if self.method == "eagle3" and self.target_model_config and not any( - supported_model in - self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported): + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: raise ValueError( - f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 + "Eagle3 is only supported for Llama models. " f"Got {self.target_model_config.hf_text_config.model_type=}") return self @@ -3406,16 +3329,7 @@ def get_limit_per_prompt(self, modality: str) -> int: 999 if envs.VLLM_USE_V1 else 1, ) - def merge_mm_processor_kwargs( - self, - inference_kwargs: Mapping[str, object], - ) -> dict[str, object]: - """ - Get the keyword arguments to pass to the multi-modal processor - according to the extra arguments passed during inference. - """ - kwargs = self.mm_processor_kwargs or {} - return kwargs | dict(inference_kwargs) + # TODO: Add configs to init vision tower or not. @config @@ -3429,34 +3343,25 @@ class PoolerConfig: [`vllm.model_executor.layers.pooler.PoolingType`][]. """ - ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. - """ - dimensions: Optional[int] = None - """ - Reduce the dimensions of embeddings if model - support matryoshka representation. - """ - - ## for classification models - activation: Optional[bool] = None - """ - Whether to apply activation function to the classification outputs. + Whether to normalize the pooled outputs. Usually, this should be set to + ``True`` for embedding outputs. """ - ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the pooled outputs. Usually, this should be set + to ``True`` for classification outputs. """ + step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ + returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, @@ -3464,6 +3369,25 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_model_len to be accepted for embedding models. + This parameter enables accepting long inputs without requiring + VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds + max_embed_len, it will be handled according to the original max_model_len + validation logic. Defaults to None (use max_model_len validation). + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -4143,7 +4067,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 + fi_allreduce_fusion_max_token_num: int = 1024 """Max number of tokens to used in flashinfer allreduce fusion.""" # TODO(luka) better pass enabling system. @@ -4374,20 +4298,12 @@ def __repr__(self) -> str: "disabled_custom_ops": True, "compilation_time": True, "bs_to_padded_graph_size": True, + "pass_config": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, }, } - - # exclude default attr in pass_config - pass_config_exclude = {} - for attr, default_val in vars(PassConfig()).items(): - if getattr(self.pass_config, attr) == default_val: - pass_config_exclude[attr] = True - if pass_config_exclude: - exclude["pass_config"] = pass_config_exclude - # The cast to string is necessary because Pydantic is mocked in docs # builds and sphinx-argparse doesn't know the return type of decode() return str( @@ -5017,11 +4933,6 @@ def try_verify_and_update_config(self): if self.model_config is None: return - # Avoid running try_verify_and_update_config multiple times - if getattr(self.model_config, "config_updated", False): - return - self.model_config.config_updated = True - architecture = self.model_config.architecture if architecture is None: return @@ -5123,14 +5034,6 @@ def set_current_vllm_config(vllm_config: VllmConfig, finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix - # Clear the compilation config cache when context changes - get_cached_compilation_config.cache_clear() - - -@lru_cache(maxsize=1) -def get_cached_compilation_config(): - """Cache config to avoid repeated calls to get_current_vllm_config()""" - return get_current_vllm_config().compilation_config def get_current_vllm_config() -> VllmConfig: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 84ba00873103..42551a1854f1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from typing import Final, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, Final, Literal, Optional, Union, cast import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never, override @@ -12,15 +14,22 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this docstring +# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext) + ServeContext, + TextTokensPrompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) @@ -129,6 +138,497 @@ def _build_response( usage=usage, ) + def _get_max_position_embeddings(self) -> int: + """Get the model's effective maximum sequence length for chunking. + + This uses the same logic as vLLM's _get_and_verify_max_len to determine + the actual sequence length limit, + considering both model config and tokenizer config. + When max_model_len is set and smaller than max_position_embeddings, + use max_model_len for chunking. + """ + hf_config = self.model_config.hf_config + + # Start with max_position_embeddings from model config + derived_max_len = getattr(hf_config, 'max_position_embeddings', 512) + + # Get tokenizer config for pooling models (embedding models) + if self.model_config.runner_type == "pooling": + from vllm.transformers_utils.config import try_get_tokenizer_config + tokenizer_config = try_get_tokenizer_config( + self.model_config.tokenizer, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + + # Consider model_max_length in tokenizer_config + # (same logic as _get_and_verify_max_len) + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + 'model_max_length', derived_max_len) + derived_max_len = min(derived_max_len, + tokenizer_model_max_length) + + # Consider max_model_len when it's set and smaller than other limits + # max_model_len is set in OpenAIServing.__init__ + # from model_config.max_model_len + if self.max_model_len is not None: + derived_max_len = min(derived_max_len, self.max_model_len) + + return int(derived_max_len) + + def _should_use_chunked_processing(self, request) -> bool: + """Check if chunked processing should be used for this request.""" + if not isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + return False + + pooler_config = getattr(self.model_config, 'pooler_config', None) + + # For chunked processing, we always use MEAN aggregation + # for cross-chunk aggregation (native pooling is used within each chunk) + return (pooler_config is not None + and getattr(pooler_config, 'enable_chunked_processing', False)) + + def _chunk_token_ids(self, token_ids: list[int], + chunk_size: int) -> list[list[int]]: + """Split token IDs into chunks of specified size.""" + if len(token_ids) <= chunk_size: + return [token_ids] + + chunks = [] + for i in range(0, len(token_ids), chunk_size): + chunk = token_ids[i:i + chunk_size] + chunks.append(chunk) + return chunks + + async def _process_chunked_request( + self, + ctx: EmbeddingServeContext, + original_prompt: TextTokensPrompt, + pooling_params, + trace_headers, + prompt_idx: int, + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + """Process a single prompt using chunked processing.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + token_ids = original_prompt["prompt_token_ids"] + + # Split into chunks using max_position_embeddings + max_pos_embeddings = self._get_max_position_embeddings() + chunks = self._chunk_token_ids(token_ids, max_pos_embeddings) + + # Process all chunks for MEAN aggregation + chunks_to_process = chunks + chunk_indices = list(range(len(chunks))) + logger.info("Using chunked processing with MEAN aggregation") + + for i, (chunk_idx, chunk_tokens) in enumerate( + zip(chunk_indices, chunks_to_process)): + # Create a request ID for this chunk + chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" + f"chunk-{chunk_idx}") + + # Create engine prompt for this chunk + chunk_engine_prompt = EngineTokensPrompt( + prompt_token_ids=chunk_tokens) + + # Create chunk request prompt for logging + chunk_text = "" + chunk_request_prompt = TextTokensPrompt( + prompt=chunk_text, prompt_token_ids=chunk_tokens) + + # Log the chunk + self._log_inputs(chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Create generator for this chunk + generator = self.engine_client.encode( + chunk_engine_prompt, + pooling_params, + chunk_request_id, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + return generators + + def _validate_input( + self, + request, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + """Override to support chunked processing for embedding requests.""" + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + # Check if chunked processing is enabled for pooling models + pooler_config = getattr(self.model_config, 'pooler_config', None) + enable_chunked = (pooler_config is not None and getattr( + pooler_config, 'enable_chunked_processing', False)) + + # Get max_embed_len from pooler config if set + max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + + # Use max_position_embeddings for chunked processing decisions + max_pos_embeddings = self._get_max_position_embeddings() + + # Determine the effective max length for validation + if max_embed_len is not None: + # Use max_embed_len for validation instead of max_model_len + effective_max_len = max_embed_len + length_type = "maximum embedding input length" + max_length_value = max_embed_len + else: + # Fall back to max_model_len validation (original behavior) + effective_max_len = self.max_model_len + length_type = "maximum context length" + max_length_value = self.max_model_len + + validation_error_msg = ( + "This model's {length_type} is {max_length} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input." + ).format(length_type=length_type, + max_length=max_length_value, + token_num=token_num) + + # Check if input exceeds effective max length + if token_num > effective_max_len: + raise ValueError(validation_error_msg) + + # Check for chunked processing + # when exceeding max_position_embeddings + if token_num > max_pos_embeddings: + if enable_chunked: + # Allow long inputs when chunked processing is enabled + logger.info( + "Input length %s exceeds max_position_embeddings " + "%s, will use chunked processing", token_num, + max_pos_embeddings) + else: + raise ValueError( + f"This model's maximum position embeddings length is " + f"{max_pos_embeddings} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input or " + f"enable chunked processing.") + + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # For other request types, use the parent's implementation + return super()._validate_input(request, input_ids, input_text) + + def _is_text_tokens_prompt(self, prompt) -> bool: + """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Override to support chunked processing.""" + ctx = cast(EmbeddingServeContext, ctx) + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + # Verify and set the task for pooling params + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_prompt = ctx.request_prompts[i] + + # Check if this specific prompt needs chunked processing + max_pos_embeddings = self._get_max_position_embeddings() + if (use_chunked + and self._is_text_tokens_prompt(request_prompt)): + # Cast to TextTokensPrompt since we've + # verified prompt_token_ids + text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + if len(text_tokens_prompt["prompt_token_ids"] + ) > max_pos_embeddings: + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, text_tokens_prompt, pooling_params, + trace_headers, i) + generators.extend(chunk_generators) + continue + + # Normal processing for short prompts or non-token prompts + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect and aggregate batch results + with support for chunked processing. + + For chunked requests, performs online aggregation to + minimize memory usage. + For regular requests, collects results normally. + """ + ctx = cast(EmbeddingServeContext, ctx) + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + if use_chunked: + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': + None, + 'total_weight': + 0, + 'chunk_count': + 0, + 'request_id': + result.request_id.split("-chunk-")[0] + } + + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + embedding_data = result.outputs.data + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor( + embedding_data, dtype=torch.float32) + + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight + + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding + else: + # Accumulate + current_sum = aggregator['weighted_sum'] + if isinstance(current_sum, torch.Tensor): + aggregator['weighted_sum'] = ( + current_sum + weighted_embedding) + + total_weight = aggregator['total_weight'] + if isinstance(total_weight, (int, float)): + aggregator['total_weight'] = (total_weight + + weight) + + chunk_count = aggregator['chunk_count'] + if isinstance(chunk_count, int): + aggregator['chunk_count'] = chunk_count + 1 + + except (ValueError, IndexError): + return self.create_error_response( + f"Invalid chunk request ID format: " + f"{result.request_id}") + else: + # Non-chunked result + try: + prompt_idx = int(result.request_id.split("-")[-1]) + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) + except ValueError: + return self.create_error_response( + f"Invalid request ID format: " + f"{result.request_id}") + + # Build final result batch + final_res_batch = [] + + for prompt_idx, request_prompt in enumerate( + ctx.request_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + # Finalize weighted average + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0): + final_embedding = weighted_sum / total_weight + + # Create aggregated result + from vllm.outputs import PoolingOutput + aggregated_output = PoolingOutput( + data=final_embedding) + + # Get original prompt token ids + if self._is_text_tokens_prompt(request_prompt): + text_tokens_prompt = cast( + TextTokensPrompt, request_prompt) + original_token_ids = text_tokens_prompt[ + "prompt_token_ids"] + else: + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"text tokens prompt") + + # Ensure request_id is string + request_id = aggregator['request_id'] + if not isinstance(request_id, str): + return self.create_error_response( + f"Invalid request_id type: " + f"{type(request_id)}") + + aggregated_result = PoolingRequestOutput( + request_id=request_id, + outputs=aggregated_output, + prompt_token_ids=original_token_ids, + finished=True, + ) + final_res_batch.append(aggregated_result) + else: + return self.create_error_response( + f"No valid aggregation data for prompt " + f"{prompt_idx}") + + elif prompt_idx in short_prompts_results: + # This was a short prompt + final_res_batch.append( + short_prompts_results[prompt_idx]) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) + else: + # Normal processing for non-chunked requests + num_prompts = len(ctx.engine_prompts) + normal_final_res_batch: list[ + Optional[PoolingRequestOutput]] = [None] * num_prompts + + async for result_idx, result in ctx.result_generator: + if result_idx < num_prompts: + # Cast to PoolingRequestOutput for embedding results + normal_final_res_batch[result_idx] = cast( + PoolingRequestOutput, result) + + if None in normal_final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + final_results = [ + res for res in normal_final_res_batch if res is not None + ] + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_results) + + return None + + except Exception as e: + return self.create_error_response(str(e)) + class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 71976fea1ee7..d74231d7e9d9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,6 +16,7 @@ from fastapi import Request from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import TypeIs if sys.version_info >= (3, 12): @@ -512,8 +513,13 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: if (isinstance(message, dict) and "content" in message and isinstance(message["content"], list)): for content_dict in message["content"]: - if "type" in content_dict: - message_types.add(content_dict["type"].split("_")[0]) + # Check if content_dict has a "type" key and it's a string + if isinstance(content_dict, dict): + type_value = content_dict.get("type") + if isinstance(type_value, str): + # Split on "_" and take the first part + base_type = type_value.split("_")[0] + message_types.add(base_type) return message_types async def _normalize_prompt_text_to_input( @@ -890,12 +896,23 @@ async def _preprocess_chat( **_chat_template_kwargs, ) else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) + # Type check for apply_hf_chat_template which only accepts + # PreTrainedTokenizer or PreTrainedTokenizerFast + if isinstance(tokenizer, + (PreTrainedTokenizer, PreTrainedTokenizerFast)): + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + else: + # For other tokenizer types, we need to handle this differently + # This shouldn't happen in normal operation, but we handle it + # for type safety + raise ValueError( + f"Unsupported tokenizer type for HF chat template: " + f"{type(tokenizer)}") mm_data = await mm_data_future @@ -932,9 +949,16 @@ async def _preprocess_chat( # For MistralTokenizer assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids") - prompt_inputs = TextTokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + # Type check for decode method + if hasattr(tokenizer, 'decode'): + decoded_prompt = tokenizer.decode(request_prompt) + else: + # Fallback for tokenizers without decode method + raise ValueError( + f"Tokenizer {type(tokenizer)} does not support " + f"decode method") + prompt_inputs = TextTokensPrompt(prompt=decoded_prompt, + prompt_token_ids=request_prompt) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) @@ -995,7 +1019,9 @@ def _log_inputs( elif isinstance(inputs, list): prompt_token_ids = inputs elif 'prompt_embeds' in inputs: - prompt_embeds = inputs.get("prompt_embeds") + # Cast to proper type for log_inputs + prompt_embeds = cast(Optional[torch.Tensor], + inputs.get("prompt_embeds")) else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] @@ -1043,7 +1069,13 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token - return tokenizer.decode(token_id) + + # Type check for decode method + if hasattr(tokenizer, 'decode'): + return tokenizer.decode(token_id) + else: + # Fallback for tokenizers without decode method + return f"token_id:{token_id}" def _is_model_supported(self, model_name: Optional[str]) -> bool: if not model_name: From cab8200b94818596fc6856e6557d4e0751871b2a Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:08:50 +0800 Subject: [PATCH 02/39] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E5=A4=84=E7=90=86=E5=99=A8=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E7=9A=84=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E5=90=88=E5=B9=B6=E4=BC=A0=E5=85=A5=E7=9A=84?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E3=80=82=E6=9B=B4=E6=96=B0=E4=BA=86=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E6=96=87=E4=BB=B6=E4=BB=A5=E4=BD=BF=E7=94=A8=E6=96=B0?= =?UTF-8?q?=E7=9A=84=E5=90=88=E5=B9=B6=E6=96=B9=E5=BC=8F=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 2 +- vllm/inputs/registry.py | 2 +- vllm/transformers_utils/processor.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 42551a1854f1..f569c54f2b00 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -32,7 +32,7 @@ from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput) + PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams logger = init_logger(__name__) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6331a70b469a..81c9f995b552 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -154,7 +154,7 @@ def call_hf_processor( assert callable(hf_processor) mm_config = self.model_config.get_multimodal_config() - merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) + merged_kwargs = {**(mm_config.mm_processor_kwargs or {}), **kwargs} allowed_kwargs = get_allowed_kwarg_only_overrides( hf_processor, diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index a630d940b257..aad386f78250 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -55,7 +55,7 @@ def _merge_mm_kwargs( **kwargs, ): mm_config = model_config.get_multimodal_config() - merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) + merged_kwargs = {**(mm_config.mm_processor_kwargs or {}), **kwargs} factory = _get_processor_factory_fn(processor_cls) allowed_kwargs = get_allowed_kwarg_only_overrides( From 57987aa83dfc0a12db72dd719d8e573a7ca46aaf Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:24:46 +0800 Subject: [PATCH 03/39] restore Signed-off-by: x22x22 --- vllm/config.py | 209 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 153 insertions(+), 56 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6564121d401b..b7f263c86b95 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,10 +11,11 @@ import uuid import warnings from collections import Counter +from collections.abc import Mapping from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) -from functools import cached_property +from functools import cached_property, lru_cache from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -38,8 +39,8 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - try_get_generation_config, try_get_safetensors_metadata, - try_get_tokenizer_config, uses_mrope) + maybe_override_with_speculators_target_model, try_get_generation_config, + try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect # yapf conflicts with isort for this block @@ -56,6 +57,7 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance + from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup from transformers.configuration_utils import PretrainedConfig @@ -73,6 +75,7 @@ else: DataclassInstance = Any PlacementGroup = Any + RuntimeEnv = Any PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any @@ -374,7 +377,8 @@ class ModelConfig: max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the - OpenAI Chat Completions API.""" + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * + vocab_size) logprobs are allowed to be returned and it may cause OOM.""" logprobs_mode: LogprobsMode = "raw_logprobs" """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: @@ -534,6 +538,15 @@ def __post_init__(self) -> None: "affect the random state of the Python process that " "launched vLLM.", self.seed) + if self.runner != "draft": + # If we're not running the draft model, check for speculators config + # If speculators config, set model / tokenizer to be target model + self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code) + # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name(self.model, self.served_model_name) @@ -605,8 +618,8 @@ def __post_init__(self) -> None: self.config_format, hf_overrides_kw=hf_overrides_kw, hf_overrides_fn=hf_overrides_fn) - self.hf_config = hf_config + self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) @@ -776,6 +789,9 @@ def _task_to_convert(task: TaskOption) -> ConvertType: raise ValueError( "`override_neuron_config` is only supported on Neuron.") + # Avoid running try_verify_and_update_config multiple times + self.config_updated = False + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -799,12 +815,17 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" + if getattr(self, "runner_type", self.runner) == "pooling": + return "TransformersModel" if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal return "TransformersForMultimodalLM" - else: - return "TransformersForCausalLM" + return "TransformersForCausalLM" + + def using_transformers_backend(self) -> bool: + """Check if the model is using the Transformers backend class.""" + return self.architecture == self._get_transformers_backend_cls() @property def registry(self): @@ -867,6 +888,12 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None + def set_disable_mm_preprocessor_cache(self, value: bool) -> None: + mm_config = self.get_multimodal_config() + + self.disable_mm_preprocessor_cache = value + mm_config.disable_mm_preprocessor_cache = value + def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -886,15 +913,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) - if self.is_matryoshka: - if pooler_config.normalize is None: - pooler_config.normalize = True - elif not pooler_config.normalize: - raise ValueError( - "`normalize` must be enabled (set to True) " - "for models that are compatible with " - "Matryoshka Representation.") - return pooler_config return None @@ -1081,6 +1099,21 @@ def _parse_quant_hf_config(self): if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", + {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError( + f"Unknown ModelOpt quant algo: {quant_algo}") + return quant_cfg def _verify_quantization(self) -> None: @@ -1556,7 +1589,18 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config def try_get_generation_config(self) -> dict[str, Any]: - if self.generation_config in ("auto", "vllm"): + """ + This method attempts to retrieve the non-default values of the + generation config for this model. + + The generation config can contain information about special tokens, as + well as sampling parameters. Which is why this method exists separately + to `get_diff_sampling_param`. + + Returns: + A dictionary containing the non-default generation config. + """ + if self.generation_config in {"auto", "vllm"}: config = try_get_generation_config( self.hf_config_path or self.model, trust_remote_code=self.trust_remote_code, @@ -1575,13 +1619,18 @@ def try_get_generation_config(self) -> dict[str, Any]: def get_diff_sampling_param(self) -> dict[str, Any]: """ - This method returns a dictionary containing the parameters - that differ from the default sampling parameters. If - `generation_config` is `"vllm"`, an empty dictionary is returned. + This method returns a dictionary containing the non-default sampling + parameters with `override_generation_config` applied. + + The default sampling parameters are: + + - vLLM's neutral defaults if `self.generation_config="vllm"` + - the model's defaults if `self.generation_config="auto"` + - as defined in `generation_config.json` if + `self.generation_config="path/to/generation_config/dir"` Returns: - dict[str, Any]: A dictionary with the differing sampling - parameters, if `generation_config` is `"vllm"` an empty dictionary. + A dictionary containing the non-default sampling parameters. """ if self.generation_config == "vllm": config = {} @@ -2024,7 +2073,7 @@ class ParallelConfig: and when data_parallel_size > 0. Enables running an AsyncLLM and API server on a "per-node" basis where vLLM load balances between local data parallel ranks, but an external LB balances - between vLLM nodes/replicas. Set explicitly in conjunction with + between vLLM nodes/replicas. Set explicitly in conjunction with --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" @@ -2058,6 +2107,9 @@ class ParallelConfig: ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" + ray_runtime_env: Optional["RuntimeEnv"] = None + """Ray runtime environment to pass to distributed workers.""" + placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" @@ -2970,10 +3022,13 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") + from vllm.transformers_utils.configs import ( + SpeculatorsConfig) from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, - EAGLEConfig): + (EAGLEConfig, SpeculatorsConfig)): pass else: eagle_config = EAGLEConfig( @@ -3001,6 +3056,19 @@ def __post_init__(self): f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval( + self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, @@ -3132,10 +3200,19 @@ def _verify_args(self) -> Self: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - if self.method == "eagle3" and self.target_model_config and \ - "llama" not in self.target_model_config.hf_text_config.model_type: + from vllm.transformers_utils.configs import SpeculatorsConfig + + eagle3_target_supported = ["llama"] + if self.draft_model_config and isinstance( + self.draft_model_config.hf_config, SpeculatorsConfig): + eagle3_target_supported.append("qwen") + + if self.method == "eagle3" and self.target_model_config and not any( + supported_model in + self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported): raise ValueError( - "Eagle3 is only supported for Llama models. " + f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Got {self.target_model_config.hf_text_config.model_type=}") return self @@ -3329,7 +3406,16 @@ def get_limit_per_prompt(self, modality: str) -> int: 999 if envs.VLLM_USE_V1 else 1, ) - # TODO: Add configs to init vision tower or not. + def merge_mm_processor_kwargs( + self, + inference_kwargs: Mapping[str, object], + ) -> dict[str, object]: + """ + Get the keyword arguments to pass to the multi-modal processor + according to the extra arguments passed during inference. + """ + kwargs = self.mm_processor_kwargs or {} + return kwargs | dict(inference_kwargs) @config @@ -3343,25 +3429,34 @@ class PoolerConfig: [`vllm.model_executor.layers.pooler.PoolingType`][]. """ + ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the pooled outputs. Usually, this should be set to - ``True`` for embedding outputs. + Whether to normalize the embeddings outputs. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. """ - softmax: Optional[bool] = None + ## for classification models + activation: Optional[bool] = None """ - Whether to apply softmax to the pooled outputs. Usually, this should be set - to ``True`` for classification outputs. + Whether to apply activation function to the classification outputs. """ + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + """ step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ - returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, @@ -3369,25 +3464,6 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ - enable_chunked_processing: Optional[bool] = None - """ - Whether to enable chunked processing for long inputs that exceed the model's - maximum position embeddings. When enabled, long inputs will be split into - chunks, processed separately, and then aggregated using weighted averaging. - This allows embedding models to handle arbitrarily long text without CUDA - errors. Defaults to False. - """ - - max_embed_len: Optional[int] = None - """ - Maximum input length allowed for embedding generation. When set, allows - inputs longer than max_model_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring - VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds - max_embed_len, it will be handled according to the original max_model_len - validation logic. Defaults to None (use max_model_len validation). - """ - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -4067,7 +4143,7 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 1024 + fi_allreduce_fusion_max_token_num: int = 16384 """Max number of tokens to used in flashinfer allreduce fusion.""" # TODO(luka) better pass enabling system. @@ -4298,12 +4374,20 @@ def __repr__(self) -> str: "disabled_custom_ops": True, "compilation_time": True, "bs_to_padded_graph_size": True, - "pass_config": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, }, } + + # exclude default attr in pass_config + pass_config_exclude = {} + for attr, default_val in vars(PassConfig()).items(): + if getattr(self.pass_config, attr) == default_val: + pass_config_exclude[attr] = True + if pass_config_exclude: + exclude["pass_config"] = pass_config_exclude + # The cast to string is necessary because Pydantic is mocked in docs # builds and sphinx-argparse doesn't know the return type of decode() return str( @@ -4933,6 +5017,11 @@ def try_verify_and_update_config(self): if self.model_config is None: return + # Avoid running try_verify_and_update_config multiple times + if getattr(self.model_config, "config_updated", False): + return + self.model_config.config_updated = True + architecture = self.model_config.architecture if architecture is None: return @@ -5034,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_vllm_config()""" + return get_current_vllm_config().compilation_config def get_current_vllm_config() -> VllmConfig: @@ -5158,4 +5255,4 @@ def update_config(config: DataclassInstanceT, current_value, # type: ignore[type-var] value) processed_overrides[field_name] = value - return replace(config, **processed_overrides) + return replace(config, **processed_overrides) \ No newline at end of file From 8e3ba7268943d7a965c449dd83493b227bab1fe3 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:26:49 +0800 Subject: [PATCH 04/39] Feature: Implement chunk processing and maximum embedding length configuration options to facilitate long-text input support. Signed-off-by: x22x22 --- vllm/config.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index b7f263c86b95..1a90dfb85021 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3464,6 +3464,25 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_model_len to be accepted for embedding models. + This parameter enables accepting long inputs without requiring + VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds + max_embed_len, it will be handled according to the original max_model_len + validation logic. Defaults to None (use max_model_len validation). + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, From f24b546820c58d85b5792c2a29fc2044498ee8f4 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:29:14 +0800 Subject: [PATCH 05/39] restore Signed-off-by: x22x22 --- vllm/inputs/registry.py | 4 ++-- vllm/transformers_utils/processor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 81c9f995b552..120786e53161 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -154,7 +154,7 @@ def call_hf_processor( assert callable(hf_processor) mm_config = self.model_config.get_multimodal_config() - merged_kwargs = {**(mm_config.mm_processor_kwargs or {}), **kwargs} + merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) allowed_kwargs = get_allowed_kwarg_only_overrides( hf_processor, @@ -242,4 +242,4 @@ def dummy_data_for_profiling( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), multi_modal_data=dec_data.multi_modal_data, multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) + ) \ No newline at end of file diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index aad386f78250..5253f29302cd 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -55,7 +55,7 @@ def _merge_mm_kwargs( **kwargs, ): mm_config = model_config.get_multimodal_config() - merged_kwargs = {**(mm_config.mm_processor_kwargs or {}), **kwargs} + merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) factory = _get_processor_factory_fn(processor_cls) allowed_kwargs = get_allowed_kwarg_only_overrides( @@ -242,4 +242,4 @@ def cached_image_processor_from_config( revision=model_config.revision, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs), - ) + ) \ No newline at end of file From b46791be4b6f9bb4ec8eb64a048f3d13eca63e81 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:32:29 +0800 Subject: [PATCH 06/39] restore Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 508 +------------------ vllm/entrypoints/openai/serving_engine.py | 60 +-- 2 files changed, 18 insertions(+), 550 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f569c54f2b00..13b30265499f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,11 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from collections.abc import AsyncGenerator -from typing import Any, Final, Literal, Optional, Union, cast +from typing import Final, Literal, Optional, Union, cast import numpy as np -import torch from fastapi import Request from typing_extensions import assert_never, override @@ -14,25 +12,18 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this docstring -# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, - EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext, - TextTokensPrompt) -# yapf: enable + ServeContext) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput, RequestOutput) + PoolingRequestOutput) from vllm.pooling_params import PoolingParams logger = init_logger(__name__) @@ -138,497 +129,6 @@ def _build_response( usage=usage, ) - def _get_max_position_embeddings(self) -> int: - """Get the model's effective maximum sequence length for chunking. - - This uses the same logic as vLLM's _get_and_verify_max_len to determine - the actual sequence length limit, - considering both model config and tokenizer config. - When max_model_len is set and smaller than max_position_embeddings, - use max_model_len for chunking. - """ - hf_config = self.model_config.hf_config - - # Start with max_position_embeddings from model config - derived_max_len = getattr(hf_config, 'max_position_embeddings', 512) - - # Get tokenizer config for pooling models (embedding models) - if self.model_config.runner_type == "pooling": - from vllm.transformers_utils.config import try_get_tokenizer_config - tokenizer_config = try_get_tokenizer_config( - self.model_config.tokenizer, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - - # Consider model_max_length in tokenizer_config - # (same logic as _get_and_verify_max_len) - if tokenizer_config: - tokenizer_model_max_length = tokenizer_config.get( - 'model_max_length', derived_max_len) - derived_max_len = min(derived_max_len, - tokenizer_model_max_length) - - # Consider max_model_len when it's set and smaller than other limits - # max_model_len is set in OpenAIServing.__init__ - # from model_config.max_model_len - if self.max_model_len is not None: - derived_max_len = min(derived_max_len, self.max_model_len) - - return int(derived_max_len) - - def _should_use_chunked_processing(self, request) -> bool: - """Check if chunked processing should be used for this request.""" - if not isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest)): - return False - - pooler_config = getattr(self.model_config, 'pooler_config', None) - - # For chunked processing, we always use MEAN aggregation - # for cross-chunk aggregation (native pooling is used within each chunk) - return (pooler_config is not None - and getattr(pooler_config, 'enable_chunked_processing', False)) - - def _chunk_token_ids(self, token_ids: list[int], - chunk_size: int) -> list[list[int]]: - """Split token IDs into chunks of specified size.""" - if len(token_ids) <= chunk_size: - return [token_ids] - - chunks = [] - for i in range(0, len(token_ids), chunk_size): - chunk = token_ids[i:i + chunk_size] - chunks.append(chunk) - return chunks - - async def _process_chunked_request( - self, - ctx: EmbeddingServeContext, - original_prompt: TextTokensPrompt, - pooling_params, - trace_headers, - prompt_idx: int, - ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: - """Process a single prompt using chunked processing.""" - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - token_ids = original_prompt["prompt_token_ids"] - - # Split into chunks using max_position_embeddings - max_pos_embeddings = self._get_max_position_embeddings() - chunks = self._chunk_token_ids(token_ids, max_pos_embeddings) - - # Process all chunks for MEAN aggregation - chunks_to_process = chunks - chunk_indices = list(range(len(chunks))) - logger.info("Using chunked processing with MEAN aggregation") - - for i, (chunk_idx, chunk_tokens) in enumerate( - zip(chunk_indices, chunks_to_process)): - # Create a request ID for this chunk - chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" - f"chunk-{chunk_idx}") - - # Create engine prompt for this chunk - chunk_engine_prompt = EngineTokensPrompt( - prompt_token_ids=chunk_tokens) - - # Create chunk request prompt for logging - chunk_text = "" - chunk_request_prompt = TextTokensPrompt( - prompt=chunk_text, prompt_token_ids=chunk_tokens) - - # Log the chunk - self._log_inputs(chunk_request_id, - chunk_request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) - - # Create generator for this chunk - generator = self.engine_client.encode( - chunk_engine_prompt, - pooling_params, - chunk_request_id, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) - - generators.append(generator) - - return generators - - def _validate_input( - self, - request, - input_ids: list[int], - input_text: str, - ) -> TextTokensPrompt: - """Override to support chunked processing for embedding requests.""" - token_num = len(input_ids) - - # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest)): - # Check if chunked processing is enabled for pooling models - pooler_config = getattr(self.model_config, 'pooler_config', None) - enable_chunked = (pooler_config is not None and getattr( - pooler_config, 'enable_chunked_processing', False)) - - # Get max_embed_len from pooler config if set - max_embed_len = (pooler_config.max_embed_len if pooler_config - and pooler_config.max_embed_len else None) - - # Use max_position_embeddings for chunked processing decisions - max_pos_embeddings = self._get_max_position_embeddings() - - # Determine the effective max length for validation - if max_embed_len is not None: - # Use max_embed_len for validation instead of max_model_len - effective_max_len = max_embed_len - length_type = "maximum embedding input length" - max_length_value = max_embed_len - else: - # Fall back to max_model_len validation (original behavior) - effective_max_len = self.max_model_len - length_type = "maximum context length" - max_length_value = self.max_model_len - - validation_error_msg = ( - "This model's {length_type} is {max_length} tokens. " - "However, you requested {token_num} tokens in the input for " - "embedding generation. Please reduce the length of the input." - ).format(length_type=length_type, - max_length=max_length_value, - token_num=token_num) - - # Check if input exceeds effective max length - if token_num > effective_max_len: - raise ValueError(validation_error_msg) - - # Check for chunked processing - # when exceeding max_position_embeddings - if token_num > max_pos_embeddings: - if enable_chunked: - # Allow long inputs when chunked processing is enabled - logger.info( - "Input length %s exceeds max_position_embeddings " - "%s, will use chunked processing", token_num, - max_pos_embeddings) - else: - raise ValueError( - f"This model's maximum position embeddings length is " - f"{max_pos_embeddings} tokens. However, you requested " - f"{token_num} tokens in the input for embedding " - f"generation. Please reduce the length of the input or " - f"enable chunked processing.") - - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) - - # For other request types, use the parent's implementation - return super()._validate_input(request, input_ids, input_text) - - def _is_text_tokens_prompt(self, prompt) -> bool: - """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) - - async def _prepare_generators( - self, - ctx: ServeContext, - ) -> Optional[ErrorResponse]: - """Override to support chunked processing.""" - ctx = cast(EmbeddingServeContext, ctx) - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], - None]] = [] - - try: - trace_headers = (None if ctx.raw_request is None else await - self._get_trace_headers(ctx.raw_request.headers)) - - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters") - - pooling_params = ctx.request.to_pooling_params() - - # Verify and set the task for pooling params - try: - pooling_params.verify("embed", self.model_config) - except ValueError as e: - return self.create_error_response(str(e)) - - if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") - - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - - # Check if we should use chunked processing - use_chunked = self._should_use_chunked_processing(ctx.request) - - for i, engine_prompt in enumerate(ctx.engine_prompts): - request_prompt = ctx.request_prompts[i] - - # Check if this specific prompt needs chunked processing - max_pos_embeddings = self._get_max_position_embeddings() - if (use_chunked - and self._is_text_tokens_prompt(request_prompt)): - # Cast to TextTokensPrompt since we've - # verified prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, request_prompt) - if len(text_tokens_prompt["prompt_token_ids"] - ) > max_pos_embeddings: - # Use chunked processing for this prompt - chunk_generators = await self._process_chunked_request( - ctx, text_tokens_prompt, pooling_params, - trace_headers, i) - generators.extend(chunk_generators) - continue - - # Normal processing for short prompts or non-token prompts - request_id_item = f"{ctx.request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) - - # Mypy has an existing bug related to inferring the variance - # of TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) - - generators.append(generator) - - from vllm.utils import merge_async_iterators - ctx.result_generator = merge_async_iterators(*generators) - - return None - - except Exception as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - async def _collect_batch( - self, - ctx: ServeContext, - ) -> Optional[ErrorResponse]: - """Collect and aggregate batch results - with support for chunked processing. - - For chunked requests, performs online aggregation to - minimize memory usage. - For regular requests, collects results normally. - """ - ctx = cast(EmbeddingServeContext, ctx) - try: - if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") - - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - - if ctx.result_generator is None: - return self.create_error_response( - "Result generator not available") - - # Check if we used chunked processing - use_chunked = self._should_use_chunked_processing(ctx.request) - - if use_chunked: - # Online aggregation for chunked requests to - # minimize memory usage - # Track aggregation state for each prompt - prompt_aggregators: dict[int, dict[str, Any]] = {} - short_prompts_results: dict[int, PoolingRequestOutput] = {} - - async for result_idx, result in ctx.result_generator: - if "-chunk-" in result.request_id: - # Extract prompt_idx from chunked request_id - parts = result.request_id.split("-") - try: - prompt_idx = int(parts[parts.index("prompt") + 1]) - - # Initialize aggregator for this prompt if needed - if prompt_idx not in prompt_aggregators: - prompt_aggregators[prompt_idx] = { - 'weighted_sum': - None, - 'total_weight': - 0, - 'chunk_count': - 0, - 'request_id': - result.request_id.split("-chunk-")[0] - } - - aggregator = prompt_aggregators[prompt_idx] - - # MEAN pooling with online weighted averaging - # Ensure result is PoolingRequestOutput - # for embedding processing - if not isinstance(result, PoolingRequestOutput): - return self.create_error_response( - f"Expected PoolingRequestOutput for " - f"chunked embedding, got " - f"{type(result).__name__}") - - embedding_data = result.outputs.data - if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor( - embedding_data, dtype=torch.float32) - - if result.prompt_token_ids is None: - return self.create_error_response( - "prompt_token_ids cannot be None for " - "chunked processing") - weight = len(result.prompt_token_ids) - - weighted_embedding = embedding_data.to( - dtype=torch.float32) * weight - - if aggregator['weighted_sum'] is None: - # First chunk - aggregator['weighted_sum'] = weighted_embedding - else: - # Accumulate - current_sum = aggregator['weighted_sum'] - if isinstance(current_sum, torch.Tensor): - aggregator['weighted_sum'] = ( - current_sum + weighted_embedding) - - total_weight = aggregator['total_weight'] - if isinstance(total_weight, (int, float)): - aggregator['total_weight'] = (total_weight + - weight) - - chunk_count = aggregator['chunk_count'] - if isinstance(chunk_count, int): - aggregator['chunk_count'] = chunk_count + 1 - - except (ValueError, IndexError): - return self.create_error_response( - f"Invalid chunk request ID format: " - f"{result.request_id}") - else: - # Non-chunked result - try: - prompt_idx = int(result.request_id.split("-")[-1]) - short_prompts_results[prompt_idx] = cast( - PoolingRequestOutput, result) - except ValueError: - return self.create_error_response( - f"Invalid request ID format: " - f"{result.request_id}") - - # Build final result batch - final_res_batch = [] - - for prompt_idx, request_prompt in enumerate( - ctx.request_prompts): - if prompt_idx in prompt_aggregators: - # Finalize MEAN aggregation for this chunked prompt - aggregator = prompt_aggregators[prompt_idx] - - # Finalize weighted average - weighted_sum = aggregator['weighted_sum'] - total_weight = aggregator['total_weight'] - if (weighted_sum is not None - and isinstance(weighted_sum, torch.Tensor) - and isinstance(total_weight, (int, float)) - and total_weight > 0): - final_embedding = weighted_sum / total_weight - - # Create aggregated result - from vllm.outputs import PoolingOutput - aggregated_output = PoolingOutput( - data=final_embedding) - - # Get original prompt token ids - if self._is_text_tokens_prompt(request_prompt): - text_tokens_prompt = cast( - TextTokensPrompt, request_prompt) - original_token_ids = text_tokens_prompt[ - "prompt_token_ids"] - else: - return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a " - f"text tokens prompt") - - # Ensure request_id is string - request_id = aggregator['request_id'] - if not isinstance(request_id, str): - return self.create_error_response( - f"Invalid request_id type: " - f"{type(request_id)}") - - aggregated_result = PoolingRequestOutput( - request_id=request_id, - outputs=aggregated_output, - prompt_token_ids=original_token_ids, - finished=True, - ) - final_res_batch.append(aggregated_result) - else: - return self.create_error_response( - f"No valid aggregation data for prompt " - f"{prompt_idx}") - - elif prompt_idx in short_prompts_results: - # This was a short prompt - final_res_batch.append( - short_prompts_results[prompt_idx]) - else: - return self.create_error_response( - f"Result not found for prompt {prompt_idx}") - - ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_res_batch) - else: - # Normal processing for non-chunked requests - num_prompts = len(ctx.engine_prompts) - normal_final_res_batch: list[ - Optional[PoolingRequestOutput]] = [None] * num_prompts - - async for result_idx, result in ctx.result_generator: - if result_idx < num_prompts: - # Cast to PoolingRequestOutput for embedding results - normal_final_res_batch[result_idx] = cast( - PoolingRequestOutput, result) - - if None in normal_final_res_batch: - return self.create_error_response( - "Failed to generate results for all prompts") - - final_results = [ - res for res in normal_final_res_batch if res is not None - ] - ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_results) - - return None - - except Exception as e: - return self.create_error_response(str(e)) - class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd" @@ -704,4 +204,4 @@ def _create_pooling_params( except ValueError as e: return self.create_error_response(str(e)) - return pooling_params + return pooling_params \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d74231d7e9d9..71515f6954a7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,7 +16,6 @@ from fastapi import Request from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import TypeIs if sys.version_info >= (3, 12): @@ -513,13 +512,8 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: if (isinstance(message, dict) and "content" in message and isinstance(message["content"], list)): for content_dict in message["content"]: - # Check if content_dict has a "type" key and it's a string - if isinstance(content_dict, dict): - type_value = content_dict.get("type") - if isinstance(type_value, str): - # Split on "_" and take the first part - base_type = type_value.split("_")[0] - message_types.add(base_type) + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) return message_types async def _normalize_prompt_text_to_input( @@ -896,23 +890,12 @@ async def _preprocess_chat( **_chat_template_kwargs, ) else: - # Type check for apply_hf_chat_template which only accepts - # PreTrainedTokenizer or PreTrainedTokenizerFast - if isinstance(tokenizer, - (PreTrainedTokenizer, PreTrainedTokenizerFast)): - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - else: - # For other tokenizer types, we need to handle this differently - # This shouldn't happen in normal operation, but we handle it - # for type safety - raise ValueError( - f"Unsupported tokenizer type for HF chat template: " - f"{type(tokenizer)}") + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) mm_data = await mm_data_future @@ -949,16 +932,9 @@ async def _preprocess_chat( # For MistralTokenizer assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids") - # Type check for decode method - if hasattr(tokenizer, 'decode'): - decoded_prompt = tokenizer.decode(request_prompt) - else: - # Fallback for tokenizers without decode method - raise ValueError( - f"Tokenizer {type(tokenizer)} does not support " - f"decode method") - prompt_inputs = TextTokensPrompt(prompt=decoded_prompt, - prompt_token_ids=request_prompt) + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) @@ -1019,9 +995,7 @@ def _log_inputs( elif isinstance(inputs, list): prompt_token_ids = inputs elif 'prompt_embeds' in inputs: - # Cast to proper type for log_inputs - prompt_embeds = cast(Optional[torch.Tensor], - inputs.get("prompt_embeds")) + prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] @@ -1069,13 +1043,7 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token - - # Type check for decode method - if hasattr(tokenizer, 'decode'): - return tokenizer.decode(token_id) - else: - # Fallback for tokenizers without decode method - return f"token_id:{token_id}" + return tokenizer.decode(token_id) def _is_model_supported(self, model_name: Optional[str]) -> bool: if not model_name: @@ -1104,4 +1072,4 @@ def clamp_prompt_logprobs( for logprob_values in logprob_dict.values(): if logprob_values.logprob == float('-inf'): logprob_values.logprob = -9999.0 - return prompt_logprobs + return prompt_logprobs \ No newline at end of file From 54c79301622baade3334853dc2b538a0f43a93e4 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 06:44:10 +0800 Subject: [PATCH 07/39] Feature: Implementation of Chunk Processing for Embedding Requests of Extensive Texts Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 495 ++++++++++++++++++- vllm/entrypoints/openai/serving_engine.py | 4 +- 2 files changed, 495 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 13b30265499f..13cf4a049c65 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from typing import Final, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, Final, Literal, Optional, Union, cast import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never, override @@ -13,17 +15,20 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext) + ServeContext, + TextTokensPrompt) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput) + PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams logger = init_logger(__name__) @@ -129,6 +134,490 @@ def _build_response( usage=usage, ) + def _get_max_position_embeddings(self) -> int: + """Get the model's effective maximum sequence length for chunking. + + This uses the same logic as vLLM's _get_and_verify_max_len to determine + the actual sequence length limit, + considering both model config and tokenizer config. + When max_model_len is set and smaller than max_position_embeddings, + use max_model_len for chunking. + """ + hf_config = self.model_config.hf_config + + # Start with max_position_embeddings from model config + derived_max_len = getattr(hf_config, 'max_position_embeddings', 512) + + # Get tokenizer config for pooling models (embedding models) + if self.model_config.runner_type == "pooling": + from vllm.transformers_utils.config import try_get_tokenizer_config + tokenizer_config = try_get_tokenizer_config( + self.model_config.tokenizer, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + + # Consider model_max_length in tokenizer_config + # (same logic as _get_and_verify_max_len) + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + 'model_max_length', derived_max_len) + derived_max_len = min(derived_max_len, + tokenizer_model_max_length) + + # Consider max_model_len when it's set and smaller than other limits + # max_model_len is set in OpenAIServing.__init__ + # from model_config.max_model_len + if self.max_model_len is not None: + derived_max_len = min(derived_max_len, self.max_model_len) + + return int(derived_max_len) + + def _should_use_chunked_processing(self, request) -> bool: + """Check if chunked processing should be used for this request.""" + if not isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + return False + + pooler_config = getattr(self.model_config, 'pooler_config', None) + + # For chunked processing, we always use MEAN aggregation + # for cross-chunk aggregation (native pooling is used within each chunk) + return (pooler_config is not None + and getattr(pooler_config, 'enable_chunked_processing', False)) + + def _chunk_token_ids(self, token_ids: list[int], + chunk_size: int) -> list[list[int]]: + """Split token IDs into chunks of specified size.""" + if len(token_ids) <= chunk_size: + return [token_ids] + + chunks = [] + for i in range(0, len(token_ids), chunk_size): + chunk = token_ids[i:i + chunk_size] + chunks.append(chunk) + return chunks + + async def _process_chunked_request( + self, + ctx: EmbeddingServeContext, + original_prompt: TextTokensPrompt, + pooling_params, + trace_headers, + prompt_idx: int, + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + """Process a single prompt using chunked processing.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + token_ids = original_prompt["prompt_token_ids"] + + # Split into chunks using max_position_embeddings + max_pos_embeddings = self._get_max_position_embeddings() + chunks = self._chunk_token_ids(token_ids, max_pos_embeddings) + + # Process all chunks for MEAN aggregation + chunks_to_process = chunks + chunk_indices = list(range(len(chunks))) + logger.info("Using chunked processing with MEAN aggregation") + + for i, (chunk_idx, chunk_tokens) in enumerate( + zip(chunk_indices, chunks_to_process)): + # Create a request ID for this chunk + chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" + f"chunk-{chunk_idx}") + + # Create engine prompt for this chunk + chunk_engine_prompt = EngineTokensPrompt( + prompt_token_ids=chunk_tokens) + + # Create chunk request prompt for logging + chunk_text = "" + chunk_request_prompt = TextTokensPrompt( + prompt=chunk_text, prompt_token_ids=chunk_tokens) + + # Log the chunk + self._log_inputs(chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Create generator for this chunk + generator = self.engine_client.encode( + chunk_engine_prompt, + pooling_params, + chunk_request_id, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + return generators + + def _validate_input( + self, + request, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + """Override to support chunked processing for embedding requests.""" + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + # Check if chunked processing is enabled for pooling models + pooler_config = getattr(self.model_config, 'pooler_config', None) + enable_chunked = (pooler_config is not None and getattr( + pooler_config, 'enable_chunked_processing', False)) + + # Get max_embed_len from pooler config if set + max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + + # Use max_position_embeddings for chunked processing decisions + max_pos_embeddings = self._get_max_position_embeddings() + + # Determine the effective max length for validation + if max_embed_len is not None: + # Use max_embed_len for validation instead of max_model_len + effective_max_len = max_embed_len + length_type = "maximum embedding input length" + max_length_value = max_embed_len + else: + # Fall back to max_model_len validation (original behavior) + effective_max_len = self.max_model_len + length_type = "maximum context length" + max_length_value = self.max_model_len + + validation_error_msg = ( + "This model's {length_type} is {max_length} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input." + ).format(length_type=length_type, + max_length=max_length_value, + token_num=token_num) + + # Check if input exceeds effective max length + if token_num > effective_max_len: + raise ValueError(validation_error_msg) + + # Check for chunked processing + # when exceeding max_position_embeddings + if token_num > max_pos_embeddings: + if enable_chunked: + # Allow long inputs when chunked processing is enabled + logger.info( + "Input length %s exceeds max_position_embeddings " + "%s, will use chunked processing", token_num, + max_pos_embeddings) + else: + raise ValueError( + f"This model's maximum position embeddings length is " + f"{max_pos_embeddings} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input or " + f"enable chunked processing.") + + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # For other request types, use the parent's implementation + return super()._validate_input(request, input_ids, input_text) + + def _is_text_tokens_prompt(self, prompt) -> bool: + """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + @override + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Override to support chunked processing.""" + ctx = cast(EmbeddingServeContext, ctx) + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + # Verify and set the task for pooling params + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_prompt = ctx.request_prompts[i] + + # Check if this specific prompt needs chunked processing + max_pos_embeddings = self._get_max_position_embeddings() + if (use_chunked + and self._is_text_tokens_prompt(request_prompt)): + # Cast to TextTokensPrompt since we've + # verified prompt_token_ids + text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + if len(text_tokens_prompt["prompt_token_ids"] + ) > max_pos_embeddings: + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, text_tokens_prompt, pooling_params, + trace_headers, i) + generators.extend(chunk_generators) + continue + + # Normal processing for short prompts or non-token prompts + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + @override + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect and aggregate batch results + with support for chunked processing. + + For chunked requests, performs online aggregation to + minimize memory usage. + For regular requests, collects results normally. + """ + ctx = cast(EmbeddingServeContext, ctx) + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + if use_chunked: + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': + None, + 'total_weight': + 0, + 'chunk_count': + 0, + 'request_id': + result.request_id.split("-chunk-")[0] + } + + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + embedding_data = result.outputs.data + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor( + embedding_data, dtype=torch.float32) + + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight + + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding + else: + # Accumulate + current_sum = aggregator['weighted_sum'] + if isinstance(current_sum, torch.Tensor): + aggregator['weighted_sum'] = ( + current_sum + weighted_embedding) + + total_weight = aggregator['total_weight'] + if isinstance(total_weight, (int, float)): + aggregator['total_weight'] = (total_weight + + weight) + + chunk_count = aggregator['chunk_count'] + if isinstance(chunk_count, int): + aggregator['chunk_count'] = chunk_count + 1 + + except (ValueError, IndexError): + return self.create_error_response( + f"Invalid chunk request ID format: " + f"{result.request_id}") + else: + # Non-chunked result + try: + prompt_idx = int(result.request_id.split("-")[-1]) + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) + except ValueError: + return self.create_error_response( + f"Invalid request ID format: {result.request_id}") + + # Finalize aggregated results + final_res_batch: list[PoolingRequestOutput] = [] + num_prompts = len(ctx.engine_prompts) + + for prompt_idx in range(num_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + + if (weighted_sum is not None and + isinstance(weighted_sum, torch.Tensor) and + isinstance(total_weight, (int, float)) and + total_weight > 0): + + # Compute final mean embedding + final_embedding = weighted_sum / total_weight + + # Create a PoolingRequestOutput for the aggregated result + from vllm.outputs import EmbeddingOutput + embedding_output = EmbeddingOutput( + embedding=final_embedding.tolist()) + + # Get original prompt token IDs for this prompt + original_prompt = ctx.request_prompts[prompt_idx] + if not self._is_text_tokens_prompt(original_prompt): + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"TextTokensPrompt") + + original_token_ids = cast(TextTokensPrompt, + original_prompt)["prompt_token_ids"] + + pooling_output = PoolingRequestOutput( + request_id=aggregator['request_id'], + prompt_token_ids=original_token_ids, + outputs=embedding_output, + finished=True + ) + + final_res_batch.append(pooling_output) + else: + return self.create_error_response( + f"Failed to aggregate chunks for prompt {prompt_idx}") + elif prompt_idx in short_prompts_results: + final_res_batch.append( + short_prompts_results[prompt_idx]) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) + else: + # Normal processing for non-chunked requests + num_prompts = len(ctx.engine_prompts) + normal_final_res_batch: list[ + Optional[PoolingRequestOutput]] = [None] * num_prompts + + async for result_idx, result in ctx.result_generator: + if result_idx < num_prompts: + # Cast to PoolingRequestOutput for embedding results + normal_final_res_batch[result_idx] = cast( + PoolingRequestOutput, result) + + if None in normal_final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + final_results = [ + res for res in normal_final_res_batch if res is not None + ] + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_results) + + return None + + except Exception as e: + return self.create_error_response(str(e)) + class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 71515f6954a7..e181bd22b835 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -995,7 +995,9 @@ def _log_inputs( elif isinstance(inputs, list): prompt_token_ids = inputs elif 'prompt_embeds' in inputs: - prompt_embeds = inputs.get("prompt_embeds") + # Cast to proper type for log_inputs + prompt_embeds = cast(Optional[torch.Tensor], + inputs.get("prompt_embeds")) else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] From 1ad1ae3b62b8a832eef51aed53b2f49fd92a54f7 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 14:31:22 +0800 Subject: [PATCH 08/39] revert: restore processor.py and registry.py to main branch state - Restore vllm/transformers_utils/processor.py to main branch - Restore vllm/inputs/registry.py to main branch - Ensure all file metadata matches main branch exactly Signed-off-by: x22x22 --- vllm/inputs/registry.py | 2 +- vllm/transformers_utils/processor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 120786e53161..6331a70b469a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -242,4 +242,4 @@ def dummy_data_for_profiling( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), multi_modal_data=dec_data.multi_modal_data, multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) \ No newline at end of file + ) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 5253f29302cd..a630d940b257 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -242,4 +242,4 @@ def cached_image_processor_from_config( revision=model_config.revision, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs), - ) \ No newline at end of file + ) From 35e0aeedc70abe615ce8190846642dd285f0393b Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 14:45:12 +0800 Subject: [PATCH 09/39] Refactor: Enhance the code structure and error handling logic for embedding generation. Signed-off-by: x22x22 --- vllm/config.py | 4 +- vllm/entrypoints/openai/serving_embedding.py | 64 +++++++++++--------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 1a90dfb85021..bdbd5565cca7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3476,11 +3476,11 @@ class PoolerConfig: max_embed_len: Optional[int] = None """ Maximum input length allowed for embedding generation. When set, allows - inputs longer than max_model_len to be accepted for embedding models. + inputs longer than max_embed_len to be accepted for embedding models. This parameter enables accepting long inputs without requiring VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds max_embed_len, it will be handled according to the original max_model_len - validation logic. Defaults to None (use max_model_len validation). + validation logic. Defaults to None (i.e. set to max_model_len). """ def compute_hash(self) -> str: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 13cf4a049c65..6fc152d77167 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -14,6 +14,8 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this docstring +# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, EmbeddingCompletionRequest, EmbeddingRequest, @@ -24,7 +26,9 @@ OpenAIServing, ServeContext, TextTokensPrompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, @@ -216,7 +220,6 @@ async def _process_chunked_request( # Process all chunks for MEAN aggregation chunks_to_process = chunks chunk_indices = list(range(len(chunks))) - logger.info("Using chunked processing with MEAN aggregation") for i, (chunk_idx, chunk_tokens) in enumerate( zip(chunk_indices, chunks_to_process)): @@ -290,12 +293,10 @@ def _validate_input( max_length_value = self.max_model_len validation_error_msg = ( - "This model's {length_type} is {max_length} tokens. " - "However, you requested {token_num} tokens in the input for " - "embedding generation. Please reduce the length of the input." - ).format(length_type=length_type, - max_length=max_length_value, - token_num=token_num) + f"This model's {length_type} is {max_length_value} tokens. " + f"However, you requested {token_num} tokens in the input for " + f"embedding generation. Please reduce the length of the input." + ) # Check if input exceeds effective max length if token_num > effective_max_len: @@ -532,7 +533,8 @@ async def _collect_batch( PoolingRequestOutput, result) except ValueError: return self.create_error_response( - f"Invalid request ID format: {result.request_id}") + f"Invalid request ID " + f"format: {result.request_id}") # Finalize aggregated results final_res_batch: list[PoolingRequestOutput] = [] @@ -542,47 +544,51 @@ async def _collect_batch( if prompt_idx in prompt_aggregators: # Finalize MEAN aggregation for this chunked prompt aggregator = prompt_aggregators[prompt_idx] - + weighted_sum = aggregator['weighted_sum'] total_weight = aggregator['total_weight'] - - if (weighted_sum is not None and - isinstance(weighted_sum, torch.Tensor) and - isinstance(total_weight, (int, float)) and - total_weight > 0): - + + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0): + # Compute final mean embedding final_embedding = weighted_sum / total_weight - - # Create a PoolingRequestOutput for the aggregated result + + # Create a PoolingRequestOutput + # for the aggregated result from vllm.outputs import EmbeddingOutput embedding_output = EmbeddingOutput( embedding=final_embedding.tolist()) - + # Get original prompt token IDs for this prompt original_prompt = ctx.request_prompts[prompt_idx] - if not self._is_text_tokens_prompt(original_prompt): + if not self._is_text_tokens_prompt( + original_prompt): return self.create_error_response( f"Chunked prompt {prompt_idx} is not a " f"TextTokensPrompt") - - original_token_ids = cast(TextTokensPrompt, - original_prompt)["prompt_token_ids"] - + + original_token_ids = cast( + TextTokensPrompt, + original_prompt)["prompt_token_ids"] + pooling_output = PoolingRequestOutput( request_id=aggregator['request_id'], prompt_token_ids=original_token_ids, outputs=embedding_output, - finished=True - ) - + finished=True) + final_res_batch.append(pooling_output) else: return self.create_error_response( - f"Failed to aggregate chunks for prompt {prompt_idx}") + f"Failed to aggregate chunks " + f"for prompt {prompt_idx}") elif prompt_idx in short_prompts_results: final_res_batch.append( - short_prompts_results[prompt_idx]) + cast(PoolingRequestOutput, + short_prompts_results[prompt_idx])) else: return self.create_error_response( f"Result not found for prompt {prompt_idx}") @@ -693,4 +699,4 @@ def _create_pooling_params( except ValueError as e: return self.create_error_response(str(e)) - return pooling_params \ No newline at end of file + return pooling_params From 483be3eb4f21e64140e58c59f70cc83a3e081129 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 14:48:35 +0800 Subject: [PATCH 10/39] Refactor: Enhance the code structure and error handling logic for embedding generation. Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6fc152d77167..3d5569b3096e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -34,6 +34,7 @@ from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams +from vllm.transformers_utils.config import try_get_tokenizer_config logger = init_logger(__name__) @@ -154,7 +155,6 @@ def _get_max_position_embeddings(self) -> int: # Get tokenizer config for pooling models (embedding models) if self.model_config.runner_type == "pooling": - from vllm.transformers_utils.config import try_get_tokenizer_config tokenizer_config = try_get_tokenizer_config( self.model_config.tokenizer, trust_remote_code=self.model_config.trust_remote_code, From d410c3424637a6873f8ea9b13b5dcdaccc26c38c Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 15:23:44 +0800 Subject: [PATCH 11/39] Refactor: Enhance the code structure and error handling logic for embedding generation. Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 3d5569b3096e..5f2062c79c07 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -537,7 +537,8 @@ async def _collect_batch( f"format: {result.request_id}") # Finalize aggregated results - final_res_batch: list[PoolingRequestOutput] = [] + final_res_batch: list[Union[PoolingRequestOutput, + EmbeddingRequestOutput]] = [] num_prompts = len(ctx.engine_prompts) for prompt_idx in range(num_prompts): @@ -574,7 +575,7 @@ async def _collect_batch( TextTokensPrompt, original_prompt)["prompt_token_ids"] - pooling_output = PoolingRequestOutput( + pooling_output = EmbeddingRequestOutput( request_id=aggregator['request_id'], prompt_token_ids=original_token_ids, outputs=embedding_output, From 8880316f3f1e72fe4ca6510895cb98d34f1609c4 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 17:26:42 +0800 Subject: [PATCH 12/39] Feature: Implementation of an automatic chunking mechanism for long text embedding, accompanied by corresponding unit tests. Signed-off-by: x22x22 --- .../openai/test_embedding_long_text.py | 274 ++++++++++++++++++ .../long_text_1500_words.txt | 45 +++ .../long_text_2500_words.txt | 73 +++++ vllm/entrypoints/openai/serving_embedding.py | 23 +- 4 files changed, 408 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/openai/test_embedding_long_text.py create mode 100644 tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt create mode 100644 tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/openai/test_embedding_long_text.py new file mode 100644 index 000000000000..307871997860 --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_long_text.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test cases for long text embedding with automatic chunking mechanism. + +This test suite validates vLLM's automatic chunking functionality for handling +text inputs that exceed the model's maximum token length, specifically targeting +the intfloat/multilingual-e5-small model (max token length: 512). +""" + +import os + +import openai +import pytest +import pytest_asyncio + +from vllm.entrypoints.openai.protocol import EmbeddingResponse + +from ...utils import RemoteOpenAIServer + + +def _load_text_file(filename: str) -> str: + """Load text content from file in the same directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(current_dir, filename) + with open(file_path, encoding='utf-8') as f: + return f.read().strip() + + +MODEL_NAME = "intfloat/multilingual-e5-small" +DTYPE = "bfloat16" + +# Test text: Load text with approximately 1500 words to exceed 1024 tokens +LONG_TEXT_1500_WORDS = _load_text_file( + './test_embedding_long_text_datasets/long_text_1500_words.txt') + +# Test text: Construct text with approximately 2500 words to exceed 2048 tokens +LONG_TEXT_2500_WORDS = _load_text_file( + './test_embedding_long_text_datasets/long_text_2500_words.txt') + + +@pytest.fixture(scope="module") +def server_with_chunked_processing(): + """Start server with automatic chunking processing enabled.""" + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", # Set smaller max_model_len to trigger chunking mechanism + '--override-pooler-config', + ('{"pooling_type": "MEAN", "normalize": true, ' + '"enable_chunked_processing": true, "max_embed_len": 10000}'), + "--gpu-memory-utilization", + "0.8", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_chunked_processing(server_with_chunked_processing): + """Create async client with chunking processing support.""" + async with server_with_chunked_processing.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_1500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~1500 character long text + (~1028 tokens, exceeding 512 token limit).""" + + # Verify text length + # Verify text has sufficient word count (approximately 1500 words) + word_count = len(LONG_TEXT_1500_WORDS.split()) + assert word_count >= 1400, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~1500 words, we expect roughly + # 1024+ tokens (exceeding 512 token limit) + # Should exceed single chunk limit of 512 + assert embeddings.usage.prompt_tokens > 800 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_long_text_embedding_2500_chars( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test embedding processing for ~2500 character long text + (~2048 tokens, requiring multiple chunks).""" + + # Verify text length + # Verify text has sufficient word count (approximately 2500 words) + word_count = len(LONG_TEXT_2500_WORDS.split()) + assert word_count >= 2300, ( + f"Test text word count insufficient: {word_count} words") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_2500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding + ) == 384 # multilingual-e5-small embedding dimension + assert embeddings.usage.completion_tokens == 0 + # Due to chunked processing, token count should + # reflect actual processed tokens + # With ~2500 words, we expect + # roughly 2048+ tokens (requiring multiple chunks) + # Should require multiple chunks for processing + assert embeddings.usage.prompt_tokens > 1500 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # Verify embedding vector validity + embedding_vector = embeddings.data[0].embedding + assert all( + isinstance(x, float) + for x in embedding_vector), "Embedding vector should contain floats" + assert not all( + x == 0 + for x in embedding_vector), "Embedding vector should not be all zeros" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_batch_long_text_embedding( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test batch long text embedding processing.""" + + input_texts = [ + LONG_TEXT_1500_WORDS, + LONG_TEXT_2500_WORDS, + "This is a short text test.", # Short text for comparison + ] + + # Send batch embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=input_texts, + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 # Three input texts + + # Verify each embedding dimension + for i, embedding_data in enumerate(embeddings.data): + assert len(embedding_data.embedding) == 384 + assert embedding_data.index == i + + # Verify embedding vector validity + embedding_vector = embedding_data.embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + # Verify token usage + assert embeddings.usage.completion_tokens == 0 + # Total token count should be very substantial + assert embeddings.usage.prompt_tokens > 1000 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_vs_normal_consistency( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test consistency between chunked and + normal processing (using short text).""" + + # Use a short text within the 512 token limit + short_text = ("Artificial intelligence technology is changing our world, " + "bringing unprecedented opportunities and challenges.") + + # Send embedding request + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[short_text], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 384 + assert embeddings.usage.completion_tokens == 0 + # Short text should not require chunked processing + assert embeddings.usage.prompt_tokens < 512 + assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens + + # ้ชŒ่ฏembeddingๅ‘้‡็š„ๆœ‰ๆ•ˆๆ€ง + embedding_vector = embeddings.data[0].embedding + assert all(isinstance(x, float) for x in embedding_vector) + assert not all(x == 0 for x in embedding_vector) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_chunked_processing_response_format( + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + """Test response format and structure during chunked processing.""" + + # Test with long text to trigger chunking + embedding_response = await client_with_chunked_processing.embeddings.create( + model=model_name, + input=[LONG_TEXT_1500_WORDS], + encoding_format="float", + ) + + # Verify response structure + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert embeddings.data[0].object == "embedding" + assert embeddings.data[0].index == 0 + + # Verify embedding vector properties + embedding_vector = embeddings.data[0].embedding + import math + vector_norm = math.sqrt(sum(x * x for x in embedding_vector)) + # Check that the vector is normalized + # (default behavior for most embedding models) + assert 0.8 < vector_norm < 1.2, ( + f"Vector norm should be reasonable, actual: {vector_norm}") diff --git a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt new file mode 100644 index 000000000000..62c407571e38 --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt @@ -0,0 +1,45 @@ +The development of artificial intelligence technology is profoundly transforming our world in unprecedented ways that continue to reshape every aspect of human civilization and society. From sophisticated machine learning algorithms to complex deep neural networks, from advanced natural language processing systems to cutting-edge computer vision models, artificial intelligence technology has demonstrated tremendous potential across virtually every field of human endeavor and scientific research. In the healthcare and medical sector, artificial intelligence systems can help doctors and medical professionals diagnose diseases more accurately than ever before, improve treatment outcomes significantly, reduce medical errors substantially, and accelerate drug discovery processes dramatically while enhancing patient care quality and reducing healthcare costs. + +In transportation and logistics industries, autonomous driving technology is gradually maturing and becoming more sophisticated with each passing year, offering the potential to dramatically reduce traffic accidents, improve fuel efficiency substantially, optimize route planning intelligently, and revolutionize urban mobility systems completely. The integration of artificial intelligence in transportation extends beyond individual vehicles to encompass entire smart city infrastructures, traffic management systems, and public transportation networks that can adapt dynamically to changing conditions and user demands while minimizing environmental impact and maximizing operational efficiency. + +In education and learning environments, personalized artificial intelligence tutoring systems can create highly specialized and adaptive learning plans based on each individual student's unique characteristics, learning style preferences, cognitive pace, academic strengths, and specific areas requiring improvement. These systems can provide real-time feedback, adjust difficulty levels automatically, and offer customized educational content that maximizes learning effectiveness while maintaining student engagement and motivation throughout the educational process. + +However, the rapid development of artificial intelligence technology also brings numerous complex challenges that society must address thoughtfully and comprehensively, including significant changes in employment structures and job markets, critical privacy protection issues and data security concerns, the persistent risk of algorithmic bias and unfair discrimination, ethical considerations around autonomous decision-making processes, and fundamental questions about human agency and control in an increasingly automated world where machines make decisions that affect human lives. + +We need to pay careful attention to these multifaceted social impacts while simultaneously advancing technological development responsibly, ensuring that artificial intelligence technology can truly benefit all segments of human society rather than exacerbating existing inequalities or creating new forms of digital divide. Future artificial intelligence development requires extensive interdisciplinary cooperation and collaboration, involving not only technical experts and computer scientists, but also sociologists, ethicists, policymakers, economists, legal experts, and professionals from various other fields to build a more intelligent, equitable, fair, and sustainable future world for everyone. + +The economic implications of artificial intelligence adoption are far-reaching and complex, affecting labor markets, productivity levels, wealth distribution, and global competitiveness in ways that require careful analysis and strategic planning. While AI technologies can increase efficiency and create new opportunities for innovation and growth, they also pose significant challenges for workers whose jobs may become automated or fundamentally changed. This necessitates comprehensive retraining programs, educational reforms, and social safety nets to ensure a smooth transition for affected populations and communities. + +Environmental considerations also play a crucial role in AI development, as the computational requirements for training and running large AI models consume significant energy resources and contribute to carbon emissions. Sustainable AI practices, green computing initiatives, and energy-efficient algorithms are becoming increasingly important as the field continues to grow and expand globally, requiring researchers and developers to balance performance with environmental responsibility. + +International cooperation and governance frameworks are essential for managing the global impact of artificial intelligence technologies and ensuring that their development and deployment serve the common good. Cross-border collaboration on AI safety standards, ethical guidelines, and regulatory frameworks can help ensure that AI development proceeds in a manner that benefits humanity as a whole while minimizing potential risks and negative consequences for individuals and societies. + +The future of artificial intelligence holds immense promise for solving complex global challenges, from climate change and healthcare to education and scientific research, but realizing this potential requires careful planning, responsible development practices, and ongoing dialogue between technologists, policymakers, and society at large. We must ensure that AI serves the common good and contributes to human flourishing while addressing concerns about privacy, security, fairness, and human autonomy. + +In the realm of scientific research, artificial intelligence is accelerating discoveries in fields ranging from astronomy and physics to biology and chemistry, enabling researchers to analyze vast datasets, identify patterns, and generate hypotheses at unprecedented scales and speeds. Machine learning models are helping scientists understand complex phenomena, predict outcomes, and design experiments more efficiently than ever before. + +The entertainment industry has also been transformed by AI technologies, with applications in content creation, recommendation systems, and interactive experiences that personalize entertainment for individual users. From AI-generated music and art to sophisticated recommendation algorithms that help users discover new content, artificial intelligence is reshaping how we create, distribute, and consume entertainment media. + +In agriculture and food production, AI systems are optimizing crop yields, reducing waste, and improving sustainability through precision farming techniques, automated monitoring systems, and predictive analytics that help farmers make better decisions about planting, irrigation, and harvesting. These technologies are crucial for addressing global food security challenges and environmental sustainability concerns. + +Financial services have embraced AI for fraud detection, risk assessment, algorithmic trading, and customer service applications that improve efficiency and security while reducing costs. However, these applications also raise important questions about fairness, transparency, and accountability in automated decision-making processes that affect people's financial lives and opportunities. + +As we continue to develop and deploy artificial intelligence technologies, it is essential that we maintain a focus on human values, ethical principles, and social responsibility to ensure that these powerful tools serve humanity's best interests and contribute to a more prosperous, equitable, and sustainable future for all people around the world. + +The manufacturing sector has witnessed remarkable transformations through the integration of artificial intelligence technologies, with smart factories utilizing predictive maintenance systems, quality control algorithms, and automated production processes that enhance efficiency while reducing waste and operational costs. These innovations enable manufacturers to respond more quickly to market demands, customize products for individual customers, and maintain higher standards of quality and safety. + +Retail and e-commerce industries have embraced artificial intelligence for inventory management, demand forecasting, personalized marketing campaigns, and customer service chatbots that provide round-the-clock support to consumers. These applications help businesses optimize their operations, improve customer satisfaction, and increase sales while reducing overhead costs and improving supply chain efficiency. + +In the field of cybersecurity, artificial intelligence systems play increasingly critical roles in detecting threats, preventing attacks, and responding to security incidents in real-time. Machine learning algorithms can identify patterns in network traffic, recognize malicious behavior, and automatically implement protective measures to safeguard sensitive data and critical infrastructure from cyber threats. + +The legal profession has begun incorporating AI tools for document review, legal research, contract analysis, and case prediction, enabling lawyers to work more efficiently and provide better services to their clients. However, these applications also raise important questions about professional responsibility, client confidentiality, and the role of human judgment in legal decision-making processes. + +Space exploration and astronomy have benefited tremendously from artificial intelligence applications that help analyze vast amounts of data from telescopes, satellites, and space missions. AI systems can identify celestial objects, predict astronomical events, and assist in planning complex space missions that would be impossible without advanced computational support. + +The energy sector is leveraging artificial intelligence for grid optimization, renewable energy forecasting, and smart building management systems that reduce energy consumption and improve sustainability. These technologies are essential for transitioning to cleaner energy sources and addressing climate change challenges while maintaining reliable power supplies for growing populations. + +Social media platforms and digital communication technologies rely heavily on artificial intelligence for content moderation, recommendation algorithms, and user engagement optimization. While these systems can enhance user experiences and facilitate global communication, they also raise concerns about privacy, misinformation, and the potential for manipulation of public opinion. + +As artificial intelligence continues to evolve and expand into new domains, it becomes increasingly important to establish robust governance frameworks, ethical guidelines, and regulatory mechanisms that ensure these technologies are developed and deployed responsibly. This requires ongoing collaboration between technologists, policymakers, ethicists, and civil society organizations to address the complex challenges and opportunities presented by AI advancement. + +The future success of artificial intelligence development will depend on our ability to balance innovation with responsibility, ensuring that these powerful technologies serve the common good while respecting human dignity, privacy, and autonomy. By working together across disciplines, sectors, and borders, we can harness the transformative potential of artificial intelligence to create a better world for current and future generations. The choices we make today about how to develop, deploy, and govern AI technologies will shape the trajectory of human civilization for decades to come, making it essential that we proceed with wisdom, caution, and unwavering commitment to human welfare and flourishing in this rapidly evolving technological landscape that demands careful consideration and thoughtful implementation across all sectors and industries worldwide for sustainable progress and long-term societal benefit through responsible innovation and ethical development practices. diff --git a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt new file mode 100644 index 000000000000..8e63f4c4b45e --- /dev/null +++ b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt @@ -0,0 +1,73 @@ +With the continuous and accelerating deepening of globalization processes throughout the modern era, interconnections and interdependencies between countries and regions around the world have become increasingly complex, multifaceted, strategically important, and fundamentally transformative for human civilization. Economic globalization has fundamentally enabled and facilitated the unprecedented free flow of goods, services, financial capital, technological innovations, intellectual property, and human resources across international borders and continental boundaries, thereby promoting sustained prosperity, economic growth, technological advancement, and comprehensive development of the integrated world economy on a scale never before witnessed in human history. + +The dramatic rise and expansion of multinational corporations and transnational enterprises has completely transformed and revolutionized traditional business models, operational frameworks, competitive strategies, market dynamics, and corporate governance structures, making global supply chain management significantly more complex, sophisticated, refined, technologically advanced, and strategically crucial than ever before in the history of international commerce and trade. These organizations now operate across multiple continents, coordinate activities in dozens of countries, and manage supply chains that span thousands of miles and involve millions of workers worldwide. + +Meanwhile, the rapid and revolutionary development of cutting-edge information technology, digital communications infrastructure, data processing capabilities, and computational systems has provided extraordinarily strong, reliable, comprehensive, and scalable technical infrastructure and support systems for accelerated globalization processes, with transformative technologies such as the Internet, mobile communications networks, cloud computing platforms, artificial intelligence systems, machine learning algorithms, blockchain technologies, and quantum computing making information dissemination, knowledge sharing, cross-border collaboration, and international communication dramatically more convenient, efficient, instantaneous, cost-effective, and accessible to people worldwide. + +In this dynamic and rapidly evolving global context, cultural exchanges, artistic collaborations, intellectual interactions, and creative partnerships have also become significantly more frequent, intensive, meaningful, and profoundly impactful on societies worldwide. Diverse cultures from different countries, regions, civilizations, and historical backgrounds continuously collide, interact, merge, synthesize, and influence each other in unprecedented ways, creating numerous innovative cultural phenomena, artistic expressions, creative art forms, literary works, musical compositions, and intellectual movements that transcend traditional geographical, political, and cultural boundaries while enriching human experience and understanding. + +Language learning and multilingual communication have become increasingly important and valuable skills in the global marketplace, and comprehensive multilingual abilities have become absolutely essential competencies for modern global citizens, international professionals, business leaders, diplomats, and academics who seek to participate effectively in the interconnected world economy and global society. The ability to communicate across linguistic and cultural barriers has become a critical factor in personal and professional success in virtually every field of human endeavor. + +Educational internationalization and academic mobility have also emerged as critically important trends and strategic priorities for universities, research institutions, and educational systems worldwide, with growing numbers of students, researchers, scholars, and academics choosing to pursue studies abroad, participate in international exchange programs, and receive education under diverse cultural backgrounds, educational systems, pedagogical approaches, and academic traditions that broaden their perspectives and enhance their global competencies while fostering cross-cultural understanding and cooperation. + +However, globalization has simultaneously brought various negative impacts, unintended consequences, and complex challenges that require careful consideration, thoughtful analysis, strategic planning, and effective management by governments, international organizations, civil society groups, and global leaders. Environmental pollution problems have become more serious, widespread, and threatening to planetary health, and climate change has emerged as a common existential challenge facing all of humanity that requires urgent, coordinated, and sustained global action involving unprecedented levels of international cooperation and commitment. + +Economic inequality and wealth gaps have widened significantly in many regions and countries around the world, and social inequality issues have become increasingly prominent, politically significant, and socially destabilizing, creating tensions between different economic classes, social groups, and demographic segments within societies. The benefits of globalization have not been distributed equally, leading to growing concerns about social justice, economic fairness, and inclusive development that addresses the needs of all people regardless of their economic status, geographic location, or social background. + +Cultural homogenization and the loss of traditional cultural practices, languages, and local customs have also become serious concerns as global media, international brands, and standardized products spread across different societies, potentially eroding cultural diversity and unique local identities that have developed over centuries or millennia. This trend threatens the rich tapestry of human cultural heritage and raises important questions about how to preserve cultural authenticity while embracing beneficial aspects of global integration. + +Technological disruption and automation driven by artificial intelligence, robotics, and advanced manufacturing systems have created new challenges for employment, job security, and workforce development, as traditional jobs become obsolete while new types of work emerge that require different skills, educational backgrounds, and technological competencies. This transformation demands comprehensive retraining programs, educational reforms, and adaptive social policies to help workers navigate the changing landscape of employment opportunities. + +Geopolitical tensions and conflicts have also been influenced by globalization processes, as countries compete for economic advantages, technological leadership, and strategic resources while navigating complex interdependencies that can create both opportunities for cooperation and sources of friction and disagreement. The interconnected nature of the global economy means that conflicts in one region can have far-reaching consequences for countries and communities around the world. + +To address these multifaceted challenges effectively, the international community must work together to develop comprehensive policies, innovative solutions, and collaborative frameworks that harness the benefits of globalization while mitigating its negative impacts and ensuring that the process of global integration serves the interests of all people and contributes to sustainable development, social progress, and human flourishing on a planetary scale. + +The role of international organizations, multilateral institutions, and global governance mechanisms has become increasingly important in managing the complexities of globalization and addressing transnational challenges that no single country can solve alone. Organizations such as the United Nations, World Bank, International Monetary Fund, and World Trade Organization play crucial roles in facilitating cooperation, establishing standards, and providing frameworks for addressing global issues. + +Sustainable development has emerged as a central theme in discussions about globalization, with growing recognition that economic growth must be balanced with environmental protection and social equity to ensure long-term prosperity for all people. The United Nations Sustainable Development Goals provide a comprehensive framework for addressing these interconnected challenges and creating a more sustainable and equitable world. + +The digital revolution has fundamentally transformed how people communicate, work, learn, and interact across borders, creating new opportunities for collaboration and innovation while also raising concerns about privacy, security, and digital divides. The COVID-19 pandemic has accelerated many of these digital transformations, demonstrating both the potential and the limitations of technology-mediated global connections. + +Urbanization and migration patterns have been significantly influenced by globalization, with millions of people moving from rural to urban areas and across national borders in search of better opportunities. This movement of people has created both opportunities for cultural exchange and economic development, as well as challenges related to integration, housing, and social services in destination communities. + +The future of globalization will likely be shaped by emerging technologies, changing geopolitical dynamics, environmental constraints, and evolving social values. Success in navigating these changes will require adaptive governance systems, inclusive economic models, and a commitment to international cooperation that prioritizes the common good over narrow national interests. + +Education and capacity building will play crucial roles in preparing people for the challenges and opportunities of an increasingly interconnected world. This includes not only technical skills and knowledge but also cultural competency, critical thinking abilities, and ethical frameworks for navigating complex global issues. + +Ultimately, the goal of globalization should be to create a world where all people can benefit from increased connectivity, shared knowledge, and collective problem-solving capabilities while maintaining their cultural identities and local communities. Achieving this vision will require ongoing dialogue, cooperation, and commitment from individuals, communities, organizations, and governments around the world. + +The healthcare sector has been profoundly transformed by globalization, with medical knowledge, technologies, and treatments spreading rapidly across borders to benefit patients worldwide. International medical research collaborations have accelerated the development of new drugs, vaccines, and treatment protocols, while telemedicine technologies enable healthcare providers to consult with specialists and share expertise across vast distances. However, globalization has also highlighted significant disparities in healthcare access and quality between developed and developing nations, creating moral imperatives for more equitable distribution of medical resources and knowledge. + +Global supply chains have become increasingly sophisticated and interconnected, enabling companies to source materials, components, and services from multiple countries to optimize costs, quality, and efficiency. This interconnectedness has created unprecedented opportunities for economic development and specialization, allowing countries to focus on their comparative advantages while accessing goods and services from around the world. However, recent global crises have also revealed the vulnerabilities of these complex supply networks, highlighting the need for greater resilience, diversification, and strategic planning in global trade relationships. + +The financial sector has undergone dramatic globalization, with capital markets, banking systems, and investment flows becoming increasingly integrated across national boundaries. This integration has facilitated economic growth, enabled more efficient allocation of capital, and provided opportunities for individuals and businesses to access financial services and investment opportunities worldwide. However, it has also created systemic risks, as financial crises can now spread rapidly across borders, affecting economies and communities far from their origins. + +Cultural globalization has facilitated unprecedented exchanges of ideas, art, literature, music, and entertainment across different societies and civilizations. This cultural cross-pollination has enriched human experience, fostered creativity and innovation, and promoted greater understanding between different peoples and cultures. However, it has also raised concerns about cultural imperialism, the dominance of certain languages and cultural forms, and the potential loss of local traditions and indigenous knowledge systems. + +Environmental challenges have become increasingly global in scope, requiring coordinated international responses to address issues such as climate change, biodiversity loss, ocean pollution, and deforestation. Globalization has both contributed to these environmental problems through increased industrial activity and consumption, while also providing the means for global cooperation and knowledge sharing necessary to address them effectively. + +The role of non-governmental organizations, civil society groups, and international advocacy networks has expanded significantly in the globalized world, enabling these organizations to mobilize resources, coordinate campaigns, and influence policy decisions across multiple countries and regions. These networks have been instrumental in advancing human rights, environmental protection, and social justice causes on a global scale. + +Technological innovation and knowledge transfer have accelerated through globalization, with research and development activities becoming increasingly international and collaborative. Universities, research institutions, and technology companies now routinely engage in cross-border partnerships that combine expertise, resources, and perspectives from different countries and cultures to tackle complex scientific and technological challenges. + +Labor markets have become more globally integrated, with workers increasingly mobile across borders and employment opportunities expanding beyond national boundaries. This has created opportunities for individuals to pursue careers and education in different countries, while also creating challenges related to brain drain, labor standards, and worker protections in an increasingly competitive global marketplace. + +The governance of globalization remains one of the most significant challenges facing the international community, as traditional national governments struggle to regulate and manage economic, social, and environmental processes that transcend their borders. This has led to the development of new forms of global governance, including international organizations, multilateral agreements, and transnational regulatory frameworks that attempt to coordinate policies and standards across different countries and regions. + +Looking toward the future, the trajectory of globalization will likely be shaped by emerging technologies such as artificial intelligence, biotechnology, and renewable energy systems, as well as by evolving geopolitical relationships, demographic changes, and environmental constraints. Successfully managing these changes will require innovative approaches to international cooperation, adaptive governance systems, and a renewed commitment to ensuring that the benefits of global integration are shared more equitably among all people and communities worldwide. + +The COVID-19 pandemic has provided important lessons about the interconnected nature of global systems and the need for better preparedness, coordination, and resilience in the face of global challenges. The pandemic demonstrated both the vulnerabilities of globalized systems and their potential for rapid adaptation and innovation when faced with existential threats. These experiences have highlighted the importance of investing in global health infrastructure, strengthening international cooperation mechanisms, and building more resilient and sustainable economic systems. + +Digital transformation has accelerated dramatically in recent years, fundamentally changing how people work, learn, communicate, and conduct business across borders. The rise of remote work, online education, digital commerce, and virtual collaboration tools has created new possibilities for global integration while also raising questions about digital equity, cybersecurity, and the future of physical communities and workplaces. + +Climate change represents perhaps the greatest long-term challenge facing the globalized world, requiring unprecedented levels of international cooperation and coordination to address effectively. The transition to sustainable energy systems, the development of climate adaptation strategies, and the implementation of carbon reduction policies will require global collaboration on a scale never before attempted in human history. + +The emergence of new economic powers and the shifting balance of global influence will continue to reshape international relationships and governance structures in the coming decades. Managing these transitions peacefully and constructively will require diplomatic skill, mutual understanding, and a commitment to multilateral approaches to global problem-solving. + +Young people around the world are increasingly connected through digital technologies and shared global experiences, creating new opportunities for cross-cultural understanding and collaboration. This generation of global citizens will play a crucial role in shaping the future direction of globalization and addressing the complex challenges facing humanity. + +The development of space technologies and the potential for space exploration and colonization may open entirely new frontiers for human expansion and cooperation, requiring new forms of international governance and collaboration to ensure that space remains a domain for peaceful cooperation rather than conflict and competition. + +Ultimately, the future of globalization will be determined by the choices made by individuals, communities, organizations, and governments around the world. By embracing the positive aspects of global integration while addressing its negative consequences, humanity can work toward a more connected, prosperous, and sustainable future for all people on Earth. + +The path forward requires careful balance between global integration and local autonomy, between economic efficiency and social equity, between technological progress and environmental sustainability. Success will depend on our collective ability to learn from past mistakes, adapt to changing circumstances, and maintain focus on shared human values and common goals. Through continued dialogue, cooperation, and commitment to inclusive development, the global community can navigate the challenges and opportunities of an interconnected world while ensuring that the benefits of globalization reach every corner of our planet and every member of the human family. This vision of inclusive globalization represents both our greatest challenge and our most important opportunity for creating a better world for future generations to inherit and cherish through sustained effort, mutual understanding, cooperation, and dedication to building bridges across cultural, economic, and political divides that have historically separated different peoples and nations around the world while fostering peace, prosperity, and sustainable development for all humanity through collaborative international efforts, shared responsibility, and commitment to creating lasting positive change that benefits everyone equally and fairly across all nations, cultures, and communities through comprehensive global cooperation, mutual respect, understanding, and dedication to building a more equitable and just world for present and future generations to enjoy and thrive in together. diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 5f2062c79c07..7e8be750004e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -488,7 +488,16 @@ async def _collect_batch( f"chunked embedding, got " f"{type(result).__name__}") - embedding_data = result.outputs.data + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, 'data'): + # PoolingOutput case + embedding_data = result.outputs.data + else: + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + if not isinstance(embedding_data, torch.Tensor): embedding_data = torch.tensor( embedding_data, dtype=torch.float32) @@ -559,9 +568,9 @@ async def _collect_batch( # Create a PoolingRequestOutput # for the aggregated result - from vllm.outputs import EmbeddingOutput - embedding_output = EmbeddingOutput( - embedding=final_embedding.tolist()) + from vllm.outputs import PoolingOutput + pooling_output_data = PoolingOutput( + data=final_embedding) # Get original prompt token IDs for this prompt original_prompt = ctx.request_prompts[prompt_idx] @@ -575,13 +584,13 @@ async def _collect_batch( TextTokensPrompt, original_prompt)["prompt_token_ids"] - pooling_output = EmbeddingRequestOutput( + pooling_request_output = PoolingRequestOutput( request_id=aggregator['request_id'], prompt_token_ids=original_token_ids, - outputs=embedding_output, + outputs=pooling_output_data, finished=True) - final_res_batch.append(pooling_output) + final_res_batch.append(pooling_request_output) else: return self.create_error_response( f"Failed to aggregate chunks " From 6e62421690adcf8aac6e2f9994b4bf668e22c0ab Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 6 Aug 2025 20:31:45 +0800 Subject: [PATCH 13/39] Feature: Implementation of an automatic chunking mechanism for long text embedding, accompanied by corresponding unit tests. Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7e8be750004e..d9fc63ed30e5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -493,10 +493,14 @@ async def _collect_batch( if hasattr(result.outputs, 'data'): # PoolingOutput case embedding_data = result.outputs.data - else: + elif hasattr(result.outputs, 'embedding'): # EmbeddingOutput case - # convert embedding list to tensor embedding_data = result.outputs.embedding + else: + return self.create_error_response( + f"Unsupported output type: " + f"{type(result.outputs).__name__}") if not isinstance(embedding_data, torch.Tensor): embedding_data = torch.tensor( From ae380ed6c3687ceb4d17ec03755ef3cc09baea4b Mon Sep 17 00:00:00 2001 From: x22x22 Date: Thu, 7 Aug 2025 00:39:19 +0800 Subject: [PATCH 14/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 51 ++++++++++++------- vllm/entrypoints/openai/serving_engine.py | 53 +++++++++++++++----- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index d9fc63ed30e5..742cfc9e3f9a 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -56,6 +56,15 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Cache chunked processing support to avoid repeated attribute lookups + pooler_config = getattr(self.model_config, 'pooler_config', None) + self.supports_chunked_processing = ( + pooler_config is not None + and getattr(pooler_config, 'enable_chunked_processing', False)) + @override async def _preprocess( self, @@ -182,12 +191,9 @@ def _should_use_chunked_processing(self, request) -> bool: (EmbeddingChatRequest, EmbeddingCompletionRequest)): return False - pooler_config = getattr(self.model_config, 'pooler_config', None) - # For chunked processing, we always use MEAN aggregation # for cross-chunk aggregation (native pooling is used within each chunk) - return (pooler_config is not None - and getattr(pooler_config, 'enable_chunked_processing', False)) + return self.supports_chunked_processing def _chunk_token_ids(self, token_ids: list[int], chunk_size: int) -> list[list[int]]: @@ -219,10 +225,8 @@ async def _process_chunked_request( # Process all chunks for MEAN aggregation chunks_to_process = chunks - chunk_indices = list(range(len(chunks))) - for i, (chunk_idx, chunk_tokens) in enumerate( - zip(chunk_indices, chunks_to_process)): + for chunk_idx, chunk_tokens in enumerate(chunks_to_process): # Create a request ID for this chunk chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" f"chunk-{chunk_idx}") @@ -269,9 +273,10 @@ def _validate_input( if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest)): # Check if chunked processing is enabled for pooling models + enable_chunked = self._should_use_chunked_processing(request) + + # Get pooler config for max_embed_len pooler_config = getattr(self.model_config, 'pooler_config', None) - enable_chunked = (pooler_config is not None and getattr( - pooler_config, 'enable_chunked_processing', False)) # Get max_embed_len from pooler config if set max_embed_len = (pooler_config.max_embed_len if pooler_config @@ -293,14 +298,23 @@ def _validate_input( max_length_value = self.max_model_len validation_error_msg = ( - f"This model's {length_type} is {max_length_value} tokens. " - f"However, you requested {token_num} tokens in the input for " - f"embedding generation. Please reduce the length of the input." - ) + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input.") + + chunked_processing_error_msg = ( + "This model's {length_type} is {max_length_value} tokens. " + "However, you requested {token_num} tokens in the input for " + "embedding generation. Please reduce the length of the input " + "or enable chunked processing.") # Check if input exceeds effective max length if token_num > effective_max_len: - raise ValueError(validation_error_msg) + raise ValueError( + validation_error_msg.format( + length_type=length_type, + max_length_value=max_length_value, + token_num=token_num)) # Check for chunked processing # when exceeding max_position_embeddings @@ -313,11 +327,10 @@ def _validate_input( max_pos_embeddings) else: raise ValueError( - f"This model's maximum position embeddings length is " - f"{max_pos_embeddings} tokens. However, you requested " - f"{token_num} tokens in the input for embedding " - f"generation. Please reduce the length of the input or " - f"enable chunked processing.") + chunked_processing_error_msg.format( + length_type="maximum position embeddings length", + max_length_value=max_pos_embeddings, + token_num=token_num)) return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e181bd22b835..120bf0ec788a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -512,8 +512,12 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: if (isinstance(message, dict) and "content" in message and isinstance(message["content"], list)): for content_dict in message["content"]: - if "type" in content_dict: - message_types.add(content_dict["type"].split("_")[0]) + # Use Any type to handle dynamic dict access safely + if (isinstance(content_dict, dict) + and "type" in content_dict): + content_type = cast(Any, content_dict).get("type") + if isinstance(content_type, str): + message_types.add(content_type.split("_")[0]) return message_types async def _normalize_prompt_text_to_input( @@ -890,12 +894,25 @@ async def _preprocess_chat( **_chat_template_kwargs, ) else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) + # Cast tokenizer to compatible type for apply_hf_chat_template + from transformers import (PreTrainedTokenizer, + PreTrainedTokenizerFast) + if isinstance(tokenizer, + (PreTrainedTokenizer, PreTrainedTokenizerFast)): + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + else: + # Fallback for other tokenizer types + request_prompt = apply_hf_chat_template( + tokenizer=cast(PreTrainedTokenizer, tokenizer), + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) mm_data = await mm_data_future @@ -932,9 +949,14 @@ async def _preprocess_chat( # For MistralTokenizer assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids") - prompt_inputs = TextTokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + # Ensure tokenizer has decode method + if hasattr(tokenizer, 'decode'): + decoded_prompt = tokenizer.decode(request_prompt) + else: + # Fallback for tokenizers without decode method + decoded_prompt = str(request_prompt) + prompt_inputs = TextTokensPrompt(prompt=decoded_prompt, + prompt_token_ids=request_prompt) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) @@ -1045,7 +1067,12 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token - return tokenizer.decode(token_id) + # Ensure tokenizer has decode method + if hasattr(tokenizer, 'decode'): + return tokenizer.decode(token_id) + else: + # Fallback for tokenizers without decode method + return f"token_id:{token_id}" def _is_model_supported(self, model_name: Optional[str]) -> bool: if not model_name: @@ -1074,4 +1101,4 @@ def clamp_prompt_logprobs( for logprob_values in logprob_dict.values(): if logprob_values.logprob == float('-inf'): logprob_values.logprob = -9999.0 - return prompt_logprobs \ No newline at end of file + return prompt_logprobs From 3ce8d47e76f3d80ffbb93888d585449ca6a01a95 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Thu, 7 Aug 2025 09:56:50 +0800 Subject: [PATCH 15/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/config.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 41 ++------------------ 2 files changed, 4 insertions(+), 39 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bdbd5565cca7..71e3b007b6c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5274,4 +5274,4 @@ def update_config(config: DataclassInstanceT, current_value, # type: ignore[type-var] value) processed_overrides[field_name] = value - return replace(config, **processed_overrides) \ No newline at end of file + return replace(config, **processed_overrides) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 742cfc9e3f9a..f2f8c5947b34 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -32,9 +32,8 @@ from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput, RequestOutput) + PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams -from vllm.transformers_utils.config import try_get_tokenizer_config logger = init_logger(__name__) @@ -149,41 +148,8 @@ def _build_response( ) def _get_max_position_embeddings(self) -> int: - """Get the model's effective maximum sequence length for chunking. - - This uses the same logic as vLLM's _get_and_verify_max_len to determine - the actual sequence length limit, - considering both model config and tokenizer config. - When max_model_len is set and smaller than max_position_embeddings, - use max_model_len for chunking. - """ - hf_config = self.model_config.hf_config - - # Start with max_position_embeddings from model config - derived_max_len = getattr(hf_config, 'max_position_embeddings', 512) - - # Get tokenizer config for pooling models (embedding models) - if self.model_config.runner_type == "pooling": - tokenizer_config = try_get_tokenizer_config( - self.model_config.tokenizer, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - - # Consider model_max_length in tokenizer_config - # (same logic as _get_and_verify_max_len) - if tokenizer_config: - tokenizer_model_max_length = tokenizer_config.get( - 'model_max_length', derived_max_len) - derived_max_len = min(derived_max_len, - tokenizer_model_max_length) - - # Consider max_model_len when it's set and smaller than other limits - # max_model_len is set in OpenAIServing.__init__ - # from model_config.max_model_len - if self.max_model_len is not None: - derived_max_len = min(derived_max_len, self.max_model_len) - - return int(derived_max_len) + """Get the model's effective maximum sequence length for chunking.""" + return self.model_config.max_model_len def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" @@ -585,7 +551,6 @@ async def _collect_batch( # Create a PoolingRequestOutput # for the aggregated result - from vllm.outputs import PoolingOutput pooling_output_data = PoolingOutput( data=final_embedding) From 54ad46e3c64b44b2c97a4f6d235c80643df2bf4b Mon Sep 17 00:00:00 2001 From: x22x22 Date: Thu, 7 Aug 2025 09:58:37 +0800 Subject: [PATCH 16/39] Refactoring inelegant code Signed-off-by: x22x22 --- examples/online_serving/openai_embedding_long_text/client.py | 2 -- examples/online_serving/openai_embedding_long_text/service.sh | 1 - 2 files changed, 3 deletions(-) diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py index 7e3663f2854a..6e9838ac6d8d 100644 --- a/examples/online_serving/openai_embedding_long_text/client.py +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -13,7 +13,6 @@ # MEAN pooling (processes all chunks, recommended for complete coverage) vllm serve intfloat/multilingual-e5-large \ - --task embed \ --override-pooler-config \ '{"pooling_type": "MEAN", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ @@ -24,7 +23,6 @@ # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) vllm serve BAAI/bge-large-en-v1.5 \ - --task embed \ --override-pooler-config \ '{"pooling_type": "CLS", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh index 03feb485d6d4..f356d7d4529e 100644 --- a/examples/online_serving/openai_embedding_long_text/service.sh +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -105,7 +105,6 @@ vllm serve "$MODEL_NAME" \ --enforce-eager \ --override-pooler-config "$POOLER_CONFIG" \ --served-model-name ${MODEL_CODE} \ - --task embed \ --api-key "$API_KEY" \ --trust-remote-code \ --port "$PORT" \ From 503ab003a41151873819c190b38e2d74a17ae0b3 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Thu, 7 Aug 2025 10:09:54 +0800 Subject: [PATCH 17/39] Refactoring inelegant code Signed-off-by: x22x22 --- .../openai/test_embedding_long_text.py | 193 ++++++++++++++++-- .../long_text_1500_words.txt | 45 ---- .../long_text_2500_words.txt | 73 ------- 3 files changed, 180 insertions(+), 131 deletions(-) delete mode 100644 tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt delete mode 100644 tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/openai/test_embedding_long_text.py index 307871997860..86bd34abb97e 100644 --- a/tests/entrypoints/openai/test_embedding_long_text.py +++ b/tests/entrypoints/openai/test_embedding_long_text.py @@ -8,7 +8,7 @@ the intfloat/multilingual-e5-small model (max token length: 512). """ -import os +import random import openai import pytest @@ -19,24 +19,191 @@ from ...utils import RemoteOpenAIServer -def _load_text_file(filename: str) -> str: - """Load text content from file in the same directory.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - file_path = os.path.join(current_dir, filename) - with open(file_path, encoding='utf-8') as f: - return f.read().strip() +def _generate_random_text(word_count: int) -> str: + """Generate random text with approximately the specified word count.""" + # Common English words with focus on verbs and nouns for realistic text + common_words = [ + # Essential articles and pronouns (minimal) + "the", + "and", + "you", + "they", + "this", + "that", + "these", + "those", + + # Action verbs + "create", + "build", + "develop", + "design", + "implement", + "execute", + "analyze", + "process", + "generate", + "calculate", + "evaluate", + "optimize", + "transform", + "integrate", + "configure", + "deploy", + "monitor", + "manage", + "discover", + "explore", + "investigate", + "research", + "study", + "examine", + "improve", + "enhance", + "upgrade", + "modify", + "update", + "maintain", + "solve", + "resolve", + "handle", + "address", + "tackle", + "overcome", + "communicate", + "collaborate", + "coordinate", + "organize", + "plan", + "achieve", + "accomplish", + "complete", + "finish", + "deliver", + "provide", + + # Technology and science nouns + "system", + "application", + "software", + "hardware", + "network", + "database", + "algorithm", + "model", + "framework", + "platform", + "interface", + "protocol", + "architecture", + "infrastructure", + "component", + "module", + "service", + "technology", + "innovation", + "solution", + "methodology", + "approach", + "artificial", + "intelligence", + "machine", + "learning", + "neural", + "network", + "computer", + "processor", + "memory", + "storage", + "computation", + "data", + "information", + "knowledge", + "insight", + "pattern", + "trend", + "analysis", + "research", + "development", + "engineering", + "science", + "mathematics", + "statistics", + "probability", + "optimization", + "performance", + "efficiency", + + # General nouns + "project", + "team", + "organization", + "company", + "business", + "industry", + "market", + "customer", + "user", + "client", + "product", + "feature", + "function", + "requirement", + "specification", + "documentation", + "report", + "result", + "outcome", + "impact", + "benefit", + "advantage", + "challenge", + "problem", + "opportunity", + "strategy", + "goal", + "objective", + "target", + "milestone", + "process", + "procedure", + "workflow", + "pipeline", + "operation", + "task", + "activity", + "event", + "session", + "meeting", + "discussion", + "decision" + ] + + words = [] + for _ in range(word_count): + words.append(random.choice(common_words)) + + # Add some punctuation for more realistic text + text = " ".join(words) + # Add periods every 10-20 words + words_list = text.split() + result = [] + for i, word in enumerate(words_list): + result.append(word) + if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1): + result[-1] += "." + + return " ".join(result) MODEL_NAME = "intfloat/multilingual-e5-small" DTYPE = "bfloat16" -# Test text: Load text with approximately 1500 words to exceed 1024 tokens -LONG_TEXT_1500_WORDS = _load_text_file( - './test_embedding_long_text_datasets/long_text_1500_words.txt') +# Test text: Generate text with approximately 1500 words to exceed 1024 tokens +LONG_TEXT_1500_WORDS = _generate_random_text(1500) -# Test text: Construct text with approximately 2500 words to exceed 2048 tokens -LONG_TEXT_2500_WORDS = _load_text_file( - './test_embedding_long_text_datasets/long_text_2500_words.txt') +# Test text: Generate text with approximately 2500 words to exceed 2048 tokens +LONG_TEXT_2500_WORDS = _generate_random_text(2500) @pytest.fixture(scope="module") diff --git a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt deleted file mode 100644 index 62c407571e38..000000000000 --- a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_1500_words.txt +++ /dev/null @@ -1,45 +0,0 @@ -The development of artificial intelligence technology is profoundly transforming our world in unprecedented ways that continue to reshape every aspect of human civilization and society. From sophisticated machine learning algorithms to complex deep neural networks, from advanced natural language processing systems to cutting-edge computer vision models, artificial intelligence technology has demonstrated tremendous potential across virtually every field of human endeavor and scientific research. In the healthcare and medical sector, artificial intelligence systems can help doctors and medical professionals diagnose diseases more accurately than ever before, improve treatment outcomes significantly, reduce medical errors substantially, and accelerate drug discovery processes dramatically while enhancing patient care quality and reducing healthcare costs. - -In transportation and logistics industries, autonomous driving technology is gradually maturing and becoming more sophisticated with each passing year, offering the potential to dramatically reduce traffic accidents, improve fuel efficiency substantially, optimize route planning intelligently, and revolutionize urban mobility systems completely. The integration of artificial intelligence in transportation extends beyond individual vehicles to encompass entire smart city infrastructures, traffic management systems, and public transportation networks that can adapt dynamically to changing conditions and user demands while minimizing environmental impact and maximizing operational efficiency. - -In education and learning environments, personalized artificial intelligence tutoring systems can create highly specialized and adaptive learning plans based on each individual student's unique characteristics, learning style preferences, cognitive pace, academic strengths, and specific areas requiring improvement. These systems can provide real-time feedback, adjust difficulty levels automatically, and offer customized educational content that maximizes learning effectiveness while maintaining student engagement and motivation throughout the educational process. - -However, the rapid development of artificial intelligence technology also brings numerous complex challenges that society must address thoughtfully and comprehensively, including significant changes in employment structures and job markets, critical privacy protection issues and data security concerns, the persistent risk of algorithmic bias and unfair discrimination, ethical considerations around autonomous decision-making processes, and fundamental questions about human agency and control in an increasingly automated world where machines make decisions that affect human lives. - -We need to pay careful attention to these multifaceted social impacts while simultaneously advancing technological development responsibly, ensuring that artificial intelligence technology can truly benefit all segments of human society rather than exacerbating existing inequalities or creating new forms of digital divide. Future artificial intelligence development requires extensive interdisciplinary cooperation and collaboration, involving not only technical experts and computer scientists, but also sociologists, ethicists, policymakers, economists, legal experts, and professionals from various other fields to build a more intelligent, equitable, fair, and sustainable future world for everyone. - -The economic implications of artificial intelligence adoption are far-reaching and complex, affecting labor markets, productivity levels, wealth distribution, and global competitiveness in ways that require careful analysis and strategic planning. While AI technologies can increase efficiency and create new opportunities for innovation and growth, they also pose significant challenges for workers whose jobs may become automated or fundamentally changed. This necessitates comprehensive retraining programs, educational reforms, and social safety nets to ensure a smooth transition for affected populations and communities. - -Environmental considerations also play a crucial role in AI development, as the computational requirements for training and running large AI models consume significant energy resources and contribute to carbon emissions. Sustainable AI practices, green computing initiatives, and energy-efficient algorithms are becoming increasingly important as the field continues to grow and expand globally, requiring researchers and developers to balance performance with environmental responsibility. - -International cooperation and governance frameworks are essential for managing the global impact of artificial intelligence technologies and ensuring that their development and deployment serve the common good. Cross-border collaboration on AI safety standards, ethical guidelines, and regulatory frameworks can help ensure that AI development proceeds in a manner that benefits humanity as a whole while minimizing potential risks and negative consequences for individuals and societies. - -The future of artificial intelligence holds immense promise for solving complex global challenges, from climate change and healthcare to education and scientific research, but realizing this potential requires careful planning, responsible development practices, and ongoing dialogue between technologists, policymakers, and society at large. We must ensure that AI serves the common good and contributes to human flourishing while addressing concerns about privacy, security, fairness, and human autonomy. - -In the realm of scientific research, artificial intelligence is accelerating discoveries in fields ranging from astronomy and physics to biology and chemistry, enabling researchers to analyze vast datasets, identify patterns, and generate hypotheses at unprecedented scales and speeds. Machine learning models are helping scientists understand complex phenomena, predict outcomes, and design experiments more efficiently than ever before. - -The entertainment industry has also been transformed by AI technologies, with applications in content creation, recommendation systems, and interactive experiences that personalize entertainment for individual users. From AI-generated music and art to sophisticated recommendation algorithms that help users discover new content, artificial intelligence is reshaping how we create, distribute, and consume entertainment media. - -In agriculture and food production, AI systems are optimizing crop yields, reducing waste, and improving sustainability through precision farming techniques, automated monitoring systems, and predictive analytics that help farmers make better decisions about planting, irrigation, and harvesting. These technologies are crucial for addressing global food security challenges and environmental sustainability concerns. - -Financial services have embraced AI for fraud detection, risk assessment, algorithmic trading, and customer service applications that improve efficiency and security while reducing costs. However, these applications also raise important questions about fairness, transparency, and accountability in automated decision-making processes that affect people's financial lives and opportunities. - -As we continue to develop and deploy artificial intelligence technologies, it is essential that we maintain a focus on human values, ethical principles, and social responsibility to ensure that these powerful tools serve humanity's best interests and contribute to a more prosperous, equitable, and sustainable future for all people around the world. - -The manufacturing sector has witnessed remarkable transformations through the integration of artificial intelligence technologies, with smart factories utilizing predictive maintenance systems, quality control algorithms, and automated production processes that enhance efficiency while reducing waste and operational costs. These innovations enable manufacturers to respond more quickly to market demands, customize products for individual customers, and maintain higher standards of quality and safety. - -Retail and e-commerce industries have embraced artificial intelligence for inventory management, demand forecasting, personalized marketing campaigns, and customer service chatbots that provide round-the-clock support to consumers. These applications help businesses optimize their operations, improve customer satisfaction, and increase sales while reducing overhead costs and improving supply chain efficiency. - -In the field of cybersecurity, artificial intelligence systems play increasingly critical roles in detecting threats, preventing attacks, and responding to security incidents in real-time. Machine learning algorithms can identify patterns in network traffic, recognize malicious behavior, and automatically implement protective measures to safeguard sensitive data and critical infrastructure from cyber threats. - -The legal profession has begun incorporating AI tools for document review, legal research, contract analysis, and case prediction, enabling lawyers to work more efficiently and provide better services to their clients. However, these applications also raise important questions about professional responsibility, client confidentiality, and the role of human judgment in legal decision-making processes. - -Space exploration and astronomy have benefited tremendously from artificial intelligence applications that help analyze vast amounts of data from telescopes, satellites, and space missions. AI systems can identify celestial objects, predict astronomical events, and assist in planning complex space missions that would be impossible without advanced computational support. - -The energy sector is leveraging artificial intelligence for grid optimization, renewable energy forecasting, and smart building management systems that reduce energy consumption and improve sustainability. These technologies are essential for transitioning to cleaner energy sources and addressing climate change challenges while maintaining reliable power supplies for growing populations. - -Social media platforms and digital communication technologies rely heavily on artificial intelligence for content moderation, recommendation algorithms, and user engagement optimization. While these systems can enhance user experiences and facilitate global communication, they also raise concerns about privacy, misinformation, and the potential for manipulation of public opinion. - -As artificial intelligence continues to evolve and expand into new domains, it becomes increasingly important to establish robust governance frameworks, ethical guidelines, and regulatory mechanisms that ensure these technologies are developed and deployed responsibly. This requires ongoing collaboration between technologists, policymakers, ethicists, and civil society organizations to address the complex challenges and opportunities presented by AI advancement. - -The future success of artificial intelligence development will depend on our ability to balance innovation with responsibility, ensuring that these powerful technologies serve the common good while respecting human dignity, privacy, and autonomy. By working together across disciplines, sectors, and borders, we can harness the transformative potential of artificial intelligence to create a better world for current and future generations. The choices we make today about how to develop, deploy, and govern AI technologies will shape the trajectory of human civilization for decades to come, making it essential that we proceed with wisdom, caution, and unwavering commitment to human welfare and flourishing in this rapidly evolving technological landscape that demands careful consideration and thoughtful implementation across all sectors and industries worldwide for sustainable progress and long-term societal benefit through responsible innovation and ethical development practices. diff --git a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt b/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt deleted file mode 100644 index 8e63f4c4b45e..000000000000 --- a/tests/entrypoints/openai/test_embedding_long_text_datasets/long_text_2500_words.txt +++ /dev/null @@ -1,73 +0,0 @@ -With the continuous and accelerating deepening of globalization processes throughout the modern era, interconnections and interdependencies between countries and regions around the world have become increasingly complex, multifaceted, strategically important, and fundamentally transformative for human civilization. Economic globalization has fundamentally enabled and facilitated the unprecedented free flow of goods, services, financial capital, technological innovations, intellectual property, and human resources across international borders and continental boundaries, thereby promoting sustained prosperity, economic growth, technological advancement, and comprehensive development of the integrated world economy on a scale never before witnessed in human history. - -The dramatic rise and expansion of multinational corporations and transnational enterprises has completely transformed and revolutionized traditional business models, operational frameworks, competitive strategies, market dynamics, and corporate governance structures, making global supply chain management significantly more complex, sophisticated, refined, technologically advanced, and strategically crucial than ever before in the history of international commerce and trade. These organizations now operate across multiple continents, coordinate activities in dozens of countries, and manage supply chains that span thousands of miles and involve millions of workers worldwide. - -Meanwhile, the rapid and revolutionary development of cutting-edge information technology, digital communications infrastructure, data processing capabilities, and computational systems has provided extraordinarily strong, reliable, comprehensive, and scalable technical infrastructure and support systems for accelerated globalization processes, with transformative technologies such as the Internet, mobile communications networks, cloud computing platforms, artificial intelligence systems, machine learning algorithms, blockchain technologies, and quantum computing making information dissemination, knowledge sharing, cross-border collaboration, and international communication dramatically more convenient, efficient, instantaneous, cost-effective, and accessible to people worldwide. - -In this dynamic and rapidly evolving global context, cultural exchanges, artistic collaborations, intellectual interactions, and creative partnerships have also become significantly more frequent, intensive, meaningful, and profoundly impactful on societies worldwide. Diverse cultures from different countries, regions, civilizations, and historical backgrounds continuously collide, interact, merge, synthesize, and influence each other in unprecedented ways, creating numerous innovative cultural phenomena, artistic expressions, creative art forms, literary works, musical compositions, and intellectual movements that transcend traditional geographical, political, and cultural boundaries while enriching human experience and understanding. - -Language learning and multilingual communication have become increasingly important and valuable skills in the global marketplace, and comprehensive multilingual abilities have become absolutely essential competencies for modern global citizens, international professionals, business leaders, diplomats, and academics who seek to participate effectively in the interconnected world economy and global society. The ability to communicate across linguistic and cultural barriers has become a critical factor in personal and professional success in virtually every field of human endeavor. - -Educational internationalization and academic mobility have also emerged as critically important trends and strategic priorities for universities, research institutions, and educational systems worldwide, with growing numbers of students, researchers, scholars, and academics choosing to pursue studies abroad, participate in international exchange programs, and receive education under diverse cultural backgrounds, educational systems, pedagogical approaches, and academic traditions that broaden their perspectives and enhance their global competencies while fostering cross-cultural understanding and cooperation. - -However, globalization has simultaneously brought various negative impacts, unintended consequences, and complex challenges that require careful consideration, thoughtful analysis, strategic planning, and effective management by governments, international organizations, civil society groups, and global leaders. Environmental pollution problems have become more serious, widespread, and threatening to planetary health, and climate change has emerged as a common existential challenge facing all of humanity that requires urgent, coordinated, and sustained global action involving unprecedented levels of international cooperation and commitment. - -Economic inequality and wealth gaps have widened significantly in many regions and countries around the world, and social inequality issues have become increasingly prominent, politically significant, and socially destabilizing, creating tensions between different economic classes, social groups, and demographic segments within societies. The benefits of globalization have not been distributed equally, leading to growing concerns about social justice, economic fairness, and inclusive development that addresses the needs of all people regardless of their economic status, geographic location, or social background. - -Cultural homogenization and the loss of traditional cultural practices, languages, and local customs have also become serious concerns as global media, international brands, and standardized products spread across different societies, potentially eroding cultural diversity and unique local identities that have developed over centuries or millennia. This trend threatens the rich tapestry of human cultural heritage and raises important questions about how to preserve cultural authenticity while embracing beneficial aspects of global integration. - -Technological disruption and automation driven by artificial intelligence, robotics, and advanced manufacturing systems have created new challenges for employment, job security, and workforce development, as traditional jobs become obsolete while new types of work emerge that require different skills, educational backgrounds, and technological competencies. This transformation demands comprehensive retraining programs, educational reforms, and adaptive social policies to help workers navigate the changing landscape of employment opportunities. - -Geopolitical tensions and conflicts have also been influenced by globalization processes, as countries compete for economic advantages, technological leadership, and strategic resources while navigating complex interdependencies that can create both opportunities for cooperation and sources of friction and disagreement. The interconnected nature of the global economy means that conflicts in one region can have far-reaching consequences for countries and communities around the world. - -To address these multifaceted challenges effectively, the international community must work together to develop comprehensive policies, innovative solutions, and collaborative frameworks that harness the benefits of globalization while mitigating its negative impacts and ensuring that the process of global integration serves the interests of all people and contributes to sustainable development, social progress, and human flourishing on a planetary scale. - -The role of international organizations, multilateral institutions, and global governance mechanisms has become increasingly important in managing the complexities of globalization and addressing transnational challenges that no single country can solve alone. Organizations such as the United Nations, World Bank, International Monetary Fund, and World Trade Organization play crucial roles in facilitating cooperation, establishing standards, and providing frameworks for addressing global issues. - -Sustainable development has emerged as a central theme in discussions about globalization, with growing recognition that economic growth must be balanced with environmental protection and social equity to ensure long-term prosperity for all people. The United Nations Sustainable Development Goals provide a comprehensive framework for addressing these interconnected challenges and creating a more sustainable and equitable world. - -The digital revolution has fundamentally transformed how people communicate, work, learn, and interact across borders, creating new opportunities for collaboration and innovation while also raising concerns about privacy, security, and digital divides. The COVID-19 pandemic has accelerated many of these digital transformations, demonstrating both the potential and the limitations of technology-mediated global connections. - -Urbanization and migration patterns have been significantly influenced by globalization, with millions of people moving from rural to urban areas and across national borders in search of better opportunities. This movement of people has created both opportunities for cultural exchange and economic development, as well as challenges related to integration, housing, and social services in destination communities. - -The future of globalization will likely be shaped by emerging technologies, changing geopolitical dynamics, environmental constraints, and evolving social values. Success in navigating these changes will require adaptive governance systems, inclusive economic models, and a commitment to international cooperation that prioritizes the common good over narrow national interests. - -Education and capacity building will play crucial roles in preparing people for the challenges and opportunities of an increasingly interconnected world. This includes not only technical skills and knowledge but also cultural competency, critical thinking abilities, and ethical frameworks for navigating complex global issues. - -Ultimately, the goal of globalization should be to create a world where all people can benefit from increased connectivity, shared knowledge, and collective problem-solving capabilities while maintaining their cultural identities and local communities. Achieving this vision will require ongoing dialogue, cooperation, and commitment from individuals, communities, organizations, and governments around the world. - -The healthcare sector has been profoundly transformed by globalization, with medical knowledge, technologies, and treatments spreading rapidly across borders to benefit patients worldwide. International medical research collaborations have accelerated the development of new drugs, vaccines, and treatment protocols, while telemedicine technologies enable healthcare providers to consult with specialists and share expertise across vast distances. However, globalization has also highlighted significant disparities in healthcare access and quality between developed and developing nations, creating moral imperatives for more equitable distribution of medical resources and knowledge. - -Global supply chains have become increasingly sophisticated and interconnected, enabling companies to source materials, components, and services from multiple countries to optimize costs, quality, and efficiency. This interconnectedness has created unprecedented opportunities for economic development and specialization, allowing countries to focus on their comparative advantages while accessing goods and services from around the world. However, recent global crises have also revealed the vulnerabilities of these complex supply networks, highlighting the need for greater resilience, diversification, and strategic planning in global trade relationships. - -The financial sector has undergone dramatic globalization, with capital markets, banking systems, and investment flows becoming increasingly integrated across national boundaries. This integration has facilitated economic growth, enabled more efficient allocation of capital, and provided opportunities for individuals and businesses to access financial services and investment opportunities worldwide. However, it has also created systemic risks, as financial crises can now spread rapidly across borders, affecting economies and communities far from their origins. - -Cultural globalization has facilitated unprecedented exchanges of ideas, art, literature, music, and entertainment across different societies and civilizations. This cultural cross-pollination has enriched human experience, fostered creativity and innovation, and promoted greater understanding between different peoples and cultures. However, it has also raised concerns about cultural imperialism, the dominance of certain languages and cultural forms, and the potential loss of local traditions and indigenous knowledge systems. - -Environmental challenges have become increasingly global in scope, requiring coordinated international responses to address issues such as climate change, biodiversity loss, ocean pollution, and deforestation. Globalization has both contributed to these environmental problems through increased industrial activity and consumption, while also providing the means for global cooperation and knowledge sharing necessary to address them effectively. - -The role of non-governmental organizations, civil society groups, and international advocacy networks has expanded significantly in the globalized world, enabling these organizations to mobilize resources, coordinate campaigns, and influence policy decisions across multiple countries and regions. These networks have been instrumental in advancing human rights, environmental protection, and social justice causes on a global scale. - -Technological innovation and knowledge transfer have accelerated through globalization, with research and development activities becoming increasingly international and collaborative. Universities, research institutions, and technology companies now routinely engage in cross-border partnerships that combine expertise, resources, and perspectives from different countries and cultures to tackle complex scientific and technological challenges. - -Labor markets have become more globally integrated, with workers increasingly mobile across borders and employment opportunities expanding beyond national boundaries. This has created opportunities for individuals to pursue careers and education in different countries, while also creating challenges related to brain drain, labor standards, and worker protections in an increasingly competitive global marketplace. - -The governance of globalization remains one of the most significant challenges facing the international community, as traditional national governments struggle to regulate and manage economic, social, and environmental processes that transcend their borders. This has led to the development of new forms of global governance, including international organizations, multilateral agreements, and transnational regulatory frameworks that attempt to coordinate policies and standards across different countries and regions. - -Looking toward the future, the trajectory of globalization will likely be shaped by emerging technologies such as artificial intelligence, biotechnology, and renewable energy systems, as well as by evolving geopolitical relationships, demographic changes, and environmental constraints. Successfully managing these changes will require innovative approaches to international cooperation, adaptive governance systems, and a renewed commitment to ensuring that the benefits of global integration are shared more equitably among all people and communities worldwide. - -The COVID-19 pandemic has provided important lessons about the interconnected nature of global systems and the need for better preparedness, coordination, and resilience in the face of global challenges. The pandemic demonstrated both the vulnerabilities of globalized systems and their potential for rapid adaptation and innovation when faced with existential threats. These experiences have highlighted the importance of investing in global health infrastructure, strengthening international cooperation mechanisms, and building more resilient and sustainable economic systems. - -Digital transformation has accelerated dramatically in recent years, fundamentally changing how people work, learn, communicate, and conduct business across borders. The rise of remote work, online education, digital commerce, and virtual collaboration tools has created new possibilities for global integration while also raising questions about digital equity, cybersecurity, and the future of physical communities and workplaces. - -Climate change represents perhaps the greatest long-term challenge facing the globalized world, requiring unprecedented levels of international cooperation and coordination to address effectively. The transition to sustainable energy systems, the development of climate adaptation strategies, and the implementation of carbon reduction policies will require global collaboration on a scale never before attempted in human history. - -The emergence of new economic powers and the shifting balance of global influence will continue to reshape international relationships and governance structures in the coming decades. Managing these transitions peacefully and constructively will require diplomatic skill, mutual understanding, and a commitment to multilateral approaches to global problem-solving. - -Young people around the world are increasingly connected through digital technologies and shared global experiences, creating new opportunities for cross-cultural understanding and collaboration. This generation of global citizens will play a crucial role in shaping the future direction of globalization and addressing the complex challenges facing humanity. - -The development of space technologies and the potential for space exploration and colonization may open entirely new frontiers for human expansion and cooperation, requiring new forms of international governance and collaboration to ensure that space remains a domain for peaceful cooperation rather than conflict and competition. - -Ultimately, the future of globalization will be determined by the choices made by individuals, communities, organizations, and governments around the world. By embracing the positive aspects of global integration while addressing its negative consequences, humanity can work toward a more connected, prosperous, and sustainable future for all people on Earth. - -The path forward requires careful balance between global integration and local autonomy, between economic efficiency and social equity, between technological progress and environmental sustainability. Success will depend on our collective ability to learn from past mistakes, adapt to changing circumstances, and maintain focus on shared human values and common goals. Through continued dialogue, cooperation, and commitment to inclusive development, the global community can navigate the challenges and opportunities of an interconnected world while ensuring that the benefits of globalization reach every corner of our planet and every member of the human family. This vision of inclusive globalization represents both our greatest challenge and our most important opportunity for creating a better world for future generations to inherit and cherish through sustained effort, mutual understanding, cooperation, and dedication to building bridges across cultural, economic, and political divides that have historically separated different peoples and nations around the world while fostering peace, prosperity, and sustainable development for all humanity through collaborative international efforts, shared responsibility, and commitment to creating lasting positive change that benefits everyone equally and fairly across all nations, cultures, and communities through comprehensive global cooperation, mutual respect, understanding, and dedication to building a more equitable and just world for present and future generations to enjoy and thrive in together. From 8949c8f3819cd05303a82e49b839e19b916ad34a Mon Sep 17 00:00:00 2001 From: x22x22 Date: Mon, 11 Aug 2025 10:56:25 +0800 Subject: [PATCH 18/39] Refactoring inelegant code Signed-off-by: x22x22 --- .../openai_embedding_long_text/README.md | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md index dcd66a9fee9d..37bc023596e4 100644 --- a/examples/online_serving/openai_embedding_long_text/README.md +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -4,7 +4,7 @@ This directory contains examples for using vLLM's **chunked processing** feature ## ๐Ÿš€ Quick Start -### 1. Start the Server +### Start the Server Use the provided script to start a vLLM server with chunked processing enabled: @@ -23,7 +23,7 @@ MAX_EMBED_LEN=3072000 \ ./service.sh ``` -### 2. Test Long Text Embedding +### Test Long Text Embedding Run the comprehensive test client: @@ -37,7 +37,7 @@ python client.py |------|-------------| | `service.sh` | Server startup script with chunked processing enabled | | `client.py` | Comprehensive test client for long text embedding | -| `../openai_embedding_client.py` | Basic embedding client (updated with chunked processing info) | +| | Basic embedding client (updated with chunked processing info) | ## โš™๏ธ Configuration @@ -54,7 +54,8 @@ The key parameters for chunked processing are in the `--override-pooler-config`: } ``` -**Note**: `pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length. +!!! note + `pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length. #### Chunked Processing Behavior @@ -166,8 +167,8 @@ INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) ## ๐Ÿ“š Additional Resources -- [Pooling Models Documentation](../../docs/models/pooling_models.md#chunked-processing-for-long-text) -- [Supported Models List](../../docs/models/supported_models.md#text-embedding) +- [Pooling Models Documentation](../models/pooling_models.md#chunked-processing-for-long-text) +- [Supported Models List](../models/supported_models.md#text-embedding) - [Original Feature Documentation](../../README_CHUNKED_PROCESSING.md) ## ๐Ÿค Contributing @@ -193,4 +194,5 @@ The new `max_embed_len` parameter provides: --- -**Note**: Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../../docs/models/supported_models.md#chunked-processing-for-long-text) for the complete list. +!!! note + Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../models/supported_models.md#chunked-processing-for-long-text) for the complete list. \ No newline at end of file From d42419e49b1210d03586deb655862cae797a53cb Mon Sep 17 00:00:00 2001 From: x22x22 Date: Mon, 11 Aug 2025 11:21:09 +0800 Subject: [PATCH 19/39] Refactoring inelegant code Signed-off-by: x22x22 --- .../openai_embedding_long_text/README.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md index 37bc023596e4..60441d2683a2 100644 --- a/examples/online_serving/openai_embedding_long_text/README.md +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -37,7 +37,6 @@ python client.py |------|-------------| | `service.sh` | Server startup script with chunked processing enabled | | `client.py` | Comprehensive test client for long text embedding | -| | Basic embedding client (updated with chunked processing info) | ## โš™๏ธ Configuration @@ -165,12 +164,6 @@ INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) ``` -## ๐Ÿ“š Additional Resources - -- [Pooling Models Documentation](../models/pooling_models.md#chunked-processing-for-long-text) -- [Supported Models List](../models/supported_models.md#text-embedding) -- [Original Feature Documentation](../../README_CHUNKED_PROCESSING.md) - ## ๐Ÿค Contributing To extend chunked processing support to other embedding models: @@ -192,7 +185,4 @@ The new `max_embed_len` parameter provides: - **Clear Error Messages**: Better feedback when inputs exceed limits - **Backward Compatibility**: Existing configurations continue to work ---- -!!! note - Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../models/supported_models.md#chunked-processing-for-long-text) for the complete list. \ No newline at end of file From ac5b69a056bea0864967241a45e43299dc3b9415 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Mon, 11 Aug 2025 11:51:46 +0800 Subject: [PATCH 20/39] Refactoring inelegant code Signed-off-by: x22x22 --- examples/online_serving/openai_embedding_long_text/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md index 60441d2683a2..04edc4680ea0 100644 --- a/examples/online_serving/openai_embedding_long_text/README.md +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -184,5 +184,3 @@ The new `max_embed_len` parameter provides: - **Extreme Length Support**: Process documents with millions of tokens - **Clear Error Messages**: Better feedback when inputs exceed limits - **Backward Compatibility**: Existing configurations continue to work - - From b8fe266091d4db5894b5f6ae02b9588b562ab08f Mon Sep 17 00:00:00 2001 From: x22x22 Date: Tue, 12 Aug 2025 01:19:42 +0800 Subject: [PATCH 21/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f2f8c5947b34..cbcccf8fb4b1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -34,6 +34,7 @@ from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams +from vllm.utils import chunk_list logger = init_logger(__name__) @@ -161,18 +162,6 @@ def _should_use_chunked_processing(self, request) -> bool: # for cross-chunk aggregation (native pooling is used within each chunk) return self.supports_chunked_processing - def _chunk_token_ids(self, token_ids: list[int], - chunk_size: int) -> list[list[int]]: - """Split token IDs into chunks of specified size.""" - if len(token_ids) <= chunk_size: - return [token_ids] - - chunks = [] - for i in range(0, len(token_ids), chunk_size): - chunk = token_ids[i:i + chunk_size] - chunks.append(chunk) - return chunks - async def _process_chunked_request( self, ctx: EmbeddingServeContext, @@ -187,7 +176,7 @@ async def _process_chunked_request( # Split into chunks using max_position_embeddings max_pos_embeddings = self._get_max_position_embeddings() - chunks = self._chunk_token_ids(token_ids, max_pos_embeddings) + chunks = list(chunk_list(token_ids, max_pos_embeddings)) # Process all chunks for MEAN aggregation chunks_to_process = chunks From e9a5d70f9b5b158e9c491f7337398e95af202ef0 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Tue, 12 Aug 2025 01:23:26 +0800 Subject: [PATCH 22/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index cbcccf8fb4b1..f608b158fbcc 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -65,6 +65,10 @@ def __init__(self, *args, **kwargs): pooler_config is not None and getattr(pooler_config, 'enable_chunked_processing', False)) + # Cache max_embed_len to avoid repeated attribute lookups + self.max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + @override async def _preprocess( self, @@ -230,12 +234,8 @@ def _validate_input( # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) - # Get pooler config for max_embed_len - pooler_config = getattr(self.model_config, 'pooler_config', None) - - # Get max_embed_len from pooler config if set - max_embed_len = (pooler_config.max_embed_len if pooler_config - and pooler_config.max_embed_len else None) + # Use cached max_embed_len value + max_embed_len = self.max_embed_len # Use max_position_embeddings for chunked processing decisions max_pos_embeddings = self._get_max_position_embeddings() From d0c1c9eeacfb4dd39af7815a894201e720a8dd87 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Tue, 12 Aug 2025 01:36:01 +0800 Subject: [PATCH 23/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_engine.py | 51 ++++++----------------- 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 64e512c92377..79d8848f4084 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -510,12 +510,8 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: if (isinstance(message, dict) and "content" in message and isinstance(message["content"], list)): for content_dict in message["content"]: - # Use Any type to handle dynamic dict access safely - if (isinstance(content_dict, dict) - and "type" in content_dict): - content_type = cast(Any, content_dict).get("type") - if isinstance(content_type, str): - message_types.add(content_type.split("_")[0]) + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) return message_types async def _normalize_prompt_text_to_input( @@ -892,25 +888,12 @@ async def _preprocess_chat( **_chat_template_kwargs, ) else: - # Cast tokenizer to compatible type for apply_hf_chat_template - from transformers import (PreTrainedTokenizer, - PreTrainedTokenizerFast) - if isinstance(tokenizer, - (PreTrainedTokenizer, PreTrainedTokenizerFast)): - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - else: - # Fallback for other tokenizer types - request_prompt = apply_hf_chat_template( - tokenizer=cast(PreTrainedTokenizer, tokenizer), - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) mm_data = await mm_data_future @@ -947,14 +930,9 @@ async def _preprocess_chat( # For MistralTokenizer assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids") - # Ensure tokenizer has decode method - if hasattr(tokenizer, 'decode'): - decoded_prompt = tokenizer.decode(request_prompt) - else: - # Fallback for tokenizers without decode method - decoded_prompt = str(request_prompt) - prompt_inputs = TextTokensPrompt(prompt=decoded_prompt, - prompt_token_ids=request_prompt) + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) @@ -1121,12 +1099,7 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token - # Ensure tokenizer has decode method - if hasattr(tokenizer, 'decode'): - return tokenizer.decode(token_id) - else: - # Fallback for tokenizers without decode method - return f"token_id:{token_id}" + return tokenizer.decode(token_id) def _is_model_supported(self, model_name: Optional[str]) -> bool: if not model_name: From 4de2c2b346a5761b3fede1e3252fb1544d07b609 Mon Sep 17 00:00:00 2001 From: Kdump Date: Tue, 12 Aug 2025 01:48:47 +0800 Subject: [PATCH 24/39] Update vllm/entrypoints/openai/serving_embedding.py Co-authored-by: Cyrus Leung Signed-off-by: Kdump --- vllm/entrypoints/openai/serving_embedding.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f608b158fbcc..ebbc34c14dd1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -180,12 +180,9 @@ async def _process_chunked_request( # Split into chunks using max_position_embeddings max_pos_embeddings = self._get_max_position_embeddings() - chunks = list(chunk_list(token_ids, max_pos_embeddings)) - # Process all chunks for MEAN aggregation - chunks_to_process = chunks - - for chunk_idx, chunk_tokens in enumerate(chunks_to_process): + for chunk_idx, chunk_tokens in enumerate( + chunk_list(token_ids, max_pos_embeddings)): # Create a request ID for this chunk chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" f"chunk-{chunk_idx}") From dc067f371d290f4eceaee0d272171ae3530100a1 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Tue, 12 Aug 2025 02:08:28 +0800 Subject: [PATCH 25/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index ebbc34c14dd1..ea42502f2696 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -182,7 +182,7 @@ async def _process_chunked_request( max_pos_embeddings = self._get_max_position_embeddings() # Process all chunks for MEAN aggregation for chunk_idx, chunk_tokens in enumerate( - chunk_list(token_ids, max_pos_embeddings)): + chunk_list(token_ids, max_pos_embeddings)): # Create a request ID for this chunk chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" f"chunk-{chunk_idx}") From cf19859a17ca1240bdf68d92d7eee80180ca6d39 Mon Sep 17 00:00:00 2001 From: Kdump Date: Wed, 13 Aug 2025 09:12:38 +0800 Subject: [PATCH 26/39] Update vllm/entrypoints/openai/serving_embedding.py Co-authored-by: Maximilien de Bayser Signed-off-by: Kdump --- vllm/entrypoints/openai/serving_embedding.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index ea42502f2696..2fdce8058cc5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -158,13 +158,7 @@ def _get_max_position_embeddings(self) -> int: def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" - if not isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest)): - return False - - # For chunked processing, we always use MEAN aggregation - # for cross-chunk aggregation (native pooling is used within each chunk) - return self.supports_chunked_processing + return isinstance(request, EmbeddingRequest) and self.supports_chunked_processing async def _process_chunked_request( self, From 8fab60390ac7aefb9af2a5fc885950334456bd15 Mon Sep 17 00:00:00 2001 From: Kdump Date: Wed, 13 Aug 2025 09:12:50 +0800 Subject: [PATCH 27/39] Update vllm/entrypoints/openai/serving_embedding.py Co-authored-by: Maximilien de Bayser Signed-off-by: Kdump --- vllm/entrypoints/openai/serving_embedding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 2fdce8058cc5..4035a1029d2a 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -397,7 +397,11 @@ async def _collect_batch( if ctx.engine_prompts is None: return self.create_error_response( "Engine prompts not available") + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + if not use_chunked: + return await super()._collect_batch(ctx=ctx) if ctx.request_prompts is None: return self.create_error_response( "Request prompts not available") From 8c7d56b21caf500659f8e343508a01930139f965 Mon Sep 17 00:00:00 2001 From: Kdump Date: Wed, 13 Aug 2025 09:13:03 +0800 Subject: [PATCH 28/39] Update vllm/entrypoints/openai/serving_embedding.py Co-authored-by: Maximilien de Bayser Signed-off-by: Kdump --- vllm/entrypoints/openai/serving_embedding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4035a1029d2a..573f530ddd97 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -220,8 +220,7 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest)): + if isinstance(request, EmbeddingRequest): # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) From fa3b69f7f530ca6050aebc43d46f9c07b184f307 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 09:50:41 +0800 Subject: [PATCH 29/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 573f530ddd97..7785b1979ea7 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -17,7 +17,6 @@ # yapf conflicts with isort for this docstring # yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, - EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, @@ -158,7 +157,8 @@ def _get_max_position_embeddings(self) -> int: def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" - return isinstance(request, EmbeddingRequest) and self.supports_chunked_processing + return isinstance( + request, EmbeddingRequest) and self.supports_chunked_processing async def _process_chunked_request( self, @@ -224,21 +224,16 @@ def _validate_input( # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) - # Use cached max_embed_len value - max_embed_len = self.max_embed_len - # Use max_position_embeddings for chunked processing decisions max_pos_embeddings = self._get_max_position_embeddings() # Determine the effective max length for validation - if max_embed_len is not None: + if self.max_embed_len is not None: # Use max_embed_len for validation instead of max_model_len - effective_max_len = max_embed_len length_type = "maximum embedding input length" - max_length_value = max_embed_len + max_length_value = self.max_embed_len else: # Fall back to max_model_len validation (original behavior) - effective_max_len = self.max_model_len length_type = "maximum context length" max_length_value = self.max_model_len @@ -253,8 +248,8 @@ def _validate_input( "embedding generation. Please reduce the length of the input " "or enable chunked processing.") - # Check if input exceeds effective max length - if token_num > effective_max_len: + # Check if input exceeds max length + if token_num > max_length_value: raise ValueError( validation_error_msg.format( length_type=length_type, From 6584107ae16f6912acdef18f19d4aea91f8cb803 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 09:53:56 +0800 Subject: [PATCH 30/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 91 ++++++++++++-------- 1 file changed, 54 insertions(+), 37 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7785b1979ea7..7a0b2349427b 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from typing import Any, Final, Literal, Optional, Union, cast import numpy as np @@ -23,6 +23,7 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, + RequestPrompt, ServeContext, TextTokensPrompt) # yapf: enable @@ -283,6 +284,37 @@ def _is_text_tokens_prompt(self, prompt) -> bool: return (isinstance(prompt, dict) and "prompt_token_ids" in prompt and "prompt_embeds" not in prompt) + async def _create_single_prompt_generator( + self, + ctx: EmbeddingServeContext, + engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], + request_prompt: RequestPrompt, + pooling_params: PoolingParams, + trace_headers: Optional[Mapping[str, str]], + prompt_index: int, + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + """Create a generator for a single prompt using standard processing.""" + request_id_item = f"{ctx.request_id}-{prompt_index}" + + self._log_inputs(request_id_item, + request_prompt, + params=pooling_params, + lora_request=ctx.lora_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + return self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + @override async def _prepare_generators( self, @@ -290,6 +322,15 @@ async def _prepare_generators( ) -> Optional[ErrorResponse]: """Override to support chunked processing.""" ctx = cast(EmbeddingServeContext, ctx) + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + # If no chunked processing needed, delegate to parent class + if not use_chunked: + return await super()._prepare_generators(ctx) + + # Custom logic for chunked processing generators: list[AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]] = [] @@ -298,11 +339,9 @@ async def _prepare_generators( trace_headers = (None if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)) - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters") - - pooling_params = ctx.request.to_pooling_params() + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params # Verify and set the task for pooling params try: @@ -318,21 +357,18 @@ async def _prepare_generators( return self.create_error_response( "Request prompts not available") - # Check if we should use chunked processing - use_chunked = self._should_use_chunked_processing(ctx.request) + max_pos_embeddings = self._get_max_position_embeddings() for i, engine_prompt in enumerate(ctx.engine_prompts): request_prompt = ctx.request_prompts[i] # Check if this specific prompt needs chunked processing - max_pos_embeddings = self._get_max_position_embeddings() - if (use_chunked - and self._is_text_tokens_prompt(request_prompt)): - # Cast to TextTokensPrompt since we've - # verified prompt_token_ids + if self._is_text_tokens_prompt(request_prompt): + # Cast to TextTokensPrompt since we've verified + # prompt_token_ids text_tokens_prompt = cast(TextTokensPrompt, request_prompt) - if len(text_tokens_prompt["prompt_token_ids"] - ) > max_pos_embeddings: + if (len(text_tokens_prompt["prompt_token_ids"]) + > max_pos_embeddings): # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( ctx, text_tokens_prompt, pooling_params, @@ -341,28 +377,9 @@ async def _prepare_generators( continue # Normal processing for short prompts or non-token prompts - request_id_item = f"{ctx.request_id}-{i}" - - self._log_inputs(request_id_item, - request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) - - # Mypy has an existing bug related to inferring the variance - # of TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) - + generator = await self._create_single_prompt_generator( + ctx, engine_prompt, request_prompt, pooling_params, + trace_headers, i) generators.append(generator) from vllm.utils import merge_async_iterators From f4d48ce74b9560d1d9d9d58dbe6aafd3dfdf51e7 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 10:14:25 +0800 Subject: [PATCH 31/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a0b2349427b..2f811ff91511 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -17,6 +17,7 @@ # yapf conflicts with isort for this docstring # yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, @@ -159,7 +160,9 @@ def _get_max_position_embeddings(self) -> int: def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" return isinstance( - request, EmbeddingRequest) and self.supports_chunked_processing + request, + (EmbeddingCompletionRequest, + EmbeddingChatRequest)) and self.supports_chunked_processing async def _process_chunked_request( self, @@ -221,7 +224,8 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, EmbeddingRequest): + if isinstance(request, + (EmbeddingCompletionRequest, EmbeddingChatRequest)): # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) @@ -377,8 +381,12 @@ async def _prepare_generators( continue # Normal processing for short prompts or non-token prompts + # Cast engine_prompt to the expected type for mypy + engine_prompt_typed = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) generator = await self._create_single_prompt_generator( - ctx, engine_prompt, request_prompt, pooling_params, + ctx, engine_prompt_typed, request_prompt, pooling_params, trace_headers, i) generators.append(generator) From 94a7576732ff40e2d2f10bff960a096bff400321 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 10:27:26 +0800 Subject: [PATCH 32/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 316 +++++++++---------- 1 file changed, 144 insertions(+), 172 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 2f811ff91511..4be2557282a2 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -416,11 +416,13 @@ async def _collect_batch( if ctx.engine_prompts is None: return self.create_error_response( "Engine prompts not available") + # Check if we used chunked processing use_chunked = self._should_use_chunked_processing(ctx.request) if not use_chunked: return await super()._collect_batch(ctx=ctx) + if ctx.request_prompts is None: return self.create_error_response( "Request prompts not available") @@ -429,190 +431,160 @@ async def _collect_batch( return self.create_error_response( "Result generator not available") - # Check if we used chunked processing - use_chunked = self._should_use_chunked_processing(ctx.request) + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': None, + 'total_weight': 0, + 'chunk_count': 0, + 'request_id': + result.request_id.split("-chunk-")[0] + } - if use_chunked: - # Online aggregation for chunked requests to - # minimize memory usage - # Track aggregation state for each prompt - prompt_aggregators: dict[int, dict[str, Any]] = {} - short_prompts_results: dict[int, PoolingRequestOutput] = {} - - async for result_idx, result in ctx.result_generator: - if "-chunk-" in result.request_id: - # Extract prompt_idx from chunked request_id - parts = result.request_id.split("-") - try: - prompt_idx = int(parts[parts.index("prompt") + 1]) - - # Initialize aggregator for this prompt if needed - if prompt_idx not in prompt_aggregators: - prompt_aggregators[prompt_idx] = { - 'weighted_sum': - None, - 'total_weight': - 0, - 'chunk_count': - 0, - 'request_id': - result.request_id.split("-chunk-")[0] - } - - aggregator = prompt_aggregators[prompt_idx] - - # MEAN pooling with online weighted averaging - # Ensure result is PoolingRequestOutput - # for embedding processing - if not isinstance(result, PoolingRequestOutput): - return self.create_error_response( - f"Expected PoolingRequestOutput for " - f"chunked embedding, got " - f"{type(result).__name__}") - - # Handle both PoolingOutput and - # EmbeddingOutput types - if hasattr(result.outputs, 'data'): - # PoolingOutput case - embedding_data = result.outputs.data - elif hasattr(result.outputs, 'embedding'): - # EmbeddingOutput case - - # convert embedding list to tensor - embedding_data = result.outputs.embedding - else: - return self.create_error_response( - f"Unsupported output type: " - f"{type(result.outputs).__name__}") - - if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor( - embedding_data, dtype=torch.float32) - - if result.prompt_token_ids is None: - return self.create_error_response( - "prompt_token_ids cannot be None for " - "chunked processing") - weight = len(result.prompt_token_ids) - - weighted_embedding = embedding_data.to( - dtype=torch.float32) * weight - - if aggregator['weighted_sum'] is None: - # First chunk - aggregator['weighted_sum'] = weighted_embedding - else: - # Accumulate - current_sum = aggregator['weighted_sum'] - if isinstance(current_sum, torch.Tensor): - aggregator['weighted_sum'] = ( - current_sum + weighted_embedding) - - total_weight = aggregator['total_weight'] - if isinstance(total_weight, (int, float)): - aggregator['total_weight'] = (total_weight + - weight) - - chunk_count = aggregator['chunk_count'] - if isinstance(chunk_count, int): - aggregator['chunk_count'] = chunk_count + 1 - - except (ValueError, IndexError): + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): return self.create_error_response( - f"Invalid chunk request ID format: " - f"{result.request_id}") - else: - # Non-chunked result - try: - prompt_idx = int(result.request_id.split("-")[-1]) - short_prompts_results[prompt_idx] = cast( - PoolingRequestOutput, result) - except ValueError: + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, 'data'): + # PoolingOutput case + embedding_data = result.outputs.data + elif hasattr(result.outputs, 'embedding'): + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + else: return self.create_error_response( - f"Invalid request ID " - f"format: {result.request_id}") + f"Unsupported output type: " + f"{type(result.outputs).__name__}") - # Finalize aggregated results - final_res_batch: list[Union[PoolingRequestOutput, - EmbeddingRequestOutput]] = [] - num_prompts = len(ctx.engine_prompts) + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor(embedding_data, + dtype=torch.float32) - for prompt_idx in range(num_prompts): - if prompt_idx in prompt_aggregators: - # Finalize MEAN aggregation for this chunked prompt - aggregator = prompt_aggregators[prompt_idx] + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) - weighted_sum = aggregator['weighted_sum'] - total_weight = aggregator['total_weight'] + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight - if (weighted_sum is not None - and isinstance(weighted_sum, torch.Tensor) - and isinstance(total_weight, (int, float)) - and total_weight > 0): - - # Compute final mean embedding - final_embedding = weighted_sum / total_weight - - # Create a PoolingRequestOutput - # for the aggregated result - pooling_output_data = PoolingOutput( - data=final_embedding) - - # Get original prompt token IDs for this prompt - original_prompt = ctx.request_prompts[prompt_idx] - if not self._is_text_tokens_prompt( - original_prompt): - return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a " - f"TextTokensPrompt") - - original_token_ids = cast( - TextTokensPrompt, - original_prompt)["prompt_token_ids"] - - pooling_request_output = PoolingRequestOutput( - request_id=aggregator['request_id'], - prompt_token_ids=original_token_ids, - outputs=pooling_output_data, - finished=True) - - final_res_batch.append(pooling_request_output) + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding else: - return self.create_error_response( - f"Failed to aggregate chunks " - f"for prompt {prompt_idx}") - elif prompt_idx in short_prompts_results: - final_res_batch.append( - cast(PoolingRequestOutput, - short_prompts_results[prompt_idx])) - else: - return self.create_error_response( - f"Result not found for prompt {prompt_idx}") + # Accumulate + current_sum = aggregator['weighted_sum'] + if isinstance(current_sum, torch.Tensor): + aggregator['weighted_sum'] = ( + current_sum + weighted_embedding) - ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_res_batch) - else: - # Normal processing for non-chunked requests - num_prompts = len(ctx.engine_prompts) - normal_final_res_batch: list[ - Optional[PoolingRequestOutput]] = [None] * num_prompts - - async for result_idx, result in ctx.result_generator: - if result_idx < num_prompts: - # Cast to PoolingRequestOutput for embedding results - normal_final_res_batch[result_idx] = cast( + total_weight = aggregator['total_weight'] + if isinstance(total_weight, (int, float)): + aggregator['total_weight'] = (total_weight + + weight) + + chunk_count = aggregator['chunk_count'] + if isinstance(chunk_count, int): + aggregator['chunk_count'] = chunk_count + 1 + + except (ValueError, IndexError): + return self.create_error_response( + f"Invalid chunk request ID format: " + f"{result.request_id}") + else: + # Non-chunked result + try: + prompt_idx = int(result.request_id.split("-")[-1]) + short_prompts_results[prompt_idx] = cast( PoolingRequestOutput, result) + except ValueError: + return self.create_error_response( + f"Invalid request ID " + f"format: {result.request_id}") + + # Finalize aggregated results + final_res_batch: list[Union[PoolingRequestOutput, + EmbeddingRequestOutput]] = [] + num_prompts = len(ctx.engine_prompts) + + for prompt_idx in range(num_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, + (int, float)) and total_weight > 0): + + # Compute final mean embedding + final_embedding = weighted_sum / total_weight + + # Create a PoolingRequestOutput + # for the aggregated result + pooling_output_data = PoolingOutput( + data=final_embedding) + + # Get original prompt token IDs for this prompt + original_prompt = ctx.request_prompts[prompt_idx] + if not self._is_text_tokens_prompt(original_prompt): + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"TextTokensPrompt") + + original_token_ids = cast( + TextTokensPrompt, + original_prompt)["prompt_token_ids"] - if None in normal_final_res_batch: + pooling_request_output = PoolingRequestOutput( + request_id=aggregator['request_id'], + prompt_token_ids=original_token_ids, + outputs=pooling_output_data, + finished=True) + + final_res_batch.append(pooling_request_output) + else: + return self.create_error_response( + f"Failed to aggregate chunks " + f"for prompt {prompt_idx}") + elif prompt_idx in short_prompts_results: + final_res_batch.append( + cast(PoolingRequestOutput, + short_prompts_results[prompt_idx])) + else: return self.create_error_response( - "Failed to generate results for all prompts") - - final_results = [ - res for res in normal_final_res_batch if res is not None - ] - ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_results) + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) return None From 3444141369f836b8981aebfc89abacd0131c85d8 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 10:30:43 +0800 Subject: [PATCH 33/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4be2557282a2..cebf140227c0 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -502,14 +502,8 @@ async def _collect_batch( aggregator['weighted_sum'] = ( current_sum + weighted_embedding) - total_weight = aggregator['total_weight'] - if isinstance(total_weight, (int, float)): - aggregator['total_weight'] = (total_weight + - weight) - - chunk_count = aggregator['chunk_count'] - if isinstance(chunk_count, int): - aggregator['chunk_count'] = chunk_count + 1 + aggregator['total_weight'] += weight + aggregator['chunk_count'] += 1 except (ValueError, IndexError): return self.create_error_response( From ac02136d7940d75e61a751cd309de9df569c8704 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 10:33:03 +0800 Subject: [PATCH 34/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index cebf140227c0..e1ff034ff12f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -497,10 +497,7 @@ async def _collect_batch( aggregator['weighted_sum'] = weighted_embedding else: # Accumulate - current_sum = aggregator['weighted_sum'] - if isinstance(current_sum, torch.Tensor): - aggregator['weighted_sum'] = ( - current_sum + weighted_embedding) + aggregator['weighted_sum'] += weighted_embedding aggregator['total_weight'] += weight aggregator['chunk_count'] += 1 From 17c4317025c9e61c5b131ca6c057a2d78f555e22 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 10:58:35 +0800 Subject: [PATCH 35/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 167 ++++++++++--------- 1 file changed, 86 insertions(+), 81 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e1ff034ff12f..6991438ccd35 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -171,9 +171,10 @@ async def _process_chunked_request( pooling_params, trace_headers, prompt_idx: int, - ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + ) -> list[AsyncGenerator[tuple[int, int, PoolingRequestOutput], None]]: """Process a single prompt using chunked processing.""" - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + generators: list[AsyncGenerator[tuple[int, int, PoolingRequestOutput], + None]] = [] token_ids = original_prompt["prompt_token_ids"] # Split into chunks using max_position_embeddings @@ -200,8 +201,8 @@ async def _process_chunked_request( params=pooling_params, lora_request=ctx.lora_request) - # Create generator for this chunk - generator = self.engine_client.encode( + # Create generator for this chunk and wrap it to return indices + original_generator = self.engine_client.encode( chunk_engine_prompt, pooling_params, chunk_request_id, @@ -210,7 +211,15 @@ async def _process_chunked_request( priority=getattr(ctx.request, "priority", 0), ) - generators.append(generator) + # Wrap the generator to return (prompt_idx, chunk_idx, result) + # Use default parameters to capture loop variables + async def wrapped_generator(gen=original_generator, + p_idx=prompt_idx, + c_idx=chunk_idx): + async for result in gen: + yield (p_idx, c_idx, result) + + generators.append(wrapped_generator()) return generators @@ -296,7 +305,8 @@ async def _create_single_prompt_generator( pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> AsyncGenerator[tuple[int, int, Union[RequestOutput, + PoolingRequestOutput]], None]: """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" @@ -310,7 +320,9 @@ async def _create_single_prompt_generator( # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt) - return self.engine_client.encode( + + # Wrap the original generator to return (prompt_idx, chunk_idx, result) + original_generator = self.engine_client.encode( engine_prompt, pooling_params, request_id_item, @@ -319,6 +331,10 @@ async def _create_single_prompt_generator( priority=getattr(ctx.request, "priority", 0), ) + async for result in original_generator: + # chunk_idx is always 0 for non-chunked + yield (prompt_index, 0, result) + @override async def _prepare_generators( self, @@ -335,10 +351,14 @@ async def _prepare_generators( return await super()._prepare_generators(ctx) # Custom logic for chunked processing - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], + generators: list[AsyncGenerator[tuple[int, int, + Union[RequestOutput, + PoolingRequestOutput]], None]] = [] + # Track which prompts use chunked processing + ctx.chunked_prompts = set() + try: trace_headers = (None if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)) @@ -374,6 +394,7 @@ async def _prepare_generators( if (len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings): # Use chunked processing for this prompt + ctx.chunked_prompts.add(i) # Track chunked prompt chunk_generators = await self._process_chunked_request( ctx, text_tokens_prompt, pooling_params, trace_headers, i) @@ -437,85 +458,69 @@ async def _collect_batch( prompt_aggregators: dict[int, dict[str, Any]] = {} short_prompts_results: dict[int, PoolingRequestOutput] = {} - async for result_idx, result in ctx.result_generator: - if "-chunk-" in result.request_id: - # Extract prompt_idx from chunked request_id - parts = result.request_id.split("-") - try: - prompt_idx = int(parts[parts.index("prompt") + 1]) - - # Initialize aggregator for this prompt if needed - if prompt_idx not in prompt_aggregators: - prompt_aggregators[prompt_idx] = { - 'weighted_sum': None, - 'total_weight': 0, - 'chunk_count': 0, - 'request_id': - result.request_id.split("-chunk-")[0] - } - - aggregator = prompt_aggregators[prompt_idx] - - # MEAN pooling with online weighted averaging - # Ensure result is PoolingRequestOutput - # for embedding processing - if not isinstance(result, PoolingRequestOutput): - return self.create_error_response( - f"Expected PoolingRequestOutput for " - f"chunked embedding, got " - f"{type(result).__name__}") - - # Handle both PoolingOutput and - # EmbeddingOutput types - if hasattr(result.outputs, 'data'): - # PoolingOutput case - embedding_data = result.outputs.data - elif hasattr(result.outputs, 'embedding'): - # EmbeddingOutput case - - # convert embedding list to tensor - embedding_data = result.outputs.embedding - else: - return self.create_error_response( - f"Unsupported output type: " - f"{type(result.outputs).__name__}") + async for prompt_idx, chunk_idx, result in ctx.result_generator: + # This is a chunked result + if prompt_idx in ctx.chunked_prompts: + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + 'weighted_sum': None, + 'total_weight': 0, + 'chunk_count': 0, + 'request_id': result.request_id.split("-chunk-")[0] + } - if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor(embedding_data, - dtype=torch.float32) + aggregator = prompt_aggregators[prompt_idx] - if result.prompt_token_ids is None: - return self.create_error_response( - "prompt_token_ids cannot be None for " - "chunked processing") - weight = len(result.prompt_token_ids) + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, 'data'): + # PoolingOutput case + embedding_data = result.outputs.data + elif hasattr(result.outputs, 'embedding'): + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + else: + return self.create_error_response( + f"Unsupported output type: " + f"{type(result.outputs).__name__}") + + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor(embedding_data, + dtype=torch.float32) - weighted_embedding = embedding_data.to( - dtype=torch.float32) * weight + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) - if aggregator['weighted_sum'] is None: - # First chunk - aggregator['weighted_sum'] = weighted_embedding - else: - # Accumulate - aggregator['weighted_sum'] += weighted_embedding + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight - aggregator['total_weight'] += weight - aggregator['chunk_count'] += 1 + if aggregator['weighted_sum'] is None: + # First chunk + aggregator['weighted_sum'] = weighted_embedding + else: + # Accumulate + aggregator['weighted_sum'] += weighted_embedding - except (ValueError, IndexError): - return self.create_error_response( - f"Invalid chunk request ID format: " - f"{result.request_id}") + aggregator['total_weight'] += weight + aggregator['chunk_count'] += 1 else: - # Non-chunked result - try: - prompt_idx = int(result.request_id.split("-")[-1]) - short_prompts_results[prompt_idx] = cast( - PoolingRequestOutput, result) - except ValueError: - return self.create_error_response( - f"Invalid request ID " - f"format: {result.request_id}") + # Non-chunked result (chunk_idx == 0) + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) # Finalize aggregated results final_res_batch: list[Union[PoolingRequestOutput, From 8866b5d4360a57e5f60b13d6ddf7202f663d55a8 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 12:18:11 +0800 Subject: [PATCH 36/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 57 +++++++++----------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6991438ccd35..e2603f495422 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -171,10 +171,9 @@ async def _process_chunked_request( pooling_params, trace_headers, prompt_idx: int, - ) -> list[AsyncGenerator[tuple[int, int, PoolingRequestOutput], None]]: + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: """Process a single prompt using chunked processing.""" - generators: list[AsyncGenerator[tuple[int, int, PoolingRequestOutput], - None]] = [] + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] token_ids = original_prompt["prompt_token_ids"] # Split into chunks using max_position_embeddings @@ -211,15 +210,7 @@ async def _process_chunked_request( priority=getattr(ctx.request, "priority", 0), ) - # Wrap the generator to return (prompt_idx, chunk_idx, result) - # Use default parameters to capture loop variables - async def wrapped_generator(gen=original_generator, - p_idx=prompt_idx, - c_idx=chunk_idx): - async for result in gen: - yield (p_idx, c_idx, result) - - generators.append(wrapped_generator()) + generators.append(original_generator) return generators @@ -305,8 +296,7 @@ async def _create_single_prompt_generator( pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, - ) -> AsyncGenerator[tuple[int, int, Union[RequestOutput, - PoolingRequestOutput]], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" @@ -321,8 +311,8 @@ async def _create_single_prompt_generator( engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt) - # Wrap the original generator to return (prompt_idx, chunk_idx, result) - original_generator = self.engine_client.encode( + # Return the original generator without wrapping + return self.engine_client.encode( engine_prompt, pooling_params, request_id_item, @@ -331,10 +321,6 @@ async def _create_single_prompt_generator( priority=getattr(ctx.request, "priority", 0), ) - async for result in original_generator: - # chunk_idx is always 0 for non-chunked - yield (prompt_index, 0, result) - @override async def _prepare_generators( self, @@ -351,14 +337,10 @@ async def _prepare_generators( return await super()._prepare_generators(ctx) # Custom logic for chunked processing - generators: list[AsyncGenerator[tuple[int, int, - Union[RequestOutput, - PoolingRequestOutput]], + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], None]] = [] - # Track which prompts use chunked processing - ctx.chunked_prompts = set() - try: trace_headers = (None if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)) @@ -394,7 +376,6 @@ async def _prepare_generators( if (len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings): # Use chunked processing for this prompt - ctx.chunked_prompts.add(i) # Track chunked prompt chunk_generators = await self._process_chunked_request( ctx, text_tokens_prompt, pooling_params, trace_headers, i) @@ -458,9 +439,16 @@ async def _collect_batch( prompt_aggregators: dict[int, dict[str, Any]] = {} short_prompts_results: dict[int, PoolingRequestOutput] = {} - async for prompt_idx, chunk_idx, result in ctx.result_generator: - # This is a chunked result - if prompt_idx in ctx.chunked_prompts: + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + except (ValueError, IndexError): + # Fallback: extract from result_idx if parsing fails + prompt_idx = result_idx + # Initialize aggregator for this prompt if needed if prompt_idx not in prompt_aggregators: prompt_aggregators[prompt_idx] = { @@ -518,7 +506,14 @@ async def _collect_batch( aggregator['total_weight'] += weight aggregator['chunk_count'] += 1 else: - # Non-chunked result (chunk_idx == 0) + # Non-chunked result - extract prompt_idx from request_id + parts = result.request_id.split("-") + try: + # Last part should be prompt index + prompt_idx = int(parts[-1]) + except (ValueError, IndexError): + prompt_idx = result_idx # Fallback to result_idx + short_prompts_results[prompt_idx] = cast( PoolingRequestOutput, result) From b5230ed82cf701fddf1b63ed8cf8c03a09c6ce29 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Aug 2025 08:05:31 +0000 Subject: [PATCH 37/39] Reduce diff Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/serving_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 79d8848f4084..fb9d456df78e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1049,9 +1049,7 @@ def _log_inputs( elif isinstance(inputs, list): prompt_token_ids = inputs elif 'prompt_embeds' in inputs: - # Cast to proper type for log_inputs - prompt_embeds = cast(Optional[torch.Tensor], - inputs.get("prompt_embeds")) + prompt_embeds = inputs.get("prompt_embeds") else: prompt = inputs["prompt"] prompt_token_ids = inputs["prompt_token_ids"] From b362cbd0e91e6732ebcdf08d8773476daa841052 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Aug 2025 08:14:36 +0000 Subject: [PATCH 38/39] Simplify Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/serving_embedding.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e2603f495422..ed69aec20dfe 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -60,13 +60,11 @@ class EmbeddingMixin(OpenAIServing): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Cache chunked processing support to avoid repeated attribute lookups - pooler_config = getattr(self.model_config, 'pooler_config', None) - self.supports_chunked_processing = ( - pooler_config is not None - and getattr(pooler_config, 'enable_chunked_processing', False)) + pooler_config = self.model_config.pooler_config - # Cache max_embed_len to avoid repeated attribute lookups + # Avoid repeated attribute lookups + self.supports_chunked_processing = ( + pooler_config and pooler_config.enable_chunked_processing) self.max_embed_len = (pooler_config.max_embed_len if pooler_config and pooler_config.max_embed_len else None) From d515efdb5dbd88ef3ec1d3463dc4c3326245b08d Mon Sep 17 00:00:00 2001 From: x22x22 Date: Wed, 13 Aug 2025 16:57:38 +0800 Subject: [PATCH 39/39] Refactoring inelegant code Signed-off-by: x22x22 --- vllm/entrypoints/openai/serving_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index ed69aec20dfe..9dcad8e391c6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs): pooler_config = self.model_config.pooler_config # Avoid repeated attribute lookups - self.supports_chunked_processing = ( + self.supports_chunked_processing = bool( pooler_config and pooler_config.enable_chunked_processing) self.max_embed_len = (pooler_config.max_embed_len if pooler_config and pooler_config.max_embed_len else None)