diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index d998c1f73b51..734a817fd1a0 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -4,8 +4,7 @@ set -xu remove_docker_container() { - docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; + docker rm -f tpu-test || true; } trap remove_docker_container EXIT diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index e565d4b24694..9e7b5a546243 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -5,7 +5,6 @@ set -xu remove_docker_container() { docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; } trap remove_docker_container EXIT diff --git a/.buildkite/scripts/tpu/config_v6e_1.env b/.buildkite/scripts/tpu/config_v6e_1.env index 03ec116f698d..c9e3c26571e7 100644 --- a/.buildkite/scripts/tpu/config_v6e_1.env +++ b/.buildkite/scripts/tpu/config_v6e_1.env @@ -1,6 +1,6 @@ # Environment config TEST_NAME=llama8b -CONTAINER_NAME=vllm-tpu +CONTAINER_NAME=tpu-test # vllm config MODEL=meta-llama/Llama-3.1-8B-Instruct diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh index 8959877a3c05..08e36611809d 100755 --- a/.buildkite/scripts/tpu/docker_run_bm.sh +++ b/.buildkite/scripts/tpu/docker_run_bm.sh @@ -12,8 +12,6 @@ source /etc/environment source $ENV_FILE remove_docker_container() { - docker rm -f tpu-test || true; - docker rm -f vllm-tpu || true; docker rm -f $CONTAINER_NAME || true; } diff --git a/.buildkite/scripts/tpu/quantized_v6e_1.env b/.buildkite/scripts/tpu/quantized_v6e_1.env index bab34b3be3b9..bd25c803081a 100644 --- a/.buildkite/scripts/tpu/quantized_v6e_1.env +++ b/.buildkite/scripts/tpu/quantized_v6e_1.env @@ -1,6 +1,6 @@ # Environment config TEST_NAME=llama8bw8a8 -CONTAINER_NAME=vllm-tpu +CONTAINER_NAME=tpu-test # vllm config MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 88e1197d703a..e139c6b30586 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -664,7 +664,7 @@ steps: # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py + - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/test_cutlass_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' @@ -749,7 +749,6 @@ steps: # this test fails consistently. # TODO: investigate and fix - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s models/multimodal/generation/test_maverick.py diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index df26376504b9..82c20ffa6554 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -49,6 +49,7 @@ best_throughput=0 best_max_num_seqs=0 best_num_batched_tokens=0 best_goodput=0 +best_request_rate=0 start_server() { local gpu_memory_utilization=$1 @@ -57,18 +58,35 @@ start_server() { local vllm_log=$4 local profile_dir=$5 - pkill -f vllm + pkill -if vllm - VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \ - --port 8004 \ - --gpu-memory-utilization $gpu_memory_utilization \ - --max-num-seqs $max_num_seqs \ - --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size $TP \ - --enable-prefix-caching \ - --load-format dummy \ - --download-dir "$DOWNLOAD_DIR" \ - --max-model-len $MAX_MODEL_LEN > "$vllm_log" 2>&1 & + # Define the common arguments as a bash array. + # Each argument and its value are separate elements. + local common_args_array=( + "$MODEL" + "--disable-log-requests" + "--port" "8004" + "--gpu-memory-utilization" "$gpu_memory_utilization" + "--max-num-seqs" "$max_num_seqs" + "--max-num-batched-tokens" "$max_num_batched_tokens" + "--tensor-parallel-size" "$TP" + "--enable-prefix-caching" + "--load-format" "dummy" + "--download-dir" "$DOWNLOAD_DIR" + "--max-model-len" "$MAX_MODEL_LEN" + ) + + # Use the array expansion "${common_args_array[@]}" + # This correctly passes each element as a separate argument. + if [[ -n "$profile_dir" ]]; then + # Start server with profiling enabled + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ + vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + else + # Start server without profiling + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \ + vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + fi # wait for 10 minutes... server_started=0 @@ -82,6 +100,7 @@ start_server() { sleep 10 fi done + if (( ! server_started )); then echo "server did not start within 10 minutes. Please check server log at $vllm_log". return 1 @@ -90,37 +109,20 @@ start_server() { fi } -update_best_profile() { - local profile_dir=$1 - local profile_index=$2 - sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort)) - selected_profile_file= - if [[ "$SYSTEM" == "TPU" ]]; then - selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb" - fi - if [[ "$SYSTEM" == "GPU" ]]; then - selected_profile_file="${sorted_paths[$profile_index]}" - fi - rm -f $PROFILE_PATH/* - cp $selected_profile_file $PROFILE_PATH -} - run_benchmark() { local max_num_seqs=$1 local max_num_batched_tokens=$2 local gpu_memory_utilization=$3 echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" - local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}" echo "vllm_log: $vllm_log" echo rm -f $vllm_log - mkdir -p $profile_dir - pkill -f vllm - local profile_index=0 + pkill -if vllm echo "starting server..." - start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir + # Call start_server without a profile_dir to avoid profiling overhead + start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log "" result=$? if [[ "$result" -eq 1 ]]; then echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" @@ -134,7 +136,8 @@ run_benchmark() { # get a basic qps by using request-rate inf bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt" prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) -adjusted_input_len=$(( INPUT_LEN - prefix_len )) + adjusted_input_len=$(( INPUT_LEN - prefix_len )) + # --profile flag is removed from this call vllm bench serve \ --backend vllm \ --model $MODEL \ @@ -148,8 +151,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len )) --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ - --port 8004 \ - --profile &> "$bm_log" + --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') @@ -163,7 +165,6 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len )) # start from request-rate as int(throughput) + 1 request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do - profile_index=$((profile_index+1)) # clear prefix cache curl -X POST http://0.0.0.0:8004/reset_prefix_cache sleep 5 @@ -201,12 +202,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len )) best_max_num_seqs=$max_num_seqs best_num_batched_tokens=$max_num_batched_tokens best_goodput=$goodput - if [[ "$SYSTEM" == "TPU" ]]; then - update_best_profile "$profile_dir/plugins/profile" $profile_index - fi - if [[ "$SYSTEM" == "GPU" ]]; then - update_best_profile "$profile_dir" $profile_index - fi + best_request_rate=$request_rate fi else echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" @@ -215,7 +211,7 @@ adjusted_input_len=$(( INPUT_LEN - prefix_len )) echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" - pkill vllm + pkill -if vllm sleep 10 printf '=%.0s' $(seq 1 20) return 0 @@ -228,7 +224,8 @@ read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST" gpu_memory_utilization=0.98 find_gpu_memory_utilization=0 while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do - start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" + # Pass empty string for profile_dir argument + start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" "" result=$? if [[ "$result" -eq 0 ]]; then find_gpu_memory_utilization=1 @@ -251,5 +248,45 @@ for num_seqs in "${num_seqs_list[@]}"; do done done echo "finish permutations" + +# ================================================================================= +# FINAL PROFILING RUN FOR THE BEST CONFIGURATION +# ================================================================================= +if (( $(echo "$best_throughput > 0" | bc -l) )); then + echo + echo "Benchmark tuning finished. Now running profiling on the best configuration found..." + echo "Best config: max_num_seqs: $best_max_num_seqs, max_num_batched_tokens: $best_num_batched_tokens, throughput: $best_throughput" + echo + + vllm_log="$LOG_FOLDER/vllm_log_BEST_PROFILE.txt" + bm_log="$LOG_FOLDER/bm_log_BEST_PROFILE.txt" + + # Start server with the best params and profiling ENABLED + echo "Starting server for profiling..." + start_server $gpu_memory_utilization $best_max_num_seqs $best_num_batched_tokens "$vllm_log" "$PROFILE_PATH" + + # Run benchmark with the best params and the --profile flag + echo "Running benchmark with profiling..." + prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) + adjusted_input_len=$(( INPUT_LEN - prefix_len )) + vllm bench serve \ + --backend vllm \ + --model $MODEL \ + --dataset-name random \ + --random-input-len $adjusted_input_len \ + --random-output-len $OUTPUT_LEN \ + --ignore-eos \ + --disable-tqdm \ + --request-rate $best_request_rate \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ + --num-prompts 100 \ + --random-prefix-len $prefix_len \ + --port 8004 \ + --profile &> "$bm_log" +else + echo "No configuration met the latency requirements. Skipping final profiling run." +fi +pkill -if vllm echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" diff --git a/benchmarks/kernels/benchmark_trtllm_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py similarity index 99% rename from benchmarks/kernels/benchmark_trtllm_attention.py rename to benchmarks/kernels/benchmark_trtllm_decode_attention.py index 68c48858e61c..77136edca45b 100644 --- a/benchmarks/kernels/benchmark_trtllm_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -41,7 +41,6 @@ def benchmark_decode( device = "cuda" torch.manual_seed(0) - # Currently only HEAD_GRP_SIZE == 8 is supported HEAD_GRP_SIZE = 8 MAX_SEQ_LEN = max_seq_len diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py new file mode 100644 index 000000000000..67bd9aebbcca --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +import random +from datetime import datetime + +import flashinfer +import torch + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_prefill( + num_seqs, + max_seq_len, + page_size=16, + dtype=torch.bfloat16, + kv_layout="HND", + num_kv_heads=8, + kv_cache_dtype="auto", + head_dim=128, + warmup=10, + trials=20, +): + torch.set_default_device("cuda") + torch.manual_seed(0) + + HEAD_GRP_SIZE = 8 + MAX_SEQ_LEN = max_seq_len + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / page_size) + + workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) + + num_qo_heads = num_kv_heads * HEAD_GRP_SIZE + sm_scale = float(1.0 / (head_dim**0.5)) + + q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + q_lens[-1] = MAX_SEQ_LEN + max_q_len = max(q_lens) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum( + torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ), + ] + ) + q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) + + kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] + kv_lens[-1] = MAX_SEQ_LEN + + seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) + k_scale = v_scale = 1.0 + + if kv_cache_dtype.startswith("fp8"): + kv_cache, _ = to_float8(kv_cache) + + output_trtllm = torch.empty(q.shape, dtype=dtype) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + page_size - 1) // page_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % page_size + if kv_last_page_len == 0: + kv_last_page_len = page_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + output_baseline = torch.empty(q.shape, dtype=dtype) + + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=True, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=kv_cache.dtype, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + def baseline_prefill(): + return wrapper.run( + q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline + ) + + def trt_prefill(): + return flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_q_len=max_q_len, + max_kv_len=max_seq_len, + bmm1_scale=k_scale * sm_scale, + bmm2_scale=v_scale, + batch_size=num_seqs, + cum_seq_lens_q=q_indptr, + cum_seq_lens_kv=kv_indptr, + out=output_trtllm, + ) + + trt_mean, trt_std = time_fn(trt_prefill) + baseline_mean, baseline_std = time_fn(baseline_prefill) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trt_mean) / baseline_mean + + print( + f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" + f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" + ) + + # Return results for CSV writing + return { + "num_seqs": num_seqs, + "trt_mean": trt_mean, + "trt_std": trt_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(dtype), + "kv_cache_dtype": kv_cache_dtype, + "page_size": page_size, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "num_seqs", + "trt_mean", + "trt_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "page_size", + "num_kv_heads", + "head_dim", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + print( + "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " + "output_dtype: bfloat16" + ) + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_prefill( + bs, + max_seq_len, + dtype=torch.bfloat16, + kv_cache_dtype="auto", + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index ef45a5fbebf6..4eb4b464a216 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 + GIT_TAG b99f8c821771fd11feb66d5c89661e9858fde359 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 621179a70169..9c0ed1d09572 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -467,6 +467,12 @@ function (define_gpu_extension_target GPU_MOD_NAME) if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) + # Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder + target_include_directories(${GPU_MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc + ${GPU_INCLUDE_DIRECTORIES}) + else() + target_include_directories(${GPU_MOD_NAME} PRIVATE csrc + ${GPU_INCLUDE_DIRECTORIES}) endif() if (GPU_ARCHITECTURES) @@ -482,8 +488,6 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_compile_definitions(${GPU_MOD_NAME} PRIVATE "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") - target_include_directories(${GPU_MOD_NAME} PRIVATE csrc - ${GPU_INCLUDE_DIRECTORIES}) target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 0e1eab66f0b9..5fe5dd04bd89 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,7 +1,8 @@ #include "common.cuh" #include "dispatch_utils.h" - +#include "../vectorization_utils.cuh" #include +#include #ifndef USE_ROCM #include @@ -12,74 +13,127 @@ namespace vllm { template -__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); - scaled_fp8_conversion_vec( - out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); +__global__ void scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; // one token per block + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float inv_scale = 1.0f / (*scale); + + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + inv_scale); + }); } template -__global__ void dynamic_per_token_scaled_fp8_quant_kernel( - fp8_type* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, - const int hidden_size) { - int const tid = threadIdx.x; - int const token_idx = blockIdx.x; +__global__ void segmented_max_reduction_strided( + float* __restrict__ scale, const scalar_t* __restrict__ input, + int hidden_size, int64_t in_row_stride, int64_t num_tokens) { + __shared__ float cache[256]; + const int tid = threadIdx.x; + int64_t token_idx = blockIdx.x; + + // one block per token. Guard in case gridDim.x > num_tokens. + if (token_idx >= num_tokens) { + return; + } - // Use int64 to avoid overflowing an int32 when calculating this offset - int64_t offset = static_cast(token_idx) * hidden_size; - scalar_t const* __restrict__ token_input = &input[offset]; - fp8_type* __restrict__ token_output = &out[offset]; - - // For vectorization, token_input and token_output pointers need to be - // aligned at 32-byte and 16-byte addresses respectively. - bool const can_vectorize = hidden_size % 16 == 0; - - float absmax_val = 0.0f; - if (can_vectorize) { - absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - float const x = static_cast(token_input[i]); - absmax_val = fmaxf(absmax_val, fabsf(x)); + const scalar_t* row_ptr = input + token_idx * in_row_stride; + + // each thread scans elements of the row in a strided fashion. + float thread_max = 0.0f; + for (int e = tid; e < hidden_size; e += blockDim.x) { + float v = fabsf(static_cast(row_ptr[e])); + thread_max = fmaxf(thread_max, v); + } + + cache[tid] = thread_max; + __syncthreads(); + + // parallel reduction to find row max. + for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { + if (tid < offset) { + cache[tid] = fmaxf(cache[tid], cache[tid + offset]); } + __syncthreads(); } + // thread 0 updates global scale (per-tensor) atomically. + if (tid == 0) { + atomicMaxFloat(scale, cache[0] / quant_type_max_v); + } +} + +template +__global__ void scaled_fp8_quant_kernel_strided_dynamic( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float reciprocal_scale = 1.0f / (*scale); + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + reciprocal_scale); + }); +} + +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, float* __restrict__ scale, + const scalar_t* __restrict__ input, const float* __restrict__ scale_ub, + int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; + const int tid = threadIdx.x; + + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t in_offset = static_cast(token_idx) * in_row_stride; + int64_t out_offset = static_cast(token_idx) * out_row_stride; + const scalar_t* token_in = input + in_offset; + fp8_type* token_out = out + out_offset; + + // 1) per-token absmax + float absmax_val = 0.f; + vectorize_read_with_alignment<16>( + token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) { + absmax_val = fmaxf(absmax_val, fabsf(static_cast(v))); + }); + using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ typename BlockReduce::TempStorage tmp; + const float block_max = + BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ float token_scale; if (tid == 0) { - if (scale_ub) { - token_scale = fminf(block_absmax_val_maybe, *scale_ub); - } else { - token_scale = block_absmax_val_maybe; - } - // token scale computation + token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max; token_scale = fmaxf(token_scale / quant_type_max_v, min_scaling_factor::val()); scale[token_idx] = token_scale; } __syncthreads(); - // Note that we don't use inverted scales so we can match FBGemm impl. - if (can_vectorize) { - scaled_fp8_conversion_vec( - token_output, token_input, token_scale, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - token_output[i] = scaled_fp8_conversion( - static_cast(token_input[i]), token_scale); - } - } + // 2) quantize + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + token_scale); + }); } } // namespace vllm @@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::scaled_fp8_quant_kernel + vllm::scaled_fp8_quant_kernel_strided <<>>( out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // scale tensor should be initialised to <=0 before reduction + AT_CUDA_CHECK( + cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::segmented_max_reduction - <<>>(scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel + vllm::segmented_max_reduction_strided + <<>>( + scale.data_ptr(), input.data_ptr(), + hidden_size, in_row_stride, + static_cast(num_tokens)); + + vllm::scaled_fp8_quant_kernel_strided_dynamic <<>>( out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, std::optional const& scale_ub) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - int const block_size = 256; - dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, block_size)); + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, block_size)); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant( VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< + scalar_t, fp8_t><<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + hidden_size, in_row_stride, out_row_stride); }); }); } diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d36f94a8f10d..1aad6330c44b 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -55,111 +55,4 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, #endif } -// Compute the absolute maximum m of the input tensor and store -// m / float8_e4m3::max() in *scale. Each thread block performs a -// reduction tree and the memory in scale is atomically updated. -// So to get the right answer, *scale needs to be initialized to -// a value <= 0.0 and we need to wait for all thread blocks to -// finish before consuming *scale. -template -__global__ void segmented_max_reduction(float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { - __shared__ float cache[256]; - int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - // First store maximum for all values processes by - // the current thread in cache[threadIdx.x] - scalar_t tmp = 0.0; - while (i < num_elems) { - float x = static_cast(input[i]); - tmp = fmaxf(tmp, fabsf(x)); - i += blockDim.x * gridDim.x; - } - cache[threadIdx.x] = tmp; - - __syncthreads(); - - // Now perform parallel reduction within the thread block - int ib = blockDim.x / 2; - while (ib != 0) { - if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; - } - __syncthreads(); - ib /= 2; - } - // Finally, since cache[0] contains the maximum for this thread block, - // atomically write the max to the target location - if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / quant_type_max_v); - } -} - -template -__device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, int const tid, - int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - float absmax_val = 0.0f; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j])); - } - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - absmax_val = fmaxf(absmax_val, fabsf(input[i])); - } - - return absmax_val; -} - -template -__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, - scalar_t const* __restrict__ input, - float const scale, - int64_t const num_elems, - int const tid, int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - using float8xN_t = q8_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - auto* vectorized_out = reinterpret_cast(out); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; - float8xN_t out_vec; - -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - out_vec.val[j] = scaled_fp8_conversion( - static_cast(in_vec.val[j]), scale); - } - vectorized_out[i] = out_vec; - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( - static_cast(input[i]), scale); - } -} - } // namespace vllm diff --git a/docker/Dockerfile b/docker/Dockerfile index 0d6afca74e86..d444087a3eff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -119,6 +119,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 # as it was causing spam when compiling the CUTLASS kernels @@ -181,6 +183,8 @@ COPY requirements/build.txt requirements/build.txt # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/build.txt \ @@ -272,6 +276,8 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt @@ -341,6 +347,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully @@ -384,7 +392,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" # Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt # We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. -ARG FLASHINFER_GIT_REF="v0.2.9rc2" +ARG FLASHINFER_GIT_REF="v0.2.9" RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment git clone --depth 1 --recursive --shallow-submodules \ @@ -472,6 +480,8 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docs/getting_started/installation/google_tpu.md b/docs/getting_started/installation/google_tpu.md index 55d69d11fa40..6f09babb3aba 100644 --- a/docs/getting_started/installation/google_tpu.md +++ b/docs/getting_started/installation/google_tpu.md @@ -85,7 +85,7 @@ gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ | PROJECT_ID | Your Google Cloud project | | ZONE | The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see [TPU regions and zones] | | ACCELERATOR_TYPE | The TPU version you want to use. Specify the TPU version, for example `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, see [TPU versions]. | -| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). For more information see [TPU VM images]. | +| RUNTIME_VERSION | The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). | | SERVICE_ACCOUNT | The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: `tpu-service-account@.iam.gserviceaccount.com` | Connect to your TPU VM using SSH: @@ -94,6 +94,9 @@ Connect to your TPU VM using SSH: gcloud compute tpus tpu-vm ssh TPU_NAME --project PROJECT_ID --zone ZONE ``` +!!! note + When configuring `RUNTIME_VERSION` ("TPU software version") on GCP, ensure it matches the TPU generation you've selected by referencing the [TPU VM images] compatibility matrix. Using an incompatible version may prevent vLLM from running correctly. + [TPU versions]: https://cloud.google.com/tpu/docs/runtimes [TPU VM images]: https://cloud.google.com/tpu/docs/runtimes [TPU regions and zones]: https://cloud.google.com/tpu/docs/regions-zones diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 1fbbba7ace5e..c6588363b63f 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -120,7 +120,7 @@ A code example can be found here: th { white-space: nowrap; @@ -354,6 +356,7 @@ th { | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -419,7 +422,9 @@ See [this page](./pooling_models.md) for more information on how to use pooling Since some model architectures support both generative and pooling tasks, you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode. -#### Text Embedding +#### Embedding + +These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| @@ -457,28 +462,10 @@ If your model is not in the above list, we will try to automatically convert the [as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings of the whole prompt are extracted from the normalized hidden state corresponding to the last token. -#### Reward Modeling - -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `LlamaForCausalLM`C | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | - -C Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion)) -\* Feature support is the same as that of the original model. - -If your model is not in the above list, we will try to automatically convert the model using -[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. - -!!! important - For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, - e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. - #### Classification +These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API. + | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | @@ -491,7 +478,10 @@ If your model is not in the above list, we will try to automatically convert the If your model is not in the above list, we will try to automatically convert the model using [as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. -#### Sentence Pair Scoring +#### Cross-encoder / Reranker + +Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. +These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| @@ -501,6 +491,7 @@ If your model is not in the above list, we will try to automatically convert the | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | +| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | C Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -526,6 +517,28 @@ If your model is not in the above list, we will try to automatically convert the vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' ``` +#### Reward Modeling + +These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API. + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `LlamaForCausalLM`C | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* | + +C Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion)) +\* Feature support is the same as that of the original model. + +If your model is not in the above list, we will try to automatically convert the model using +[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. + +!!! important + For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + [](){ #supported-mm-models } ## List of Multimodal Language Models @@ -579,6 +592,8 @@ See [this page](generative_models.md) for more information on how to use generat #### Text Generation +These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API. + | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| | `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | ✅︎ | @@ -592,7 +607,7 @@ See [this page](generative_models.md) for more information on how to use generat | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | @@ -614,6 +629,7 @@ See [this page](generative_models.md) for more information on how to use generat | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I+ | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -720,11 +736,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition. See [this page](./pooling_models.md) for more information on how to use pooling models. -!!! important - Since some model architectures support both generative and pooling tasks, - you should explicitly specify `--runner pooling` to ensure that the model is used in pooling mode instead of generative mode. +#### Embedding -#### Text Embedding +These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API. !!! note To get the best results, you should use pooling models that are specifically trained as such. @@ -742,7 +756,10 @@ The following table lists those that are tested in vLLM. --- -#### Scoring +#### Cross-encoder / Reranker + +Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. +These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. | Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | |-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py index 73da7af85f1d..0c7d32d7862e 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py @@ -46,7 +46,7 @@ def _listen_for_register(poller, router_socket): global prefill_instances global prefill_cv with prefill_cv: - node = prefill_instances.pop(data["http_address"], None) + node = prefill_instances.get(data["http_address"], None) prefill_instances[data["http_address"]] = ( data["zmq_address"], time.time() + DEFAULT_PING_SECONDS, @@ -57,7 +57,7 @@ def _listen_for_register(poller, router_socket): global decode_instances global decode_cv with decode_cv: - node = decode_instances.pop(data["http_address"], None) + node = decode_instances.get(data["http_address"], None) decode_instances[data["http_address"]] = ( data["zmq_address"], time.time() + DEFAULT_PING_SECONDS, @@ -69,6 +69,7 @@ def _listen_for_register(poller, router_socket): remote_address, data, ) + return if node is None: print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]") diff --git a/requirements/common.txt b/requirements/common.txt index 6b57a3d2f1d0..5405df359a33 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,13 +7,13 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.53.2 +transformers >= 4.55.0 huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads. tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp -openai >= 1.87.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) +openai >= 1.98.0 # For Responses API with reasoning content pydantic >= 2.10 prometheus_client >= 0.18.0 pillow # Required for image processing @@ -49,3 +49,4 @@ ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring +openai-harmony >= 0.0.3 # Required for gpt-oss diff --git a/requirements/test.in b/requirements/test.in index 9ecaaae92727..9c8c75dd6f70 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -35,7 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.53.2 +transformers==4.55.0 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. diff --git a/requirements/test.txt b/requirements/test.txt index 691420df87c4..08ba964f22a4 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -214,7 +214,7 @@ fiona==1.10.1 # via torchgeo flask==3.1.1 # via mlflow -fonttools==4.54.1 +fonttools==4.55.0 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -286,7 +286,7 @@ httpx==0.27.2 # via # -r requirements/test.in # schemathesis -huggingface-hub==0.33.1 +huggingface-hub==0.34.3 # via # -r requirements/test.in # accelerate @@ -1148,7 +1148,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.53.2 +transformers==4.55.0 # via # -r requirements/test.in # genai-perf diff --git a/setup.py b/setup.py index 64cfbb8db962..c6f4985c5930 100644 --- a/setup.py +++ b/setup.py @@ -665,7 +665,7 @@ def _read_requirements(filename: str) -> list[str]: "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.9rc2"], + "flashinfer": ["flashinfer-python==0.2.9"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py new file mode 100644 index 000000000000..abdce8935ea5 --- /dev/null +++ b/tests/entrypoints/llm/test_classify.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(activation): + outputs = llm.classify( + prompts, + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) + return torch.tensor([x.outputs.probs for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + softmax(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py new file mode 100644 index 000000000000..ba20d7b9548e --- /dev/null +++ b/tests/entrypoints/llm/test_embedding.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "intfloat/multilingual-e5-small" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(normalize): + outputs = llm.embed(prompts, + pooling_params=PoolingParams(normalize=normalize), + use_tqdm=False) + return torch.tensor([x.outputs.embedding for x in outputs]) + + default = get_outputs(normalize=None) + w_normal = get_outputs(normalize=True) + wo_normal = get_outputs(normalize=False) + + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py new file mode 100644 index 000000000000..361e2d0e1047 --- /dev/null +++ b/tests/entrypoints/llm/test_reward.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "internlm/internlm2-1_8b-reward" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + trust_remote_code=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(softmax): + outputs = llm.reward(prompts, + pooling_params=PoolingParams(softmax=softmax), + use_tqdm=False) + return torch.cat([x.outputs.data for x in outputs]) + + default = get_outputs(softmax=None) + w_softmax = get_outputs(softmax=True) + wo_softmax = get_outputs(softmax=False) + + assert torch.allclose(default, w_softmax, + atol=1e-2), "Default should use softmax." + assert not torch.allclose(w_softmax, wo_softmax, + atol=1e-2), "wo_softmax should not use softmax." + assert torch.allclose( + softmax(wo_softmax), w_softmax, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py new file mode 100644 index 000000000000..dd4eae0ccc06 --- /dev/null +++ b/tests/entrypoints/llm/test_score.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + outputs = llm.score( + text_1, + text_2, + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) + return torch.tensor([x.outputs.score for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + softmax(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index b2472658ca81..bcf127307f73 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -3,6 +3,8 @@ import pytest import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import ClassificationResponse @@ -181,3 +183,32 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( invocation_data["probs"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_activation(server: RemoteOpenAIServer, model_name: str): + input_text = ["This product was excellent and exceeded my expectations"] + + async def get_outputs(activation): + response = requests.post(server.url_for("classify"), + json={ + "model": model_name, + "input": input_text, + "activation": activation + }) + outputs = response.json() + return torch.tensor([x['probs'] for x in outputs["data"]]) + + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index a7203befcc40..cf2442a56938 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -8,6 +8,8 @@ import pytest import pytest_asyncio import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer @@ -369,3 +371,35 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): embeddings_1_lst=[invocation_data["embedding"]], name_0="chat", name_1="invocation") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_normalize(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + async def get_outputs(normalize): + request_args = { + "model": MODEL_NAME, + "input": input_text, + "encoding_format": "float", + "normalize": normalize + } + + response = requests.post(server.url_for("v1/embeddings"), + json=request_args) + outputs = response.json() + + return torch.tensor([x['embedding'] for x in outputs["data"]]) + + default = await get_outputs(normalize=None) + w_normal = await get_outputs(normalize=True) + wo_normal = await get_outputs(normalize=False) + + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe13691..f121693e329f 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -3,6 +3,8 @@ import pytest import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import RerankResponse @@ -125,3 +127,39 @@ def test_invocations(server: RemoteOpenAIServer): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( invocations_result["relevance_score"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_activation(server: RemoteOpenAIServer, model_name: str): + + async def get_outputs(activation): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "activation": activation + }) + outputs = response.json() + + return torch.tensor([x['relevance_score'] for x in outputs["results"]]) + + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 187542b7bafc..1a5df1d2dbd2 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -4,6 +4,7 @@ import pytest import requests +import torch import torch.nn.functional as F from torch import tensor @@ -220,3 +221,43 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( invocation_data["score"], rel=0.01) + + def test_activation(self, server: RemoteOpenAIServer, model: dict[str, + Any]): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "activation": activation + }) + if response.status_code != 200: + return response + + outputs = response.json() + return torch.tensor([x['score'] for x in outputs["data"]]) + + if model["is_cross_encoder"]: + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." + else: + get_outputs(activation=None) + + # The activation parameter only works for the is_cross_encoder model + response = get_outputs(activation=True) + assert response.status_code == 400 diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 93bf20da4adb..bfeafaa9e27e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -278,23 +278,13 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("use_v1", [True, False]) def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): - + """Test that invalid attention backend names raise ValueError.""" with monkeypatch.context() as m, patch( "vllm.attention.selector.current_platform", CudaPlatform()): m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - # Test with head size 32 - backend = get_attn_backend(32, torch.float16, None, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" - assert backend.get_name() == EXPECTED - - # when block size == 16, backend will fall back to XFORMERS - # this behavior is not yet supported on V1. - if use_v1: - # TODO: support fallback on V1! - # https://github.com/vllm-project/vllm/issues/14524 - pass - else: - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() == "XFORMERS" + # Should raise ValueError for invalid backend + with pytest.raises(ValueError) as exc_info: + get_attn_backend(32, torch.float16, None, 16, False) + assert "Invalid attention backend: 'INVALID'" in str(exc_info.value) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py new file mode 100644 index 000000000000..e87ce520bc66 --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import flashinfer +import pytest +import torch + +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", + allow_module_level=True) + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + +MAX_Q_LEN = 1024 +MAX_KV_LEN = 4096 +BATCH_SIZES = [4, 12] +NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] +HEAD_SIZES = [128] +BLOCK_SIZES = [16, 32] +KV_LAYOUTS = ["HND"] +DTYPES = [torch.float16, torch.bfloat16] +KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 50.0] + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@torch.inference_mode +def test_flashinfer_trtllm_decode_with_baseline( + batch_size: int, + num_heads: tuple[int, int], + head_size: int, + block_size: int, + kv_layout: str, + dtype: torch.dtype, + kv_cache_dtype: Optional[torch.dtype], + soft_cap: Optional[float], +) -> None: + kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = MAX_KV_LEN + max_kv_len = torch.max(kv_lens).item() + num_seqs = len(kv_lens) + + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + kv_scale = 1.0 + if kv_cache_dtype is current_platform.fp8_dtype(): + key_value_cache, kv_scale = to_float8(key_value_cache, + current_platform.fp8_dtype()) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = kv_scale + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_query_heads // num_kv_heads) > 4)) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=scale, + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(query.shape, dtype=dtype) + wrapper.run(query, + key_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output) + + # TRTLLM Decode + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + output_trtllm = torch.empty(query.shape, dtype=dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query.contiguous(), + kv_cache=key_value_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=kv_lens_tensor, + max_seq_len=max_kv_len, + bmm1_scale=k_scale * scale, + bmm2_scale=v_scale, + out=output_trtllm, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", KV_LAYOUTS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@torch.inference_mode +def test_flashinfer_trtllm_prefill_with_baseline( + batch_size: int, + num_heads: tuple[int, int], + head_size: int, + block_size: int, + kv_layout: str, + dtype: torch.dtype, + kv_cache_dtype: Optional[torch.dtype], + soft_cap: Optional[float], +) -> None: + kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype + if dtype != kv_cache_dtype: + pytest.skip(f"Not supported dtype({dtype}) with " + "kv_cache_dtype({kv_cache_dtype})") + + torch.set_default_device("cuda") + current_platform.seed_everything(0) + + q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32) + q_lens[-1] = MAX_Q_LEN + max_q_len = torch.max(q_lens).item() + q_indptr = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ]) + + kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32) + kv_lens[-1] = MAX_KV_LEN + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + num_seqs = len(seq_lens) + + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + + scale = head_size**-0.5 + + query = torch.randn(torch.sum(q_lens).item(), + num_query_heads, + head_size, + dtype=dtype) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + kv_scale = 1.0 + if kv_cache_dtype is current_platform.fp8_dtype(): + key_value_cache, kv_scale = to_float8(key_value_cache, + current_platform.fp8_dtype()) + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = kv_scale + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout) + wrapper.plan(q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + causal=True, + sm_scale=scale, + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap) + + output = torch.empty(query.shape, dtype=dtype) + wrapper.run(query, + key_value_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output) + + # TRTLLM Decode + output_trtllm = torch.empty(query.shape, dtype=dtype) + flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=query.contiguous(), + kv_cache=key_value_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_q_len=max_q_len, + max_kv_len=max_seq_len, + bmm1_scale=k_scale * scale, + bmm2_scale=v_scale, + batch_size=num_seqs, + cum_seq_lens_q=q_indptr, + cum_seq_lens_kv=kv_indptr, + out=output_trtllm, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py deleted file mode 100644 index 2e2130fab6a2..000000000000 --- a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import flashinfer -import pytest -import torch - -from vllm.platforms import current_platform - -if not current_platform.is_device_capability(100): - pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", - allow_module_level=True) - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 - -# KV Cache Layout for TRT-LLM -# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - -NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] -HEAD_SIZES = [128] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. -SOFT_CAPS = [None, 30.0, 50.0] - - -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("kv_layout", ["HND"]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@torch.inference_mode -def test_flashinfer_trtllm_decode_with_baseline( - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float], - kv_layout: str, -) -> None: - torch.set_default_device("cuda") - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - kv_cache_shape = None - if kv_layout == "NHD": - kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) - elif kv_layout == "HND": - kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) - else: - raise ValueError(f"Invalid kv_layout: {kv_layout}") - key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - k_scale = v_scale = 1.0 - kv_indptr = [0] - kv_indices = [] - kv_last_page_lens = [] - for i in range(num_seqs): - seq_len = kv_lens[i] - assert seq_len > 0 - num_blocks = (seq_len + block_size - 1) // block_size - kv_indices.extend(block_tables[i, :num_blocks]) - kv_indptr.append(kv_indptr[-1] + num_blocks) - kv_last_page_len = seq_len % block_size - if kv_last_page_len == 0: - kv_last_page_len = block_size - kv_last_page_lens.append(kv_last_page_len) - - kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) - kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout, - use_tensor_cores=( - (num_query_heads//num_kv_heads) > 4) - ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) - - output = torch.empty(query.shape, dtype=dtype) - wrapper.run(query, key_value_cache, scale, out=output) - - # TRTLLM Decode - max_kv_len = max(kv_lens) - kv_lens_tensor = torch.tensor(kv_lens, - dtype=torch.int, - device=query.device) - output_trtllm = torch.empty(query.shape, dtype=dtype) - flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query.contiguous(), - key_value_cache, - workspace_buffer, - block_tables, - kv_lens_tensor, - max_kv_len, - bmm1_scale=k_scale * scale, - bmm2_scale=v_scale, - out=output_trtllm, - ) - - torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py deleted file mode 100644 index 9f2229cc41df..000000000000 --- a/tests/kv_transfer/test_disagg.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import subprocess -import sys -import time -from subprocess import Popen - -import pytest -import requests -import torch - - -# Fixture to set up environment variables and teardown servers after tests -@pytest.fixture(scope="module", autouse=True) -def setup_servers(): - if torch.cuda.device_count() < 2: - pytest.skip("Skipping test: fewer than 2 GPUs available") - - # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", - shell=True).decode().strip() - os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP - - # Start prefill instance - prefill_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8100", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ - '"kv_rank":0,"kv_parallel_size":2}', - ] - prefill_env = os.environ.copy() - prefill_env["CUDA_VISIBLE_DEVICES"] = "0" - prefill_proc = Popen(prefill_cmd, env=prefill_env) - - # Start decode instance - decode_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8200", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ - '"kv_rank":1,"kv_parallel_size":2}', - ] - decode_env = os.environ.copy() - decode_env["CUDA_VISIBLE_DEVICES"] = "1" - decode_proc = Popen(decode_cmd, env=decode_env) - - # Wait for servers to be ready - assert wait_for_server(8100), "Prefill server did not start in time" - assert wait_for_server(8200), "Decode server did not start in time" - - # Yield to the test function and handle teardown after tests - yield - - # Cleanup: kill the processes - prefill_proc.terminate() - decode_proc.terminate() - - # Additional cleanup if needed - prefill_proc.wait() - decode_proc.wait() - - -# Helper function to wait for server -def wait_for_server(port, timeout=240): - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(f"http://localhost:{port}/v1/completions") - if response.status_code in [200, 405]: - return True - except requests.ConnectionError: - time.sleep(1) - return False - - -# Test function to send curl requests and validate responses -@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) -def test_disaggregated_prefilling(prompt): - # Send to prefill - response = requests.post("http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - }) - assert response.status_code == 200 - - # Send to decode - response = requests.post("http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - }) - assert response.status_code == 200 diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py new file mode 100644 index 000000000000..2b1c74652e76 --- /dev/null +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F + +from tests.models.utils import softmax +from vllm.config import PoolerConfig + + +@pytest.mark.parametrize( + "model", + [ + "jason9693/Qwen2.5-1.5B-apeach", + "papluca/xlm-roberta-base-language-detection" + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models_using_activation( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=False)) as vllm_model: + wo_activation_out = vllm_model.classify(example_prompts) + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=True)) as vllm_model: + w_activation_out = vllm_model.classify(example_prompts) + + for wo_activation, w_activation in zip(wo_activation_out, + w_activation_out): + wo_activation = torch.tensor(wo_activation) + w_activation = torch.tensor(w_activation) + + assert not torch.allclose( + wo_activation, w_activation, + atol=1e-2), "override_pooler_config is not working" + assert torch.allclose(softmax(wo_activation), w_activation, + 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + normalize=False)) as vllm_model: + wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig(normalize=True)) as vllm_model: + w_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + assert not torch.allclose( + wo_normalize, w_normalize, + atol=1e-2), "override_pooler_config normalize is not working" + assert torch.allclose( + F.normalize(wo_normalize, p=2, dim=-1), w_normalize, + atol=1e-2), "w_normal should be close to normal(wo_normal)." + + +@pytest.mark.parametrize( + "model", + [ + "internlm/internlm2-1_8b-reward", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_reward_models_using_softmax( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: + wo_softmax = vllm_model.encode(example_prompts) + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: + w_softmax = vllm_model.encode(example_prompts) + + for wo, w in zip(wo_softmax, w_softmax): + wo = torch.tensor(wo) + w = torch.tensor(w) + + assert not torch.allclose( + wo, w, atol=1e-2), "override_pooler_config softmax is not working" + assert torch.allclose( + softmax(wo), w, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index a5f7dca76d82..7add1d975c63 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -103,7 +103,7 @@ def test_prm_models( # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): - hf_output = torch.tensor(hf_output) - vllm_output = torch.tensor(vllm_output) + hf_output = torch.tensor(hf_output).float() + vllm_output = torch.tensor(vllm_output).float() assert torch.allclose(hf_output, vllm_output, 1.5e-2) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 8cb826c1144d..7ec959d57006 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -337,6 +337,10 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + # FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we + # should enable this again after the fix is released: + # https://github.com/huggingface/transformers/pull/39915 + marks=[pytest.mark.skip("HF model is broken")], ), "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], @@ -631,6 +635,18 @@ hf_model_kwargs={"llm_attn_implementation": "sdpa"}, patch_hf_runner=model_utils.ovis_patch_hf_runner, ), + "ovis2.5": VLMTestInfo( + models=["AIDC-AI/Ovis2.5-9B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index bd1c55d95dac..eb3b0a99c1ad 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -300,6 +300,7 @@ def _test_processing_correctness_one( "AIDC-AI/Ovis1.6-Gemma2-9B", "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", + "AIDC-AI/Ovis2.5-9B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-3.5-vision-instruct", diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/test_tensor_schema.py index f80e8456f02e..a4cb1a68833a 100644 --- a/tests/models/multimodal/test_tensor_schema.py +++ b/tests/models/multimodal/test_tensor_schema.py @@ -1,11 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import partial -from typing import Any from unittest.mock import patch import pytest -from transformers import PretrainedConfig from vllm.config import ModelConfig from vllm.engine.llm_engine import LLMEngine as V0LLMEngine @@ -19,6 +17,7 @@ from ...conftest import VllmRunner from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ..utils import dummy_hf_overrides ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", @@ -51,51 +50,6 @@ def create_batched_mm_kwargs( return mm_kwargs -# Avoid OOM and reduce initialization time by only using 1 layer -def hf_overrides(hf_config: PretrainedConfig, - exist_overrides: dict[str, Any]) -> PretrainedConfig: - hf_config.update(exist_overrides) - text_config = hf_config.get_text_config() - # Ensure at least 2 expert per group - # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) - num_experts = n_group * 2 if n_group is not None else 2 - # we use three layers for Gemma-3n to check - # both normal layer and kv_shared_layer - text_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, - # For Gemma-3n - "num_kv_shared_layers": 1, - }) - if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: ibm-granite/granite-speech-3.3-2b - if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: Qwen/Qwen2-Audio-7B-Instruct - if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) - return hf_config - - @pytest.mark.core_model @pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys())) def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], @@ -110,7 +64,8 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], model_id = model_info.default - hf_overrides_fn = partial(hf_overrides, + hf_overrides_fn = partial(dummy_hf_overrides, + model_arch=model_arch, exist_overrides=model_info.hf_overrides) model_config = ModelConfig( diff --git a/tests/models/registry.py b/tests/models/registry.py index ffa6b755adf4..029ce0248e9c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -139,8 +139,7 @@ def check_available_online( trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), - "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base", - is_available_online=False), + "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", @@ -180,8 +179,7 @@ def check_available_online( min_transformers_version="4.54"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", - min_transformers_version="4.53"), + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), @@ -199,6 +197,7 @@ def check_available_online( {"6b": "EleutherAI/gpt-j-6b"}), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"}), + "GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 @@ -224,7 +223,10 @@ def check_available_online( trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 + extras={ + "tiny": "ai21labs/Jamba-tiny-dev", + "random": "ai21labs/Jamba-tiny-random", # noqa: E501 + }), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 @@ -240,8 +242,7 @@ def check_available_online( trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), - "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf", - min_transformers_version="4.53"), + "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"), "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 @@ -273,6 +274,8 @@ def check_available_online( "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + max_transformers_version="4.53", + transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), @@ -300,8 +303,7 @@ def check_available_online( "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), - "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", - min_transformers_version="4.53"), + "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), @@ -327,8 +329,12 @@ def check_available_online( "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True, v0_only=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), - "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), + "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 + "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 @@ -384,7 +390,7 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501 - "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air", + "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", is_available_online=False), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", trust_remote_code=True, @@ -440,6 +446,7 @@ def check_available_online( "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 + "Ovis2.5": _HfExamplesInfo("AIDC-AI/Ovis2.5-9B", trust_remote_code=True,), # noqa: E501 "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 4c7da24fca32..f0aa91566b57 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import partial from unittest.mock import patch import pytest -from transformers import PretrainedConfig from vllm import LLM from vllm.config import ModelImpl @@ -16,6 +16,7 @@ from ..utils import create_new_process_for_each_test from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels) +from .utils import dummy_hf_overrides @create_new_process_for_each_test() @@ -33,64 +34,15 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") + hf_overrides_fn = partial(dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides) + if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): from vllm.model_executor.models.llama4 import Llama4ForCausalLM from vllm.model_executor.models.registry import ModelRegistry ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) - # Avoid OOM and reduce initialization time by only using 1 layer - def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: - hf_config.update(model_info.hf_overrides) - - text_config = hf_config.get_text_config() - - # Ensure at least 2 expert per group - # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) - num_experts = n_group * 2 if n_group is not None else 2 - - # we use three layers for Gemma-3n to check - # both normal layer and kv_shared_layer - num_hidden_layers = (3 if model_arch - == "Gemma3nForConditionalGeneration" else 1) - - text_config.update({ - "num_layers": 1, - "num_hidden_layers": num_hidden_layers, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, - # For Gemma-3n - "num_kv_shared_layers": 1, - }) - - if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - - # e.g.: ibm-granite/granite-speech-3.3-2b - if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - - # e.g.: Qwen/Qwen2-Audio-7B-Instruct - if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) - - return hf_config - # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: self.cache_config.num_gpu_blocks = 0 @@ -132,7 +84,7 @@ def _initialize_kv_caches_v1(self, vllm_config): load_format="dummy", model_impl=ModelImpl.TRANSFORMERS if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, - hf_overrides=hf_overrides, + hf_overrides=hf_overrides_fn, ) diff --git a/tests/models/utils.py b/tests/models/utils.py index 3cd0721be1b6..1513db52209e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F +from transformers import PretrainedConfig from vllm.config import ModelConfig, RunnerOption from vllm.inputs import InputContext @@ -330,6 +331,13 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int): return tensor +def softmax(data): + if data.shape[-1] == 1: + return F.sigmoid(data) + else: + return F.softmax(data, dim=-1) + + class EmbedModelInfo(NamedTuple): name: str is_matryoshka: bool = False @@ -344,3 +352,63 @@ class RerankModelInfo(NamedTuple): architecture: str = "" dtype: str = "auto" enable_test: bool = True + + +def dummy_hf_overrides( + hf_config: PretrainedConfig, + model_arch: str, + exist_overrides: Optional[dict[str, Any]] = None, +) -> PretrainedConfig: + """ + Dummy HF overrides function used to create dummy model + with only minimum nums of layer. + """ + hf_config.update(exist_overrides or {}) + + text_config = hf_config.get_text_config() + + # Ensure at least 2 expert per group + # Since `grouped_topk` assumes top-2 + n_group = getattr(text_config, 'n_group', None) + num_experts = n_group * 2 if n_group is not None else 2 + + # we use three layers for Gemma-3n to check + # both normal layer and kv_shared_layer + num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration" + else 1) + text_config.update({ + "num_layers": 1, + "num_hidden_layers": num_hidden_layers, + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + # For Gemma-3n + "num_kv_shared_layers": 1, + }) + + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + + # e.g.: ibm-granite/granite-speech-3.3-2b + if hasattr(hf_config, "encoder_config"): + hf_config.encoder_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + + # e.g.: Qwen/Qwen2-Audio-7B-Instruct + if hasattr(hf_config, "audio_config"): + hf_config.audio_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + }) + + return hf_config diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index ef99c3dadd32..1d7e4475011d 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -4,9 +4,7 @@ import pytest import torch -from vllm.attention.selector import get_attn_backend from vllm.plugins import load_general_plugins -from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL def test_platform_plugins(): @@ -27,14 +25,6 @@ def test_platform_plugins(): f" is loaded. The first import:\n{_init_trace}") -def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch): - # ignore the backend env variable if it is set - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - backend = get_attn_backend(16, torch.float16, "auto", 16, False) - assert backend.get_name() == "Dummy_Backend" - - def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch): # simulate workload by running an example load_general_plugins() diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 84a656a3b9da..1e3e69e008bd 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -9,6 +9,8 @@ from tests.quantization.utils import is_quant_method_supported +from ..models.registry import HF_EXAMPLE_MODELS + MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] @@ -25,6 +27,8 @@ def test_model_experts_int8_startup( dtype: str, max_tokens: int, ) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_transformers_version(on_fail="skip") with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model: diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index e5ab7b3dd3cf..0b37c83c92c2 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -194,3 +194,36 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): ref_y, per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype)) + + # non-contiguous input with padding + m, n, padded_stride = 975, 512, 576 + padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * + 13).to(dtype) + x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) + + assert not x_nc.is_contiguous() + assert x_nc.stride(0) == padded_stride + + # dynamic quantization + ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None) + ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype) + + # reference dynamic quantization + y_nc = quantize_ref(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # static quantization + y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # padding after non-contiguous input quantization + y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, + inv_scale_nc, + num_token_padding=m + 10) + assert y_nc_pad.shape[0] == m + 10 + torch.testing.assert_close( + ref_y_nc, + per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), + inv_scale_nc, dtype)) diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py new file mode 100644 index 000000000000..52c03015483c --- /dev/null +++ b/tests/test_pooling_params.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import EmbedModelInfo +from vllm import PoolingParams +from vllm.config import ModelConfig + +EMBEDDING_MODELS = [ + EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256]), +] + + +def test_task(): + pooling_params = PoolingParams() + pooling_params.verify(task="score") + + pooling_params = PoolingParams(task="score") + pooling_params.verify(task="score") + + with pytest.raises(ValueError): + pooling_params.verify(task="encode") + + +def test_embed(): + task = "embed" + pooling_params = PoolingParams(normalize=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=False) + pooling_params.verify(task=task) + + invalid_parameters = ["activation", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) +def test_embed_dimensions(model_info: EmbedModelInfo): + task = "embed" + model_config = ModelConfig( + model_info.name, + task="auto", + tokenizer=model_info.name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + pooling_params = PoolingParams(dimensions=None) + pooling_params.verify(task=task, model_config=model_config) + + with pytest.raises(ValueError): + pooling_params = PoolingParams(dimensions=1) + pooling_params.verify(task=task, model_config=model_config) + + if model_info.is_matryoshka: + assert model_info.matryoshka_dimensions is not None + pooling_params = PoolingParams( + dimensions=model_info.matryoshka_dimensions[0]) + pooling_params.verify(task=task, model_config=model_config) + + +@pytest.mark.parametrize("task", ["score", "classify"]) +def test_classify(task): + pooling_params = PoolingParams(activation=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +def test_encode(): + task = "encode" + pooling_params = PoolingParams(softmax=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "activation"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 78a6509986fc..e9e574501d63 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -128,6 +128,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", + _Backend.XFORMERS_VLLM_V1: + "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", } if backend_name not in backend_map: diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 680e2ce98bb2..8bd142e87b06 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts, assert len(prompt_token_ids) == len(prompt_logprobs) +def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): + """Engine should return all vocabulary logprobs + + Args: + example_prompts: list of example prompts (test fixture) + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=-1, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.15, + max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, + logprobs=-1) + results_logprobs_all = runner.llm.generate( + example_prompts, sampling_params=sampling_params_logprobs_all) + vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): + logprobs = results_logprobs_all[i].outputs[0].logprobs + assert logprobs is not None + for logprob in logprobs: + assert len(logprob) == vocab_size + + @pytest.mark.parametrize( "logprobs_mode", ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 42468daa62a9..456ce712d36e 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -155,7 +155,7 @@ def test_tree_attn_correctness() -> None: dim_per_head = 128 num_kv_heads = 2 - block_size = 128 + block_size = 32 max_sequence_length = 8192 randomize_blocks = True for batch_size in [1, 16, 32]: diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py index 55fee4ee1ad4..9d690851b70e 100644 --- a/tests/v1/tpu/test_mha_attn.py +++ b/tests/v1/tpu/test_mha_attn.py @@ -12,17 +12,10 @@ import torch_xla.core import torch_xla.core.xla_model -from vllm import envs from vllm.attention.layer import MultiHeadAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform -if not envs.VLLM_USE_V1: - pytest.skip( - "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", - allow_module_level=True, - ) - @pytest.fixture(autouse=True) def clear_cache(): diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index a61773a4f611..bcc2993028dd 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -4,19 +4,12 @@ import openai import pytest -from vllm import envs from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.platforms import current_platform from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS from ...utils import RemoteOpenAIServer -if not envs.VLLM_USE_V1: - pytest.skip( - "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", - allow_module_level=True, - ) - @pytest.fixture(scope="session") def base64_encoded_image() -> dict[str, str]: diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 198bb1e16ed9..fa950e5f7f85 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -4,16 +4,10 @@ import pytest -from vllm import LLM, envs +from vllm import LLM from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -if not envs.VLLM_USE_V1: - pytest.skip( - "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", - allow_module_level=True, - ) - @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) @pytest.mark.skipif(not current_platform.is_tpu(), diff --git a/tools/ep_kernels/README.md b/tools/ep_kernels/README.md index 273e0f378e34..85e9d2a4f812 100644 --- a/tools/ep_kernels/README.md +++ b/tools/ep_kernels/README.md @@ -13,16 +13,16 @@ All scripts accept a positional argument as workspace path for staging the build ## Usage -### Single-node - ```bash -bash install_python_libraries.sh +# for hopper +TORCH_CUDA_ARCH_LIST="9.0" bash install_python_libraries.sh +# for blackwell +TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh ``` -### Multi-node +Additional step for multi-node deployment: ```bash -bash install_python_libraries.sh sudo bash configure_system_drivers.sh sudo reboot # Reboot is required to load the new driver ``` diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index 9d1b2da3b412..e163c83e8b51 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -29,6 +29,12 @@ if [ -z "$CUDA_HOME" ]; then exit 1 fi +# assume TORCH_CUDA_ARCH_LIST is set correctly +if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then + echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture." + exit 1 +fi + # disable all features except IBGDA export NVSHMEM_IBGDA_SUPPORT=1 @@ -95,7 +101,7 @@ clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" cd pplx-kernels # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # PIP_NO_BUILD_ISOLATION=0 disables build isolation -PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install -vvv -e . +PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . popd # build and install deepep, require pytorch installed diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 35345b1be01c..92de39418054 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_b.shape * [128, 128] == b.shape """ assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype - m = a.shape[0] - n = b.shape[1] + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) if current_platform.is_rocm() or not cutlass_compatible_b: from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa triton_scaled_mm) - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + else: + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - return out + return out.view(*target_shape) def cutlass_scaled_mm_azp(a: torch.Tensor, @@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.numel( ) == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) + # Massage the input to be 2D + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) + assert azp is None or azp.numel() == a.shape[0] + out = torch.empty((a.shape[0], b.shape[1]), + dtype=out_dtype, + device=a.device) torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out + return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: @@ -1279,14 +1284,13 @@ def scaled_fp8_quant( device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input.contiguous(), scale, scale_ub) + output, input, scale, scale_ub) else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(), - scale) + scale = torch.empty(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" - torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b3372ce2eca8..78d8a67e37f8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -46,7 +46,7 @@ from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) -from vllm.utils.flashinfer import use_trtllm_decode_attention +from vllm.utils.flashinfer import use_trtllm_attention logger = init_logger(__name__) @@ -1114,7 +1114,7 @@ def forward( assert decode_meta.decode_wrapper._sm_scale == softmax_scale # TODO: @pavanimajety Remove this once the switch happens # inside flashinfer. - if not use_trtllm_decode_attention( + if not use_trtllm_attention( num_decode_tokens, attn_metadata.max_decode_seq_len, kv_cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 4f839348e522..08bfcc974cc9 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -28,6 +28,7 @@ def kernel_paged_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -95,7 +96,17 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([num_queries_per_kv_padded], + float("-inf"), + dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_head_idx, + mask=head_mask, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -223,6 +234,8 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + # Optional tensor for sinks + sinks=None, ): if sm_scale is None: @@ -253,6 +266,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + sinks=sinks, ) block_size = value_cache.shape[3] @@ -281,11 +295,17 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, - block_size, - num_queries_per_kv, - max_seq_len, sliding_window, - kv_cache_dtype, alibi_slopes) + use_custom = use_rocm_custom_paged_attention( + query.dtype, + head_size, + block_size, + num_queries_per_kv, + max_seq_len, + sliding_window, + kv_cache_dtype, + alibi_slopes, + sinks, + ) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -334,6 +354,7 @@ def chunked_prefill_paged_decode( query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 13bef96722d2..64c90337970f 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -38,6 +38,7 @@ def _fwd_kernel(Q, V, K_cache, V_cache, + sink_ptr, B_Loc, sm_scale, k_scale, @@ -126,7 +127,15 @@ def _fwd_kernel(Q, other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + m_i = tl.load( + sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), + mask=(offs_m < cur_batch_query_len), + other=float("-inf"), + ).to(dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] @@ -732,7 +741,8 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None, sm_scale=None, - skip_decode=False): + skip_decode=False, + sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -781,6 +791,7 @@ def context_attention_fwd(q, sliding_window = 0 if alibi_slopes is not None: + assert sinks is None, "Sinks arg is not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -843,7 +854,7 @@ def context_attention_fwd(q, max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): - extra_kargs = {"kpack": 2, "waves_per_eu": 2} + extra_kargs = {"kpack": 1, "waves_per_eu": 2} grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -853,6 +864,7 @@ def context_attention_fwd(q, v, k_cache, v_cache, + sinks, b_loc, sm_scale, k_scale, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 0fdba569f93f..ba4299a2772d 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -52,6 +52,7 @@ def kernel_unified_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -131,7 +132,15 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -292,6 +301,7 @@ def kernel_unified_attention_3d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -383,7 +393,15 @@ def kernel_unified_attention_3d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None or segm_idx != 0: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -627,6 +645,8 @@ def unified_attention( v_descale, alibi_slopes=None, qq_bias=None, + # Optional tensor for sinks + sinks=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -635,6 +655,10 @@ def unified_attention( assert q.element_size() >= 2 or block_size >= 32, \ "Block size must be at least 32 for fp8" + if sinks is not None: + assert sinks.shape[0] == q.shape[1], \ + "Sinks must be num_query_heads size" + use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -669,6 +693,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -741,6 +766,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 2e3c8638125f..596c556e54f0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -193,6 +193,10 @@ def _cached_get_attn_backend( backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}") # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 54f00d541521..e07e52be9fdf 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -7,11 +7,13 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -if current_platform.is_cuda(): +if current_platform.is_cuda_alike(): from .fusion import FusionPass - from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fusion_attn import AttnFusionPass +if current_platform.is_cuda(): + from .collective_fusion import AllReduceFusionPass, AsyncTPPass + from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context diff --git a/vllm/config.py b/vllm/config.py index 871df455ef58..899862bf541e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) -from functools import cached_property +from functools import cached_property, lru_cache from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -377,7 +377,8 @@ class ModelConfig: max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the - OpenAI Chat Completions API.""" + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * + vocab_size) logprobs are allowed to be returned and it may cause OOM.""" logprobs_mode: LogprobsMode = "raw_logprobs" """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: @@ -912,15 +913,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) - if self.is_matryoshka: - if pooler_config.normalize is None: - pooler_config.normalize = True - elif not pooler_config.normalize: - raise ValueError( - "`normalize` must be enabled (set to True) " - "for models that are compatible with " - "Matryoshka Representation.") - return pooler_config return None @@ -1107,6 +1099,21 @@ def _parse_quant_hf_config(self): if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", + {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError( + f"Unknown ModelOpt quant algo: {quant_algo}") + return quant_cfg def _verify_quantization(self) -> None: @@ -1585,7 +1592,7 @@ def try_get_generation_config(self) -> dict[str, Any]: """ This method attempts to retrieve the non-default values of the generation config for this model. - + The generation config can contain information about special tokens, as well as sampling parameters. Which is why this method exists separately to `get_diff_sampling_param`. @@ -2066,7 +2073,7 @@ class ParallelConfig: and when data_parallel_size > 0. Enables running an AsyncLLM and API server on a "per-node" basis where vLLM load balances between local data parallel ranks, but an external LB balances - between vLLM nodes/replicas. Set explicitly in conjunction with + between vLLM nodes/replicas. Set explicitly in conjunction with --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" @@ -3422,25 +3429,34 @@ class PoolerConfig: [`vllm.model_executor.layers.pooler.PoolingType`][]. """ + ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the pooled outputs. Usually, this should be set to - ``True`` for embedding outputs. + Whether to normalize the embeddings outputs. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. """ - softmax: Optional[bool] = None + ## for classification models + activation: Optional[bool] = None """ - Whether to apply softmax to the pooled outputs. Usually, this should be set - to ``True`` for classification outputs. + Whether to apply activation function to the classification outputs. """ + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + """ step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ - returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, @@ -4358,12 +4374,20 @@ def __repr__(self) -> str: "disabled_custom_ops": True, "compilation_time": True, "bs_to_padded_graph_size": True, - "pass_config": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, }, } + + # exclude default attr in pass_config + pass_config_exclude = {} + for attr, default_val in vars(PassConfig()).items(): + if getattr(self.pass_config, attr) == default_val: + pass_config_exclude[attr] = True + if pass_config_exclude: + exclude["pass_config"] = pass_config_exclude + # The cast to string is necessary because Pydantic is mocked in docs # builds and sphinx-argparse doesn't know the return type of decode() return str( @@ -5099,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_vllm_config()""" + return get_current_vllm_config().compilation_config def get_current_vllm_config() -> VllmConfig: diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index e5ba297ebcc1..46cc1c2f52d6 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -70,6 +70,7 @@ def __init__( assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned" self._comm = get_pp_group().device_communicator + assert self._comm is not None # Since we wrap around the vLLM _PP communicator, we use # the rank from the vLLM communicator, and ignore the rank diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index af6462084968..f64b516b0d04 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -251,6 +251,7 @@ def build( if global_expert_load is not None: ep_group = get_ep_group().device_group + assert ep_group is not None assert global_expert_load.shape == (model.num_moe_layers, model.num_logical_experts) assert global_expert_load.dtype == torch.int64 @@ -357,6 +358,7 @@ def step(self, # Collect load metrics from all ranks ep_group = get_ep_group().device_group + assert ep_group is not None num_tokens_list = [ torch.empty_like(num_tokens) for _ in range(ep_group.size()) ] @@ -412,6 +414,7 @@ def rearrange(self, """ ep_group = get_ep_group().device_group + assert ep_group is not None ep_rank = ep_group.rank() time_start = None diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 868b227fc899..011bbb69abb0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -1,142 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -KVConnectorBase Class for Distributed KV Cache & Hidden State communication - -The class provides two primary abstract methods: -1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states -2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states -""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union - -import torch +"""Defines the base type for KV cache connectors.""" from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.config import VllmConfig - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - - -class KVConnectorBase(ABC): - """ - Abstract base class for a KV connector. - - The class provides two primary abstract methods: - 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states - 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states - """ - - @abstractmethod - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - """ - Send KV caches and hidden states to the connector. - - This method processes the input tokens, KV caches, and - hidden/intermediate states for a given model and sends the data to the - decode instance. - - Args: - model_executable (torch.nn.Module): The model executable containing - start and end layer information. - model_input (ModelInputForGPUWithSamplingMetadata): The input - metadata from vLLM. - kv_caches (list[torch.Tensor]): List of KV caches (keys and values) - for each layer. - hidden_or_intermediate_states (Union[torch.Tensor, - IntermediateTensors]): - The hidden or intermediate states associated with the tokens. - - Returns: - None - - """ - - raise NotImplementedError - - @abstractmethod - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - """ - Receive KV caches and hidden states from the connector. - - This method attempts to retrieve KV caches and hidden states for input - tokens. If all required KV caches and hidden states are received, it - will bypass model input, else it will fall back to normal vLLM model - forwarding. - - Args: - model_executable (torch.nn.Module): - The model executable from vLLM modelrunner. - model_input (ModelInputForGPUWithSamplingMetadata): - The model input from vLLM modelrunner. - kv_caches (list[torch.Tensor]): - List of KV caches for each layer. - - Returns: - - hidden_or_intermediate_states (torch.Tensor or - IntermediateTensors): - Concatenated hidden states if all required data is retrieved, - otherwise `None`. - - bypass_model_exec (bool): - Indicates whether the model execution can be skipped (True) or - needs to be redone (False). - - model_input (ModelInputForGPUWithSamplingMetadata): - Optionally adjusted input metadata for re-execution when - `bypass_model_exec=False`. - - """ - - raise NotImplementedError - - @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: - """ - Get the required KV cache layout for this connector. - Args: - vllm_config (VllmConfig): the vllm config. - - Returns: - str: the required KV cache layout. e.g. HND, or NHD. - None if the connector does not require a specific layout. - """ - return None +KVConnectorBase = KVConnectorBase_V1 +KVConnectorBaseType = KVConnectorBase_V1 -KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] +__all__ = ["KVConnectorBase", "KVConnectorBaseType"] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index cf7cde2c4377..01673a0d7c87 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,14 +5,10 @@ from typing import TYPE_CHECKING, Callable import vllm.envs as envs -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger -from .base import KVConnectorBase - if TYPE_CHECKING: from vllm.config import VllmConfig @@ -20,7 +16,7 @@ class KVConnectorFactory: - _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} + _registry: dict[str, Callable[[], type[KVConnectorBase]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -29,28 +25,23 @@ def register_connector(cls, name: str, module_path: str, if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> type[KVConnectorBaseType]: + def loader() -> type[KVConnectorBase]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector_v0(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: - if envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V0 Connector, " + def create_connector( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " f"but found {envs.VLLM_USE_V1=}") - connector_cls = cls.get_connector_class(config.kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase) - return connector_cls(rank, local_rank, config) - - @classmethod - def get_connector_class( - cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + kv_transfer_config = config.kv_transfer_config connector_name = kv_transfer_config.kv_connector if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() @@ -61,21 +52,7 @@ def get_connector_class( f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) - return connector_cls - - @classmethod - def create_connector_v1( - cls, - config: "VllmConfig", - role: KVConnectorRole, - ) -> KVConnectorBase_V1: - if not envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}") - - kv_transfer_config = config.kv_transfer_config - connector_cls = cls.get_connector_class(kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase_V1) + assert issubclass(connector_cls, KVConnectorBase) logger.info("Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -92,25 +69,6 @@ def create_connector_v1( # Register various connectors here. # The registration should not be done in each individual file, as we want to # only load the files corresponding to the current connector. -KVConnectorFactory.register_connector( - "PyNcclConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "LMCacheConnector", - "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", - "LMCacheConnector") - -KVConnectorFactory.register_connector( - "MooncakeStoreConnector", - "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") KVConnectorFactory.register_connector( "SharedStorageConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py deleted file mode 100644 index 78bf3095613a..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -LMCache KV Cache Connector for Distributed Machine Learning Inference - -The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker -(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; -(2) offload and share KV caches. -""" - -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class LMCacheConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.transfer_config = config.kv_transfer_config - self.vllm_config = config - - from lmcache.experimental.cache_engine import LMCacheEngineBuilder - from lmcache.integration.vllm.utils import ENGINE_NAME - from lmcache.integration.vllm.vllm_adapter import ( - RetrieveStatus, StoreStatus, init_lmcache_engine, - lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, - lmcache_store_kv) - logger.info("Initializing LMCacheConfig under kv_transfer_config %s", - self.transfer_config) - - # TODO (Jiayi): Find model_config, parallel_config, and cache_config - self.engine = init_lmcache_engine(config.model_config, - config.parallel_config, - config.cache_config) - self.lmcache_engine_name = ENGINE_NAME - self.lmcache_engine_builder = LMCacheEngineBuilder - - self.model_config = config.model_config - self.parallel_config = config.parallel_config - self.cache_config = config.cache_config - self.lmcache_retrieve_kv = lmcache_retrieve_kv - self.lmcache_store_kv = lmcache_store_kv - self.lmcache_should_retrieve = lmcache_should_retrieve - self.lmcache_should_store = lmcache_should_store - self.store_status = StoreStatus - self.retrieve_status = RetrieveStatus - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - retrieve_status = self.lmcache_should_retrieve(model_input) - model_input, bypass_model_exec, hidden_or_intermediate_states =\ - self.lmcache_retrieve_kv( - model_executable, model_input, self.cache_config, kv_caches, - retrieve_status) - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - store_status = self.lmcache_should_store(model_input) - self.lmcache_store_kv( - self.model_config, - self.parallel_config, - self.cache_config, - model_executable, - model_input, - kv_caches, - store_status, - ) - - def close(self): - self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py deleted file mode 100644 index 94a7ce91acf1..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ /dev/null @@ -1,203 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -MooncakeStore Connector for Distributed Machine Learning Inference -The MooncakeStoreConnector transfers KV caches between prefill vLLM workers -(KV cache producer) and decode vLLM workers (KV cache consumer) using a -database-style KVStore. -""" -import hashlib -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class MooncakeStoreConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - self.kv_transfer_config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - self.local_tp_rank = local_rank - - # Init kv_store - if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_store: - raise ValueError( - "To use MooncakeStoreConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 - MooncakeStore) - logger.info( - "Initializing KVStoreConnector under kv_transfer_config %s", - self.kv_transfer_config) - self.kv_store = MooncakeStore(config) - else: - logger.error("Can not find %s", - self.kv_transfer_config.kv_connector) - - assert self.kv_store is not None - - def close(self) -> None: - """Close the buffer and release resources. - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - self.kv_store.close() - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - current_tokens = input_tokens_tensor[start_pos:end_pos] - store_key_prefix = self.tensor_hash(current_tokens) - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - kvcache_to_sent = torch.stack((keys, values), dim=0) - store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" - self.kv_store.put(store_kvcache_key, kvcache_to_sent) - - hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" - self.kv_store.put(hidden_key, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - bypass_model_exec = True - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - hidden_or_intermediate_states_for_one_req = [] - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - # get roi for current seq - load_key_prefix = self.tensor_hash(current_tokens) - load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" - remote_kv = self.kv_store.get(load_kvcache_key) - hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" - hidden = self.kv_store.get(hidden_key) - - if remote_kv is None or hidden is None: - # didn't find any match. - bypass_model_exec = False - continue - - num_computed_tokens = current_tokens.shape[0] - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # call self.kv_store to get kv layer by layer - for layer_id in range(start_layer, end_layer): - layer = model_executable.model.layers[layer_id] - # get kvcache object - kv_cache = kv_caches[layer_id - start_layer] - - # get remote kvcache - remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ - layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - @staticmethod - def tensor_hash(tensor: torch.Tensor) -> int: - """Calculate the hash value of the tensor.""" - tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() - hash_object = hashlib.blake2b(tensor_bytes) - hash_hex = hash_object.hexdigest() - return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py deleted file mode 100644 index e7c079e1f115..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ /dev/null @@ -1,329 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Simple KV Cache Connector for Distributed Machine Learning Inference - -The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache -producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or -MooncakePipe. - -But the logic can be extended to support other pipe and lookup buffer. -""" -from typing import TYPE_CHECKING, Optional, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class SimpleConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - - if self.config.kv_connector == "PyNcclConnector": - from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( - PyNcclPipe) - logger.info( - "Initializing PyNcclConfig under kv_transfer_config %s", - self.config) - elif self.config.kv_connector == "MooncakeConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_distributed_pipe = os.getenv( - 'MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_distributed_pipe: - raise ValueError( - "To use MooncakeConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 - MooncakePipe) - logger.info( - "Initializing MooncakeConfig under kv_transfer_config %s", - self.config) - - self.lookup_buffer_size = self.config.kv_buffer_size - - self.producer_buffer: Optional[SimpleBuffer] = None - self.consumer_buffer: Optional[SimpleBuffer] = None - - self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - - # 2 pipes for every rank in the world - port_offset_base = 2 * rank - - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - if self.config.is_kv_producer: - - if self.config.kv_connector == "PyNcclConnector": - self.producer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.producer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.producer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - # We only need to initialize MooncakePipe once - self.producer_signal_pipe = self.producer_data_pipe - - self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, - self.producer_data_pipe, - self.config.kv_buffer_size) - - else: - - # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producer - if self.config.kv_connector == "PyNcclConnector": - self.consumer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.consumer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.consumer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - self.consumer_signal_pipe = self.consumer_data_pipe - - self.consumer_buffer = SimpleBuffer( - self.consumer_signal_pipe, - self.consumer_data_pipe, - self.config.kv_buffer_size, - ) - - def select(self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: - - assert self.consumer_buffer is not None, "Please initialize the "\ - "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - assert self.producer_buffer is not None, "Please initialize the "\ - "producer buffer before calling insert." - - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You have some decode requests while using " - "SimpleConnector. Their KVCache won't be sent.") - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - # When bypass_model_exec is set to False, it means that at least for one - # request its corresponding KV cache or hidden state is missing. - # In this case we need to do prefilling to recompute missing KV cache - # and hidden states. - bypass_model_exec = True - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to --max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - ret = self.select(current_tokens, - torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. - bypass_model_exec = False - num_computed_tokens_list.append(0) - continue - - roi: torch.Tensor = ret[1] - keys: torch.Tensor = ret[2] - values: torch.Tensor = ret[3] - hidden: torch.Tensor = ret[4] - - num_computed_tokens = roi.shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for cur_layer in range(start_layer, end_layer): - - layer_id = cur_layer - start_layer - kv_cache = kv_caches[layer_id] - layer = model_executable.model.layers[cur_layer] - - # get remote kvcache - remote_k, remote_v = keys[layer_id], values[layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # Here we will fall back to normal model forwarding - # But optionally you can adjust model_input so that you only do - # prefilling on those tokens that are missing KV caches. - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def close(self): - self.producer_data_pipe.close() - self.consumer_data_pipe.close() - if self.config.kv_connector == "PyNcclConnector": - self.producer_signal_pipe.close() - self.consumer_signal_pipe.close() - elif self.config.kv_connector == "MooncakeConnector": - # MooncakePipe reuses data_pipe for signal_pipe, so we only have to - # close the data_pipe. - pass diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1a11cb6d0189..1da41790f9fb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -13,8 +13,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1) from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -106,9 +106,8 @@ def get_kv_connector_cache_layout(): vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config if kv_config is not None: - connector_cls = KVConnectorFactory.get_connector_class(kv_config) - required_kvcache_layout = connector_cls.get_required_kvcache_layout( - vllm_config) + required_kvcache_layout = ( + KVConnectorBase_V1.get_required_kvcache_layout(vllm_config)) if required_kvcache_layout is not None: return required_kvcache_layout logger.info_once("Connectors do not specify a " \ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 934a03a12ee5..62a4980bff97 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -52,7 +52,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id) self._connectors.append( - KVConnectorFactory.create_connector_v1(temp_config, role)) + KVConnectorFactory.create_connector(temp_config, role)) # A mapping from request id to the index of the connector chosen to # load the request from (if any). @@ -223,9 +223,9 @@ def get_required_kvcache_layout( for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config - required_kvcache_layout = KVConnectorFactory.get_connector_class( - kv_transfer_config).get_required_kvcache_layout( - temp_vllm_config) + required_kvcache_layout = ( + KVConnectorBase_V1.get_required_kvcache_layout( + temp_vllm_config)) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py deleted file mode 100644 index 8633fdaf59f8..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector_agent.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A centralized entrypoint to perform distributed KV cache transfer. - -This implementation is a shim wrapper on two APIs exposed by `kv_connector`: -1. `send_kv_caches_and_hidden_states` -2. `recv_kv_caches_and_hidden_states -""" -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - from vllm.config import VllmConfig - -import torch - -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -logger = init_logger(__name__) - - -class KVTransferAgent: - """ - A class designated for distributed KV transfer - - Target use cases: - 1. Disaggregated prefill - 2. Remote KV cache storage - """ - - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - - self.config = config - - if config.kv_transfer_config is None: - raise ValueError("KVTransferConfig is not set in the VllmConfig," - " cannot initialize KVConnector.") - - assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ - "TransferAgent should only be used when kv_connector is set." - - self.connector = KVConnectorFactory.create_connector_v0( - rank, local_rank, config) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - self.connector.send_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches, - hidden_or_intermediate_states) - - def close(self) -> None: - self.connector.close() - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - return self.connector.recv_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches) diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 60f1d5d8bca7..5e0f64fca220 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -8,7 +8,6 @@ KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) -from vllm.distributed.parallel_state import get_world_group if TYPE_CHECKING: from vllm.config import VllmConfig @@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if (vllm_config.kv_transfer_config.is_kv_transfer_instance and _KV_CONNECTOR_AGENT is None): if envs.VLLM_USE_V1: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( config=vllm_config, role=KVConnectorRole.WORKER) else: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config, - ) + raise ValueError("V0 is no longer supported") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f31e4766bfda..6c25cdcfb7b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -196,10 +196,11 @@ class GroupCoordinator: # 3 | 1 | 3 | 1 | 3 local_rank: int # local rank used to assign devices rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication + cpu_group: Optional[ProcessGroup] # group for CPU communication + device_group: Optional[ProcessGroup] # group for device communication use_device_communicator: bool # whether to use device communicator - device_communicator: DeviceCommunicatorBase # device communicator + device_communicator: Optional[ + DeviceCommunicatorBase] # device communicator mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -250,7 +251,7 @@ def __init__( self.use_device_communicator = use_device_communicator - self.device_communicator: DeviceCommunicatorBase = None # type: ignore + self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( current_platform.get_device_communicator_cls()) @@ -364,6 +365,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return self._all_reduce_out_place(input_) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_reduce(input_) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -384,12 +387,16 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_gather(input_, dim) def all_gatherv(self, input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, sizes: Optional[list[int]] = None): + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) def reduce_scatter(self, @@ -414,10 +421,14 @@ def reduce_scatterv(self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.reduce_scatterv(input_, dim, sizes) def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.reduce_scatter(input_, dim) def gather(self, @@ -433,6 +444,8 @@ def gather(self, # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.gather(input_, dst, dim) def broadcast(self, input_: torch.Tensor, src: int = 0): @@ -667,6 +680,8 @@ def send_tensor_dict( assert dst < self.world_size, f"Invalid dst rank ({dst})" if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") self.device_communicator.send_tensor_dict( # type: ignore tensor_dict, dst) return None @@ -727,6 +742,8 @@ def recv_tensor_dict( assert src < self.world_size, f"Invalid src rank ({src})" if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.recv_tensor_dict( # type: ignore src) @@ -784,6 +801,8 @@ def barrier(self): def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" + if self.device_communicator is None: + raise ValueError("No device communicator found") self.device_communicator.send(tensor, dst) def recv(self, @@ -792,6 +811,8 @@ def recv(self, src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" + if self.device_communicator is None: + raise ValueError("No device communicator found") return self.device_communicator.recv(size, dtype, src) def destroy(self): @@ -1013,6 +1034,7 @@ def initialize_model_parallel( parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. + backend: name of torch distributed communication backend. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize @@ -1124,14 +1146,14 @@ def ensure_model_parallel_initialized( assert ( get_tensor_model_parallel_world_size() == tensor_model_parallel_size - ), ("tensor parallel group already initialized, but of unexpected size: " - f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") + ), ("tensor parallel group already initialized, but of unexpected size. " + f"got: {get_tensor_model_parallel_world_size()=} vs. " + f"wanted: {tensor_model_parallel_size=}") pp_world_size = get_pp_group().world_size assert (pp_world_size == pipeline_model_parallel_size), ( - "pipeline parallel group already initialized, but of unexpected size: " - f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}") + "pipeline parallel group already initialized, but of unexpected size. " + f"got: {pp_world_size=} vs. " + f"wanted: {pipeline_model_parallel_size=}") def prepare_communication_buffer_for_model(model: torch.nn.Module): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5eb9660cd1e8..3e2f03d56c40 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1469,6 +1469,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", "TREE_ATTN", + "XFORMERS_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py new file mode 100644 index 000000000000..6292306e7cdb --- /dev/null +++ b/vllm/entrypoints/context.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from abc import ABC, abstractmethod + +from openai_harmony import Message, Role, StreamState + +from vllm.entrypoints.harmony_utils import ( + get_encoding, get_streamable_parser_for_assistant, render_for_completion) +from vllm.entrypoints.tool import Tool +from vllm.outputs import RequestOutput + +logger = logging.getLogger(__name__) + + +class ConversationContext(ABC): + + @abstractmethod + def append_output(self, output) -> None: + pass + + @abstractmethod + async def call_tool(self) -> list[Message]: + pass + + @abstractmethod + def need_builtin_tool_call(self) -> bool: + pass + + @abstractmethod + def render_for_completion(self) -> list[int]: + pass + + +class SimpleContext(ConversationContext): + + def __init__(self): + self.last_output = None + + def append_output(self, output) -> None: + self.last_output = output + + def need_builtin_tool_call(self) -> bool: + return False + + async def call_tool(self) -> list[Message]: + raise NotImplementedError("Should not be called.") + + def render_for_completion(self) -> list[int]: + raise NotImplementedError("Should not be called.") + + +class HarmonyContext(ConversationContext): + + def __init__( + self, + messages: list, + tool_sessions: dict[str, Tool], + ): + self._messages = messages + self.tool_sessions = tool_sessions + + self.parser = get_streamable_parser_for_assistant() + self.num_init_messages = len(messages) + # TODO(woosuk): Implement the following fields. + self.num_prompt_tokens = 0 + self.num_cached_tokens = 0 + self.num_output_tokens = 0 + self.num_reasoning_tokens = 0 + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + output_token_ids = output.outputs[0].token_ids + for token_id in output_token_ids: + self.parser.process(token_id) + output_msgs = self.parser.messages + else: + # Tool output. + output_msgs = output + self._messages.extend(output_msgs) + + @property + def messages(self) -> list: + return self._messages + + def need_builtin_tool_call(self) -> bool: + last_msg = self.messages[-1] + recipient = last_msg.recipient + return recipient is not None and (recipient.startswith("browser.") + or recipient.startswith("python")) + + async def call_tool(self) -> list[Message]: + if not self.messages: + return [] + last_msg = self.messages[-1] + recipient = last_msg.recipient + if recipient is not None: + if recipient.startswith("browser."): + return await self.call_search_tool( + self.tool_sessions["browser"], last_msg) + elif recipient.startswith("python"): + return await self.call_python_tool( + self.tool_sessions["python"], last_msg) + raise ValueError("No tool call found") + + def render_for_completion(self) -> list[int]: + return render_for_completion(self.messages) + + async def call_search_tool( + self, + tool_session: Tool, + last_msg: Message, + ) -> list[Message]: + return await tool_session.get_result(self) + + async def call_python_tool( + self, + tool_session: Tool, + last_msg: Message, + ) -> list[Message]: + return await tool_session.get_result(self) + + +class StreamingHarmonyContext(HarmonyContext): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_output = None + + self.parser = get_streamable_parser_for_assistant() + self.encoding = get_encoding() + self.last_tok = None + + @property + def messages(self) -> list: + return self.parser.messages + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + tok = output.outputs[0].token_ids[0] + self.parser.process(tok) + self.last_tok = tok + else: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + + def is_expecting_start(self) -> bool: + return self.parser.state == StreamState.EXPECT_START + + def is_assistant_action_turn(self) -> bool: + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( + ) + + def render_for_completion(self) -> list[int]: + # now this list of tokens as next turn's starting tokens + # `<|start|>assistant``, + # we need to process them in parser. + rendered_tokens = super().render_for_completion() + + last_n = -1 + to_process = [] + while rendered_tokens[last_n] != self.last_tok: + to_process.append(rendered_tokens[last_n]) + last_n -= 1 + for tok in reversed(to_process): + self.parser.process(tok) + + return rendered_tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py new file mode 100644 index 000000000000..ecda35c9807e --- /dev/null +++ b/vllm/entrypoints/harmony_utils.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import datetime +from collections.abc import Iterable, Sequence +from typing import Literal, Optional + +from openai.types.responses.tool import Tool +from openai_harmony import (Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, + Role, StreamableParser, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date is None: + # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message(instructions: Optional[str] = None, + tools: Optional[list[Tool]] = None) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_chat_input(chat_msg) -> Message: + role = chat_msg["role"] + content = chat_msg["content"] + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c["text"]) for c in content] + msg = Message.from_role_and_contents(role, contents) + return msg + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser + + +def parse_chat_output( + token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: + parser = parse_output_into_messages(token_ids) + output_msgs = parser.messages + if len(output_msgs) == 0: + # The generation has stopped during reasoning. + is_tool_call = False + reasoning_content = parser.current_content + final_content = None + elif len(output_msgs) == 1: + # The generation has stopped during final message. + is_tool_call = False + reasoning_content = output_msgs[0].content[0].text + final_content = parser.current_content + else: + if len(output_msgs) != 2: + raise ValueError( + "Expected 2 output messages (reasoning and final), " + f"but got {len(output_msgs)}.") + reasoning_msg, final_msg = output_msgs + reasoning_content = reasoning_msg.content[0].text + final_content = final_msg.content[0].text + is_tool_call = final_msg.recipient is not None + return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 842a22ccebaa..ca24b0c32b73 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1189,6 +1189,8 @@ def classify( /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ClassificationRequestOutput]: """ @@ -1207,7 +1209,8 @@ def classify( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. @@ -1220,6 +1223,7 @@ def classify( items = self.encode( prompts, use_tqdm=use_tqdm, + pooling_params=pooling_params, lora_request=lora_request, pooling_task="classify", ) @@ -1272,6 +1276,7 @@ def _embedding_score( text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: @@ -1280,6 +1285,7 @@ def _embedding_score( truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, + pooling_params=pooling_params, pooling_task="embed", ) @@ -1306,6 +1312,7 @@ def _cross_encoding_score( data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: model_config = self.llm_engine.model_config @@ -1317,7 +1324,12 @@ def _cross_encoding_score( if len(data_1) == 1: data_1 = data_1 * len(data_2) - pooling_params = PoolingParams(task="score") + if pooling_params is None: + pooling_params = PoolingParams(task="score") + + model_config = self.llm_engine.model_config + pooling_params.verify("score", model_config) + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(model_config.max_model_len, @@ -1379,6 +1391,7 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs `` or @@ -1410,7 +1423,8 @@ def score( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. @@ -1494,6 +1508,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) else: return self._embedding_score( @@ -1502,6 +1517,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) def start_profile(self) -> None: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d77aee345843..57aa42720756 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -323,6 +323,7 @@ def to_sampling_params( if (top_p := self.top_p) is None: top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output guided_decoding = None @@ -340,6 +341,7 @@ def to_sampling_params( top_p=top_p, max_tokens=max_tokens, logprobs=self.top_logprobs, + stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, @@ -404,6 +406,8 @@ class ChatCompletionRequest(OpenAIBaseModel): Literal["required"], ChatCompletionNamedToolChoiceParam, ]] = "none" + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior parallel_tool_calls: Optional[bool] = False @@ -1274,11 +1278,13 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1332,6 +1338,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1344,7 +1351,8 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1375,10 +1383,12 @@ class ScoreRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1403,10 +1413,12 @@ class RerankRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankDocument(BaseModel): @@ -1553,10 +1565,12 @@ class ClassificationRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e1d8a31672ed..6ad0a8ec54f7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ import partial_json_parser import regex as re from fastapi import Request +from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter from vllm.config import ModelConfig @@ -19,6 +20,10 @@ from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, random_tool_call_id) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, get_system_message, parse_chat_input, + parse_chat_output, render_for_completion) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -125,6 +131,23 @@ def __init__( logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # NOTE(woosuk): While OpenAI's chat completion API supports browsing + # for some models, currently vLLM doesn't support it. Please use the + # Responses API instead. + self.supports_browsing = False + self.browser_tool = None + # NOTE(woosuk): Chat completion API does not support code interpreter. + # Please use the Responses API instead. + self.supports_code_interpreter = False + self.python_tool = None + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -169,7 +192,8 @@ async def create_chat_completion( if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer)): + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( @@ -184,25 +208,35 @@ async def create_chat_completion( else: tool_dicts = [tool.model_dump() for tool in request.tools] - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if not self.use_harmony: + # Common case. + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + # For GPT-OSS. + ( + conversation, + request_prompts, + engine_prompts, + ) = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -436,6 +470,11 @@ async def chat_completion_stream_generator( finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() + for _ in range(num_choices) + ] if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -597,7 +636,18 @@ async def chat_completion_stream_generator( else: logprobs = None - delta_text = output.text + if self.use_harmony: + harmony_parser = harmony_parsers[i] + for token_id in output.token_ids: + harmony_parser.process(token_id) + # FIXME(woosuk): Support function calling + is_final = harmony_parser.current_channel == "final" + if not (request.include_reasoning or is_final): + # Skip the reasoning content. + continue + delta_text = harmony_parser.last_content_delta or "" + else: + delta_text = output.text if not delta_text and not output.token_ids and \ not previous_num_tokens[i]: @@ -607,7 +657,8 @@ async def chat_completion_stream_generator( delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -621,8 +672,14 @@ async def chat_completion_stream_generator( else: current_token_ids = list(output.token_ids) + if self.use_harmony: + if is_final: + delta_message = DeltaMessage(content=delta_text) + else: + delta_message = DeltaMessage( + reasoning_content=delta_text) # handle streaming deltas for tools with named tool_choice - if tool_choice_function_name: + elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end( previous_token_ids)): @@ -990,7 +1047,38 @@ async def chat_completion_full_generator( ) else: logprobs = None - auto_tools_called = False + + if self.use_harmony: + reasoning_content, final_content, is_tool_call = ( + parse_chat_output(token_ids)) + if not request.include_reasoning: + reasoning_content = None + + if is_tool_call: + # TODO(woosuk): Implement tool call for gpt-oss. + # For now, only Responses API supports tool call for + # gpt-oss. + raise NotImplementedError( + "Tool call in Chat Completion API is not supported " + "for gpt-oss yet. Please use Responses API instead.") + else: + # Normal message + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=final_content, + ) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if is_tool_call else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue if self.reasoning_parser: try: @@ -1003,10 +1091,13 @@ async def chat_completion_full_generator( reasoning_content, content = ( reasoning_parser.extract_reasoning_content( output.text, request=request)) + if not request.include_reasoning: + reasoning_content = None else: reasoning_content = None content = output.text + auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used if (not self.enable_auto_tools or not self.tool_parser) and \ @@ -1261,3 +1352,33 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) + + def _make_request_with_harmony( + self, + request: ChatCompletionRequest, + ): + messages: list[OpenAIMessage] = [] + + # Add system message. + # NOTE: In Chat Completion API, browsing is enabled by default + # if the model supports it. TODO: Support browsing. + assert not self.supports_browsing + assert not self.supports_code_interpreter + sys_msg = get_system_message( + reasoning_effort=request.reasoning_effort, + browser_description=None, + python_description=None) + messages.append(sys_msg) + + # Add developer message. + dev_msg = get_developer_message() + messages.append(dev_msg) + + # Add user message. + for chat_msg in request.messages: + messages.append(parse_chat_input(chat_msg)) + + # Render prompt token ids. + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 5e9401cbd747..e009529fbd2a 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -90,8 +90,17 @@ def __init__( logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) - # False by default. + # If False (default), the "store" option is (silently) ignored and the + # response is not stored. If True, the response is stored in memory. + # NOTE(woosuk): This may not be intuitive for users, as the default + # behavior in OpenAI's Responses API is to store the response, but + # vLLM's default behavior is not. self.enable_store = envs.VLLM_ENABLE_RESPONSES_API_STORE + if self.enable_store: + logger.warning_once( + "`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may " + "cause a memory leak since we never remove responses from " + "the store.") # HACK(woosuk): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we # never remove responses from the store. @@ -121,9 +130,25 @@ async def create_responses( if self.engine_client.errored: raise self.engine_client.dead_error - # If store is not enabled, return an error. if request.store and not self.enable_store: - return self._make_store_not_supported_error() + if request.background: + return self.create_error_response( + err_type="invalid_request_error", + message=( + "This vLLM engine does not support `store=True` and " + "therefore does not support the background mode. To " + "enable these features, set the environment variable " + "`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching " + "the vLLM server."), + status_code=HTTPStatus.BAD_REQUEST, + ) + # Disable the store option. + # NOTE(woosuk): Although returning an error is possible, we opted + # to implicitly disable store and process the request anyway, as + # we assume most users do not intend to actually store the response + # (i.e., their request's `store=True` just because it's the default + # value). + request.store = False # Handle the previous response ID. prev_response_id = request.previous_response_id diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index b0df442dd864..834b33052b45 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -38,15 +38,15 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_call_end_token: str = "<|tool_call_end|>" self.tool_call_regex = re.compile( - r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*?)\s*<\|tool_call_end\|>" + r"<\|tool_call_begin\|>\s*(?P.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*?)\s*<\|tool_call_end\|>" ) self.stream_tool_call_portion_regex = re.compile( - r"(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*)" + r"(?P.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*)" ) self.stream_tool_call_name_regex = re.compile( - r"(?P[\w\.]+:\d+)\s*") + r"(?P.+:\d+)\s*") if not self.model_tokenizer: raise ValueError( @@ -374,4 +374,4 @@ def extract_tool_calls_streaming( except Exception: logger.exception("Error trying to handle streaming tool call.") - return None # do not stream a delta. skip this token ID. \ No newline at end of file + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py new file mode 100644 index 000000000000..01ee77414f13 --- /dev/null +++ b/vllm/entrypoints/tool.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +if TYPE_CHECKING: + # Avoid circular import. + from vllm.entrypoints.context import ConversationContext + +logger = init_logger(__name__) + + +class Tool(ABC): + + @abstractmethod + async def get_result(self, context: "ConversationContext") -> Any: + pass + + +class HarmonyBrowserTool(Tool): + + def __init__(self): + self.enabled = True + exa_api_key = os.getenv("EXA_API_KEY") + if not exa_api_key: + self.enabled = False + logger.warning_once("EXA_API_KEY is not set, browsing is disabled") + return + + try: + from gpt_oss.tools.simple_browser import SimpleBrowserTool + from gpt_oss.tools.simple_browser.backend import ExaBackend + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, browsing is disabled") + return + + browser_backend = ExaBackend(source="web", api_key=exa_api_key) + self.browser_tool = SimpleBrowserTool(backend=browser_backend) + logger.info_once("Browser tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.browser_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.browser_tool.tool_config + + +class HarmonyPythonTool(Tool): + + def __init__(self): + self.enabled = True + + try: + from gpt_oss.tools.python_docker.docker_tool import PythonTool + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, code interpreter is disabled") + return + + self.python_tool = PythonTool() + logger.info_once("Code interpreter tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.python_tool.tool_config diff --git a/vllm/envs.py b/vllm/envs.py index 8d3c7eab471c..f8a7197dd1bb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -17,6 +17,7 @@ LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -70,7 +71,6 @@ NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False - VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False @@ -152,6 +152,8 @@ VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False + VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False + VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False def get_default_cache_root(): @@ -327,6 +329,12 @@ def get_vllm_port() -> Optional[int]: (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in ("true", "1")), + # Use AITER triton unified attention for V1 attention + "VLLM_USE_AITER_UNIFIED_ATTENTION": + lambda: + (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -582,10 +590,6 @@ def get_vllm_port() -> Optional[int]: lambda: bool( int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))), - # If set, vllm will skip the deprecation warnings. - "VLLM_NO_DEPRECATION_WARNING": - lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), - # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": @@ -1027,9 +1031,13 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), + # If set to 1, use the TRTLLM Context Attention backend in flashinfer. + "VLLM_USE_TRTLLM_CONTEXT_ATTENTION": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))), + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. @@ -1060,7 +1068,8 @@ def get_vllm_port() -> Optional[int]: # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output - # messages for those requests in memory. By default, this is disabled (0). + # messages for those requests in memory. By default, this is disabled (0), + # and the "store" option is ignored. # NOTE/WARNING: # 1. Messages are kept in memory only (not persisted to disk) and will be # lost when the vLLM server shuts down. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index f6e79cd676f8..6b5a107396c9 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -5,7 +5,7 @@ import torch.nn as nn -from vllm.config import get_current_vllm_config +from vllm.config import get_cached_compilation_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -86,7 +86,7 @@ def forward_oot(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() enabled = self.enabled() if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) @@ -115,7 +115,7 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( @@ -138,7 +138,7 @@ def default_on() -> bool: Specifying 'all' or 'none' in custom_op takes precedence. """ from vllm.config import CompilationLevel - compilation_config = get_current_vllm_config().compilation_config + compilation_config = get_cached_compilation_config() default_on = (compilation_config.level < CompilationLevel.PIECEWISE or not compilation_config.use_inductor) count_none = compilation_config.custom_ops.count("none") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index bd3605378b6d..ba7105c83a92 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -33,7 +33,7 @@ def deep_gemm_block_shape() -> list[int]: return [block, block] -def _valid_deep_gemm_shape(M: int, N: int, K: int): +def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: align = deep_gemm_block_shape()[0] return align <= M and N % align == 0 and K % align == 0 @@ -51,9 +51,26 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, M = hidden_states.size(0) _, K, N = w2.size() + + align = deep_gemm_block_shape()[0] + if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( - "DeepGemm disabled: unaligned problem size. M: %s, N: %s, K: %s", + "DeepGemm disabled due to unaligned problem size. " + "M: %s, N: %s, K: %s. M should >= align size " + "and N and K must be multiples of %s." + "This is not an error and we will fall back to triton.", + M, + N, + K, + align, + ) + return False + elif N <= 512: + logger.debug_once( + "DeepGemm disabled for N <= 512. M: %s, N: %s, K: %s. " + "This means we will fallback to triton " + "for this specific shape for further speed up.", M, N, K, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 56d1dfe135b3..597af08c3c9f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,10 +1360,8 @@ def fused_experts( # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - N = w1.size(1) - should_use_deep_gemm = ((N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)) - or is_blackwell_deep_gemm_used()) + should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm( + hidden_states, w1, w2) if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): assert apply_router_weight_on_input is False assert is_act_and_mul, ( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9e7296feeae1..f155a1b11fbf 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -919,9 +919,13 @@ def _load_per_tensor_weight_scale(self, shard_id: str, elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_w13_weight_scale(self, shard_dim: int, - loaded_weight: torch.Tensor, - param: torch.Tensor, tp_rank: int): + def _load_combined_w13_weight_scale(self, shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, tp_rank: int): + """ + Load w13 weight scales assuming that w1 weight scales and w3 weight + scales are stored in the same loaded_weight tensor. + """ shard_size = param.shape[shard_dim] loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) @@ -1168,24 +1172,43 @@ def weight_loader(self, uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( ) - # For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale" - per_tensor_conditions = ( - "weight_scale_2" in weight_name if uses_weight_scale_2 else - "weight_scale" in weight_name) or "input_scale" in weight_name - - if "w13_weight_scale" in weight_name: - self._load_w13_weight_scale(shard_dim=shard_dim, - loaded_weight=loaded_weight, - param=param, - tp_rank=self.tp_rank) - elif per_tensor_conditions: + # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) + # weights scales. + # Input scales are always per-tensor. + # Weight scales: FP4 uses "weight_scale_2" and FP8 uses + # "weight_scale" for per-tensor scales. + is_per_tensor = ("weight_scale_2" in weight_name + if uses_weight_scale_2 else "weight_scale" + in weight_name) or "input_scale" in weight_name + if is_per_tensor: self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, loaded_weight=loaded_weight, expert_id=expert_id, ) - elif "weight" in weight_name: + return True if return_success else None + + # If the weight is w13_weight_scale and w13_weight_scales are + # combined into single loaded_weight, call + # _load_combined_w13_weight_scale() to load it. + # This is checked by comparing the hidden_out dims of the + # loaded_weight and the param. + if "w13_weight_scale" in weight_name: + loaded_weight_hidden_out = loaded_weight.shape[-2] + param_hidden_out = param.data.shape[-2] * self.tp_size + if loaded_weight_hidden_out == param_hidden_out: + self._load_combined_w13_weight_scale( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=self.tp_rank, + ) + return True if return_success else None + + # For other weights, call _load_model_weight_or_group_weight_scale() + # to load it. + if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 1b31368c79cd..c67f7e808301 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -107,8 +107,8 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K) - or is_blackwell_deep_gemm_used()): + if self.allow_deep_gemm and (is_blackwell_deep_gemm_used() + or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 5bfd4aaccc17..0f2e58eb9b5d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -41,35 +41,18 @@ class PoolingType(IntEnum): @dataclass(frozen=True) class ResolvedPoolingConfig: pooling_type: PoolingType - - normalize: bool - softmax: bool - step_tag_id: Optional[int] - returned_token_ids: Optional[list[int]] + task: PoolingTask @classmethod def from_config_with_defaults( cls, + task: PoolingTask, pooler_config: PoolerConfig, pooling_type: PoolingType, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, ) -> "ResolvedPoolingConfig": - return cls( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.normalize - if pooler_config.normalize is not None else normalize, - softmax=pooler_config.softmax - if pooler_config.softmax is not None else softmax, - step_tag_id=pooler_config.step_tag_id - if pooler_config.step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.returned_token_ids - if pooler_config.returned_token_ids is not None else - returned_token_ids, - ) + return cls(task=task, + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type) @dataclass(frozen=True) @@ -89,22 +72,15 @@ def for_encode( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.ALL, - default_normalize: bool = False, - default_softmax: bool = False, - default_step_tag_id: Optional[int] = None, - default_returned_token_ids: Optional[list[int]] = None, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="encode", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - step_tag_id=default_step_tag_id, - returned_token_ids=default_returned_token_ids, ) if resolved_config.pooling_type == PoolingType.STEP: - return StepPooler.from_config(resolved_config) + return StepPooler() return SimplePooler.from_config(resolved_config) @@ -113,14 +89,11 @@ def for_embed( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = True, - default_softmax: bool = False, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="embed", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) return SimplePooler.from_config(resolved_config) @@ -131,23 +104,18 @@ def for_classify( classifier: Optional[ClassifierFn], *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = False, - default_softmax: bool = True, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="classify", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) - base_pooler = SimplePooler.from_config(resolved_config) - if classifier is None: - return base_pooler + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) return ClassifierPooler( - pooling=base_pooler.pooling, + pooling=pooling, classifier=classifier, - act_fn=base_pooler.head.activation, ) @abstractmethod @@ -198,11 +166,17 @@ def get_prompt_token_ids( ] -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: +def get_pooling_params( + pooling_metadata: PoolingMetadata) -> list[PoolingParams]: if isinstance(pooling_metadata, V0PoolingMetadata): pooling_params = [p for _, p in pooling_metadata.seq_groups] else: pooling_params = pooling_metadata.pooling_params + return pooling_params + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + pooling_params = get_pooling_params(pooling_metadata) tasks: list[PoolingTask] = [ task for pooling_param in pooling_params @@ -484,49 +458,30 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": - if pooler_config.normalize and pooler_config.softmax: - raise ValueError("`normalize=True` and `softmax=True` should not " - "be set together") - - activation: PoolerActivation - if pooler_config.normalize: - activation = PoolerNormalize() - elif pooler_config.softmax: - activation = PoolerClassify() - else: - activation = PoolerIdentity() - - return cls(activation) - def __init__(self, activation: PoolerActivation) -> None: super().__init__() - self.activation = activation def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead - if isinstance(pooled_data, list): - for i in range(len(pooled_data)): - pooled_data[i] = pooled_data[i].to(torch.float32) - else: - pooled_data = pooled_data.to(torch.float32) + return self.activation(pooled_data) + + +class EmbeddingPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerNormalize()) + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation - if isinstance(pooling_metadata, V0PoolingMetadata): - dimensions_list = [ - pooling_param.dimensions - for _, pooling_param in pooling_metadata.seq_groups - ] - else: - assert isinstance(pooled_data, list) - dimensions_list = [ - pooling_param.dimensions - for pooling_param in pooling_metadata.pooling_params - ] + dimensions_list = [ + pooling_param.dimensions for pooling_param in pooling_params + ] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) @@ -541,7 +496,41 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], for vecs, d in zip(pooled_data, dimensions_list) ] - return self.activation(pooled_data) + # for normalize + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] + + return pooled_data + + +class RewardPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerClassify()) + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + pooling_params = get_pooling_params(pooling_metadata) + + # for softmax + flags = [p.softmax for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] + + return pooled_data class SimplePooler(Pooler): @@ -559,8 +548,12 @@ def from_config( pooler_config: ResolvedPoolingConfig, ) -> "SimplePooler": pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) - head = PoolerHead.from_config(pooler_config) - + if pooler_config.task == "embed": + head = EmbeddingPoolerHead() + elif pooler_config.task == "encode": + head = RewardPoolerHead() + else: + raise NotImplementedError(f"Unknown task: {pooler_config.task}") return cls(pooling, head) def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: @@ -587,29 +580,11 @@ def forward( class StepPooler(Pooler): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": - assert pooler_config.pooling_type == PoolingType.STEP - - return cls( - PoolerHead.from_config(pooler_config), - step_tag_id=pooler_config.step_tag_id, - returned_token_ids=pooler_config.returned_token_ids, - ) - - def __init__( - self, - head: PoolerHead, - *, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ) -> None: + def __init__(self, ) -> None: super().__init__() self.pooling = AllPool() - self.head = head - self.step_tag_id = step_tag_id - self.returned_token_ids = returned_token_ids + self.head = RewardPoolerHead() def extract_states( self, @@ -620,10 +595,15 @@ def extract_states( prompt_token_ids = get_prompt_token_ids(pooling_metadata) pooled_data = list[torch.Tensor]() - returned_token_ids = self.returned_token_ids - step_tag_id = self.step_tag_id - for data, token_id in zip(pooled_data_lst, prompt_token_ids): + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip(pooled_data_lst, + prompt_token_ids, + pooling_params): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: data = data[:, returned_token_ids] @@ -669,14 +649,14 @@ def act_fn_for_cross_encoder(config: ModelConfig): def __init__( self, pooling: PoolingFn, - classifier: ClassifierFn, - act_fn: PoolerActivation, + classifier: Optional[ClassifierFn], + act_fn: Optional[PoolerActivation] = None, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier - self.act_fn = act_fn + self.act_fn = act_fn or PoolerClassify() def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -688,15 +668,25 @@ def forward( ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_output = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_output = self.classifier(torch.stack(pooled_data)) - else: - pooled_output = [self.classifier(data) for data in pooled_data] + if self.classifier is not None: + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_data = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_data = self.classifier(torch.stack(pooled_data)) + else: + pooled_data = [self.classifier(data) for data in pooled_data] - scores = self.act_fn(pooled_output) + pooling_params = get_pooling_params(pooling_metadata) + flags = [p.activation for p in pooling_params] + + if len(set(flags)) == 1: + scores = self.act_fn(pooled_data) if flags[0] else pooled_data + else: + scores = [ + self.act_fn(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] return build_output(scores) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a96f3ee5c301..5359189caa2a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -412,12 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if bitsandbytes.__version__ < "0.45.3": + if bitsandbytes.__version__ < "0.46.1": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.45.3.") + "install bitsandbytes>=0.46.1.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.45.3 via " - "`pip install bitsandbytes>=0.45.3` to use " + raise ImportError("Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " "bitsandbytes quantizer.") from err self.topk_indices_dtype = None self.quant_config = quant_config diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py deleted file mode 100644 index 24dd86620fe9..000000000000 --- a/vllm/model_executor/layers/rotary_embedding.py +++ /dev/null @@ -1,1967 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Rotary Positional Embeddings.""" -import itertools -import math -from typing import Any, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform - -if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - - -def _rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., ::2] - x2 = x[..., 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) - - -def _apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - is_neox_style: bool) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, - not is_neox_style).squeeze(0) - else: - return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) - - -@CustomOp.register("rotary_embedding") -class RotaryEmbedding(CustomOp): - """Original rotary positional embedding.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, - self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - # key may be None in some cases, e.g. cross-layer KV sharing - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - def forward_cuda( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - from vllm import _custom_ops as ops - - # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) - # is expensive, so avoid calling it if possible - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) - return query, key - - def forward_xpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - from vllm._ipex_ops import ipex_ops as ops - - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, - dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if key is None: - # XPU kernel doesn't support key=None so fall back to native impl - # TODO(sarckk): add support for optional key in - # ipex.llm.functional.rotary_embedding_batched - return self.forward_native(positions, query, key, offsets) - else: - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) - return query, key - - def forward_neuron( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - - def _apply_rotary_emb_neuron( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, - ) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - # x1 = x[..., ::2] - - # x2 = x[..., 1::2] - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) - x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - if offsets is not None: - positions = positions + offsets - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - - if self.rotary_dim == self.head_size: - query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) - query = query.reshape(query_shape) - if key is not None: - key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) - key = key.reshape(key_shape) - else: - head_size = query.shape[-1] - query_reshaped = query.view(-1, head_size) - query_pass = query_reshaped[:, self.rotary_dim:].view( - *query.shape[:-1], head_size - self.rotary_dim) - query_rot = query_reshaped[:, :self.rotary_dim].view( - *query.shape[:-1], self.rotary_dim) - query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, - self.is_neox_style) - query = torch.cat((query_rot, query_pass), - dim=-1).reshape(query_shape) - - if key is not None: - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - def extra_repr(self) -> str: - s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" - s += f", max_position_embeddings={self.max_position_embeddings}" - s += f", base={self.base}, is_neox_style={self.is_neox_style}" - return s - - -class LinearScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with linear scaling. - - It supports multiple scaling factors. Since multiple LoRA adapters may have - different scaling factors, we need multiple cos/sin caches. In this way, - instead of running rotary embedding kernel per lora, we can run multiple - lora in a batched way. - - In addition to that, we also keep the cos/sin cache for the scaling factor - of 1 (default) at all times. - - Exemplary for two scaling factors x=1, y and z with embeddings - [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and - [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and - [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], - - we construct the cos/sin cache as follows: - [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], - ... - [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] - - We then use offsets to index into the cos/sin cache for - the respective scaling factors. - - The offset to cache can be accessed via `scaling_factor_to_offset` API. - - Credits to the Reddit user /u/kaiokendev - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factors: Union[list[float], float], - dtype: torch.dtype, - ) -> None: - if isinstance(scaling_factors, float): - scaling_factors = [scaling_factors] - self.scaling_factors: list[float] = scaling_factors # noqa - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - # Lazy initialized. - self._scaling_factor_to_offset: dict[float, int] - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.base) - cache_list: list[torch.Tensor] = [] - # offsets to the next cache in a tensor. - # Each offset corresponds to the same index in scaling_factors. - offsets: list[int] = [] - for scaling_factor in self.scaling_factors: - # NOTE(woosuk): self.max_position_embeddings is the original - # maximum length before applying the rope scaling. - # Thus, the maximum length after applying the rope scaling is - # self.max_position_embeddings * self.scaling_factor. - max_len = self.max_position_embeddings * scaling_factor - t = torch.arange(max_len, dtype=torch.float) - t = t / scaling_factor - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - if not cache_list: - offset = 0 - else: - last_offset = offsets[-1] - next_max_len = cache_list[-1].shape[0] - offset = last_offset + next_max_len - offsets.append(offset) - cache_list.append(cache) - self._scaling_factor_to_offset = { - float(scaling_factor): offsets[i] - for i, scaling_factor in enumerate(self.scaling_factors) - } - assert len(self.scaling_factors) == len(offsets) - return torch.cat(cache_list, dim=0) - - @property - def scaling_factor_to_offset(self) -> dict[float, int]: - return self._scaling_factor_to_offset - - -class NTKScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with fixed and mixed NTK scaling. - https://kexue.fm/archives/9706 """ - - def __init__(self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - mixed_b: Optional[float] = None) -> None: - self.scaling_factor = scaling_factor - self.mixed_b = mixed_b - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - base = self.base * (self.scaling_factor if self.mixed_b is None else 1) - inv_freq = super()._compute_inv_freq(base) - - if self.mixed_b is None: - inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) - else: - a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / - 2)**self.mixed_b - lambda_1_m = (a * torch.arange( - 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() - inv_freq = inv_freq / lambda_1_m - - return inv_freq - - -class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with Dynamic NTK scaling. - - Credits to the Reddit users /u/bloc97 and /u/emozilla - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - ) -> None: - self.scaling_factor = scaling_factor - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_cos_sin_cache(self) -> torch.Tensor: - # NOTE(woosuk): self.max_position_embeddings is the original - # maximum length before applying the rope scaling. - # Thus, the maximum length after applying the rope scaling is - # self.max_position_embeddings * self.scaling_factor. - max_len = self.max_position_embeddings * self.scaling_factor - base = self.base * ( - (self.scaling_factor * max_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.rotary_dim / - (self.rotary_dim - 2)) - inv_freq = self._compute_inv_freq(base) - t = torch.arange(max_len, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - -class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with Dynamic NTK alpha. - - Based on the original RotaryEmbedding implementation. - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_alpha: float, - dtype: torch.dtype, - ) -> None: - self.scaling_alpha = scaling_alpha - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_cos_sin_cache(self) -> torch.Tensor: - # For Hunyuan DynamicNTKAlphaRotaryEmbedding - max_len = self.max_position_embeddings - base = self.base * self.scaling_alpha**(self.rotary_dim / - (self.rotary_dim - 2)) - inv_freq = self._compute_inv_freq(base) - t = torch.arange(max_len, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - -# Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim(num_rotations: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> float: - return (dim * math.log(max_position_embeddings / - (num_rotations * 2 * math.pi))) / (2 * - math.log(base)) - - -# Find dim range bounds based on rotations -def _yarn_find_correction_range( - low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> tuple[int, int]: - low = math.floor( - _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil( - _yarn_find_correction_dim(high_rot, dim, base, - max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def _yarn_linear_ramp_mask(low: float, high: float, dim: int, - dtype: torch.dtype) -> torch.Tensor: - if low == high: - high += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -def _yarn_get_mscale(scale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * math.log(scale) + 1.0 - - -class YaRNScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with YaRN method. - - Credits to Peng et al. github.com/jquesnelle/yarn - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - ) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation - self.mscale = float( - _yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / - self.rotary_dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - - low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) - # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - _yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - dtype=torch.float32) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) - cache = torch.cat((cos, sin), dim=-1) - return cache - - -class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): - """Phi3 family of models scaled rotary embedding. - - Based on the original RotaryEmbedding implementation. - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - original_max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - short_factor: list[float], - long_factor: list[float], - short_mscale: Optional[float] = None, - long_mscale: Optional[float] = None, - ): - super().__init__() - - if is_neox_style is False: - raise ValueError( - "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." - ) - - self.rotary_dim = rotary_dim - self.head_size = head_size - self.max_position_embeddings = max_position_embeddings - self.original_max_position_embeddings = original_max_position_embeddings - self.base = base - self.short_factor = short_factor - self.long_factor = long_factor - - scale = self.max_position_embeddings / \ - self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = math.sqrt( - 1 + math.log(scale) / - math.log(self.original_max_position_embeddings)) - if short_mscale is None: - short_mscale = scaling_factor - if long_mscale is None: - long_mscale = scaling_factor - - self.short_mscale = short_mscale - self.long_mscale = long_mscale - - short_cache = self._compute_cos_sin_cache( - original_max_position_embeddings, short_factor, short_mscale) - short_cache = short_cache.to(dtype) - - long_cache = self._compute_cos_sin_cache(max_position_embeddings, - long_factor, long_mscale) - long_cache = long_cache.to(dtype) - - long_short_cache = torch.cat([short_cache, long_cache], dim=0) - self.register_buffer("long_short_cos_sin_cache", - long_short_cache, - persistent=False) - - def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: - rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) - inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) - return inv_freq - - def _compute_cos_sin_cache( - self, - max_position_embeddings: int, - rescale_factors: list[float], - mscale: float, - ) -> torch.Tensor: - inv_freq = self._compute_inv_freq(rescale_factors) - t = torch.arange(max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale - sin = freqs.sin() * mscale - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert key is not None - query = query.view(*query.shape[:-1], -1, self.head_size) - key = key.view(*key.shape[:-1], -1, self.head_size) - - k = self.original_max_position_embeddings - long_prompt_offset = (torch.any(positions > k).float() * - torch.full_like(positions, k)).long() - idx = (torch.add(positions, long_prompt_offset) - if long_prompt_offset is not None else positions) - idx = torch.add(idx, offsets) if offsets is not None else idx - cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) - - cos, sin = cos_sin.chunk(2, dim=-1) - cos = cos.repeat(1, 2).unsqueeze(-2) - sin = sin.repeat(1, 2).unsqueeze(-2) - - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = query_rot * cos + _rotate_neox(query_rot) * sin - query = torch.cat((query_rot, query_pass), dim=-1) - - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = key_rot * cos + _rotate_neox(key_rot) * sin - key = torch.cat((key_rot, key_pass), dim=-1) - - return query.flatten(-2), key.flatten(-2) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -class DeepseekScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with YaRN method. - - Credits to Peng et al. github.com/jquesnelle/yarn - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1, - mscale_all_dim: float = 0, - ) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation. - self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, - self.rotary_dim, - 2, - dtype=torch.float, - device=current_platform.device_type) / - self.rotary_dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - - low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) - # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - _yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device=current_platform.device_type, - dtype=torch.float32) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """PyTorch-native implementation equivalent to forward().""" - assert key is not None - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] - - if self.cos_sin_cache.device != positions.device: - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - return query, key - - -class Llama3RotaryEmbedding(RotaryEmbedding): - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - scaling_factor: float, - low_freq_factor: float, - high_freq_factor: float, - orig_max_position: int, - ) -> None: - self.scaling_factor = scaling_factor - self.low_freq_factor = low_freq_factor - self.high_freq_factor = high_freq_factor - self.orig_max_position = orig_max_position - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - inv_freqs = super()._compute_inv_freq(base) - low_freq_wavelen = self.orig_max_position / self.low_freq_factor - high_freq_wavelen = self.orig_max_position / self.high_freq_factor - - wave_len = 2 * math.pi / inv_freqs - if self.low_freq_factor != self.high_freq_factor: - smooth = (self.orig_max_position / wave_len - self.low_freq_factor - ) / (self.high_freq_factor - self.low_freq_factor) - else: - smooth = 0 - new_freqs = torch.where( - wave_len < high_freq_wavelen, - inv_freqs, - torch.where( - wave_len > low_freq_wavelen, - inv_freqs / self.scaling_factor, - (1 - smooth) * inv_freqs / self.scaling_factor + - smooth * inv_freqs, - ), - ) - return new_freqs - - -class Llama4VisionRotaryEmbedding(RotaryEmbedding): - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - ): - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - inv_freqs = super()._compute_inv_freq(base) - inv_freqs = inv_freqs[:(self.rotary_dim // 2)] - return inv_freqs - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.base) - - # self.max_position_embeddings here is number of image patches - # i.e. (image_size // patch_size) ** 2 - num_patches = self.max_position_embeddings - img_idx = torch.arange(num_patches, - dtype=torch.int32) \ - .reshape(num_patches, 1) - img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) - img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN - num_patches_single_dim = int(math.sqrt(num_patches)) - frequencies_x = img_idx % num_patches_single_dim - frequencies_y = img_idx // num_patches_single_dim - freqs_x = ((frequencies_x + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs_y = ((frequencies_y + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs = torch.cat([freqs_x, freqs_y], - dim=-1).float().contiguous()[..., ::2] - freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) - cache = torch.view_as_complex( - torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) - return cache - - def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert key is not None - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) - query_ = torch.view_as_complex(query.float().reshape( - *query.shape[:-1], -1, 2)) - key_ = torch.view_as_complex(key.float().reshape( - *key.shape[:-1], -1, 2)) - broadcast_shape = [ - d if i == 1 or i == (query_.ndim - 1) else 1 - for i, d in enumerate(query_.shape) - ] - freqs_ci = self.cos_sin_cache.view(*broadcast_shape) - query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) - key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) - return query_out.type_as(query), key_out.type_as(key) - - -class MRotaryEmbedding(RotaryEmbedding): - """Rotary Embedding with Multimodal Sections.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - mrope_section: Optional[list[int]] = None, - ) -> None: - # In Qwen2.5-VL, the maximum index value is related to the duration of - # the input video. We enlarge max_position_embeddings to 4 times to get - # a larger the cos and sin cache. - self.cache_max_position_num = max_position_embeddings * 4 - super().__init__(head_size, rotary_dim, self.cache_max_position_num, - base, is_neox_style, dtype) - - self.mrope_section = mrope_section - if self.mrope_section: - assert sum(self.mrope_section) == rotary_dim // 2 - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """PyTorch-native implementation equivalent to forward(). - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - assert positions.ndim == 1 or positions.ndim == 2 - assert key is not None - - num_tokens = positions.shape[-1] - cos_sin = self.cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - if positions.ndim == 2: - assert self.mrope_section - - cos = torch.cat([ - m[i] - for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) - ], - dim=-1) - sin = torch.cat([ - m[i] - for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) - ], - dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - @classmethod - def get_input_positions( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - second_per_grid_ts: Optional[list[float]], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[list[list[int]], int]: - """Get mrope input positions and delta value.""" - - image_grid_thw = [] if image_grid_thw is None else image_grid_thw - video_grid_thw = [] if video_grid_thw is None else video_grid_thw - second_per_grid_ts = [] if second_per_grid_ts is None else \ - second_per_grid_ts - - llm_positions, mrope_position_delta = \ - cls.get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - return llm_positions.tolist(), mrope_position_delta - - @classmethod - def get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - from vllm.transformers_utils.config import thinker_uses_mrope - if thinker_uses_mrope(hf_config): - return cls._omni_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - elif hf_config.model_type in ["glm4v", "glm4v_moe"]: - return cls._glm4v_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - else: - return cls._vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - ) - - @classmethod - def _glm4v_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" - - image_token_id = hf_config.image_token_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1]): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - - for t_idx in range(llm_grid_t): - t_index = torch.tensor(t_idx).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view( - 1, -1, 1).expand(1, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view( - 1, 1, -1).expand(1, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + - st_idx) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, - "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * - tokens_per_second).long().flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _omni_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: Optional[list[float]] = None, - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - - Example: - - (V_i are vision position ids, A_i are audio position ids) - - |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... - |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... - """ - - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. - - thinker_config = hf_config.thinker_config - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] - - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if src_item[idx] not in [ - audio_token_id, video_token_id, image_token_id - ]: - if use_audio_in_video and idx > 0: - if src_item[idx] == vision_end_token_id and \ - src_item[idx - 1] == audio_end_token_id: - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif src_item[idx] == audio_start_token_id and \ - src_item[idx - 1] == vision_start_token_id: - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], - dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: list[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len( - t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - new_src_item.extend([video_token_id] * - vision_ntoken_per_chunk) - vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_chunk, - grid_hs, grid_ws).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len) * [audio_token_id]) - audio_start_idx = start_idx if len( - audio_llm_pos_ids_list - ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 - if min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = (torch.arange( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len)).expand(3, -1) + - audio_start_idx).split(1, - dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id]) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand( - 3, -1) + llm_pos_ids_list[-1].max() + 1).split( - 1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = torch.cat(llm_pos_ids_list, - dim=1).max() + 1 - len(src_item) - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @staticmethod - def _get_llm_pos_ids_for_vision( - start_idx: int, - vision_idx: int, - spatial_merge_size: int, - t_index: list[int], - grid_hs: torch.Tensor, - grid_ws: torch.Tensor, - ) -> torch.Tensor: - llm_pos_ids_list = [] - llm_grid_h = grid_hs[vision_idx] // spatial_merge_size - llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( - len(t_index), -1, llm_grid_w).flatten()) - w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( - len(t_index), llm_grid_h, -1).flatten()) - t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( - -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() - _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) - llm_pos_ids_list.append(_llm_pos_ids + start_idx) - llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) - return llm_pos_ids - - @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, - interval: int) -> list[list[int]]: - ranges: list[list[int]] = [[] - for _ in range((max(lst) // interval) + 1)] - for num in lst: - index = num // interval - ranges[index].append(num) - return ranges - - @staticmethod - def get_next_input_positions( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> list[list[int]]: - return [ - list( - range(context_len + mrope_position_delta, - seq_len + mrope_position_delta)) for _ in range(3) - ] - - @staticmethod - def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, - mrope_position_delta: int, - context_len: int, num_new_tokens: int): - - values = np.arange(mrope_position_delta + context_len, - mrope_position_delta + context_len + num_new_tokens, - dtype=out.dtype) - out[:, out_offset:out_offset + num_new_tokens] = values - - @classmethod - def omni_get_updates_use_audio_in_video( - cls, - thinker_config: PretrainedConfig, - audio_len: int, - video_grid_thw: Union[list[int], torch.Tensor], - video_second_per_grid_t: float, - ) -> list[int]: - """Get video prompt updates when `use_audio_in_video` is True. - - In this case, audio and vision update ids will be split into - chunks and interleaved (details in `_omni_get_input_positions_tensor`). - - <|video_bos|><|VIDEO|><|video_eos|> => - <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> - """ - - audio_token_id = thinker_config.audio_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - grid_t = video_grid_thw[0] - grid_h = video_grid_thw[1] - grid_w = video_grid_thw[2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * video_second_per_grid_t * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - - updates = [audio_start_token_id] - added_audio_len = 0 - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( - spatial_merge_size**2) - updates.extend([video_token_id] * vision_ntoken_per_chunk) - - audio_chunk_size = min(t_ntoken_per_chunk, - audio_len - added_audio_len) - updates.extend(audio_chunk_size * [audio_token_id]) - added_audio_len += audio_chunk_size - if added_audio_len < audio_len: - updates.extend((audio_len - added_audio_len) * [audio_token_id]) - updates.extend([audio_end_token_id]) - - return updates - - -@CustomOp.register("dual_chunk_rotary_embedding") -class DualChunkRotaryEmbedding(CustomOp): - """Rotary positional embedding for Dual Chunk Attention.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, - chunk_size: int, - local_size: int, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.chunk_size = chunk_size - self.local_size = local_size - self.dtype = dtype - self.device = torch.device(f"cuda:{torch.cuda.current_device()}") - (q_cache, qc_cache, k_cache, qc_no_clamp_cache, - q_inter_cache) = self._compute_cos_sin_cache() - - self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) - self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) - self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) - self.register_buffer("cos_sin_qc_no_clamp_cache", - qc_no_clamp_cache, - persistent=False) - self.register_buffer("cos_sin_q_inter_cache", - q_inter_cache, - persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - chunk_len = self.chunk_size - self.local_size - q_t = torch.arange(chunk_len, dtype=torch.float) - qc_t = (torch.arange(chunk_len, dtype=torch.float) + - chunk_len).clamp(max=self.chunk_size) - k_t = torch.arange(self.max_position_embeddings, - dtype=torch.float) % chunk_len - - # count from chunk_len, no clamp(self.chunk_size) restriction - qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len - # count from self.chunk_size for q_inter's rope - q_inter_t = torch.arange(chunk_len, - dtype=torch.float) + self.chunk_size - - q_freqs = torch.outer(q_t, inv_freq) - qc_freqs = torch.outer(qc_t, inv_freq) - k_freqs = torch.outer(k_t, inv_freq) - qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) - q_inter_freqs = torch.outer(q_inter_t, inv_freq) - - q_cos = q_freqs.cos() - q_sin = q_freqs.sin() - qc_cos = qc_freqs.cos() - qc_sin = qc_freqs.sin() - k_cos = k_freqs.cos() - k_sin = k_freqs.sin() - - qc_no_clamp_cos = qc_no_clamp_freqs.cos() - qc_no_clamp_sin = qc_no_clamp_freqs.sin() - q_inter_cos = q_inter_freqs.cos() - q_inter_sin = q_inter_freqs.sin() - - q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) - q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) - return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - query = query.view(*query.shape[:-1], -1, self.head_size) - key = key.view(*key.shape[:-1], -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] - else: - query_pass = None - key_pass = None - - positions_with_offsets = (torch.add(positions, offsets) - if offsets is not None else positions) - key = self._apply_rotary_embedding( - self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) - chunk_len = self.chunk_size - self.local_size - query = self._apply_rotary_embedding( - self.cos_sin_q_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) - query_succ = self._apply_rotary_embedding( - self.cos_sin_qc_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) - query_inter = self._apply_rotary_embedding( - self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), - query_rot, query_pass) - query_succ_critical = self._apply_rotary_embedding( - self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) - query_inter_critical = self._apply_rotary_embedding( - self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) - - # merge query into one tensor to simplify the interfaces - query = torch.cat(( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ), - dim=-1) - return query, key - - def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin - - if self.rotary_dim < self.head_size: - hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) - else: - hidden = hidden_rot - return hidden.flatten(-2).squeeze(0) - - def extra_repr(self) -> str: - s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" - s += f", max_position_embeddings={self.max_position_embeddings}" - s += f", base={self.base}, is_neox_style={self.is_neox_style}" - s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" - return s - - -_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} - - -def get_rope( - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool = True, - rope_scaling: Optional[dict[str, Any]] = None, - dtype: Optional[torch.dtype] = None, - partial_rotary_factor: float = 1.0, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, -) -> RotaryEmbedding: - if dtype is None: - dtype = torch.get_default_dtype() - if rope_scaling is not None: - # Transforms every value that is a list into a tuple for caching calls - rope_scaling_tuple = { - k: tuple(v) if isinstance(v, list) else v - for k, v in rope_scaling.items() - } - rope_scaling_args = tuple(rope_scaling_tuple.items()) - else: - rope_scaling_args = None - - if dual_chunk_attention_config is not None: - dual_chunk_attention_tuple = { - k: tuple(v) if isinstance(v, list) else v - for k, v in dual_chunk_attention_config.items() - if k != "sparse_attention_config" - } - dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) - else: - dual_chunk_attention_args = None - - if partial_rotary_factor < 1.0: - rotary_dim = int(rotary_dim * partial_rotary_factor) - key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dual_chunk_attention_args, dtype) - if key in _ROPE_DICT: - return _ROPE_DICT[key] - - if dual_chunk_attention_config is not None: - extra_kwargs = { - k: v - for k, v in dual_chunk_attention_config.items() - if k in ("chunk_size", "local_size") - } - rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - **extra_kwargs) - elif not rope_scaling: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, dtype) - else: - scaling_type = rope_scaling["rope_type"] - - if scaling_type == "llama3": - scaling_factor = rope_scaling["factor"] - low_freq_factor = rope_scaling["low_freq_factor"] - high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - scaling_factor, low_freq_factor, - high_freq_factor, - original_max_position) - elif scaling_type == "mllama4": - rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype) - elif scaling_type == "default": - if "mrope_section" in rope_scaling: - rotary_emb = MRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - mrope_section=rope_scaling["mrope_section"], - ) - else: - rotary_emb = RotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - ) - elif scaling_type == "linear": - scaling_factor = rope_scaling["factor"] - rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype) - elif scaling_type == "ntk": - scaling_factor = rope_scaling["factor"] - mixed_b = rope_scaling.get('mixed_b', None) - rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype, - mixed_b) - elif scaling_type == "dynamic": - if "alpha" in rope_scaling: - scaling_alpha = rope_scaling["alpha"] - rotary_emb = DynamicNTKAlphaRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_alpha, dtype) - elif "factor" in rope_scaling: - scaling_factor = rope_scaling["factor"] - rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor, dtype) - else: - raise ValueError("Dynamic rope scaling must contain either " - "'alpha' or 'factor' field") - elif scaling_type == "yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow") - } - rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, - original_max_position, - base, is_neox_style, - scaling_factor, dtype, - **extra_kwargs) - elif scaling_type == "deepseek_yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - # assert max_position == original_max_position * scaling_factor - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow", "mscale", "mscale_all_dim") - } - rotary_emb = DeepseekScalingRotaryEmbedding( - head_size, rotary_dim, original_max_position, base, - is_neox_style, scaling_factor, dtype, **extra_kwargs) - elif scaling_type == "longrope": - short_factor = rope_scaling["short_factor"] - long_factor = rope_scaling["long_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_scaling.items() - if k in ("short_mscale", "long_mscale") - } - rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, dtype, short_factor, long_factor, - **extra_kwargs) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - _ROPE_DICT[key] = rotary_emb - return rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py new file mode 100644 index 000000000000..564f9a5c0075 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Rotary Positional Embeddings.""" +from typing import Any, Optional + +import torch + +from .base import RotaryEmbedding +from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding +from .dual_chunk_rope import DualChunkRotaryEmbedding +from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding +from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding +from .linear_scaling_rope import LinearScalingRotaryEmbedding +from .llama3_rope import Llama3RotaryEmbedding +from .llama4_vision_rope import Llama4VisionRotaryEmbedding +from .mrope import MRotaryEmbedding +from .ntk_scaling_rope import NTKScalingRotaryEmbedding +from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding +from .yarn_scaling_rope import YaRNScalingRotaryEmbedding + +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: float, + is_neox_style: bool = True, + rope_scaling: Optional[dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, dual_chunk_attention_args, dtype) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + **extra_kwargs) + elif not rope_scaling: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + else: + scaling_type = rope_scaling["rope_type"] + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype) + elif scaling_type == "ntk": + scaling_factor = rope_scaling["factor"] + mixed_b = rope_scaling.get('mixed_b', None) + rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype, + mixed_b) + elif scaling_type == "dynamic": + if "alpha" in rope_scaling: + scaling_alpha = rope_scaling["alpha"] + rotary_emb = DynamicNTKAlphaRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_alpha, dtype) + elif "factor" in rope_scaling: + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype) + else: + raise ValueError("Dynamic rope scaling must contain either " + "'alpha' or 'factor' field") + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, dtype, + **extra_kwargs) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow", "mscale", "mscale_all_dim") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py new file mode 100644 index 000000000000..10fce857a8ae --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Rotary Positional Embeddings Base Class.""" +from typing import Optional + +import torch + +from vllm.model_executor.custom_op import CustomOp + +from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + from vllm import _custom_ops as ops + + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if key is None: + # XPU kernel doesn't support key=None so fall back to native impl + # TODO(sarckk): add support for optional key in + # ipex.llm.functional.rotary_embedding_batched + return self.forward_native(positions, query, key, offsets) + else: + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_neuron( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + def _apply_rotary_emb_neuron( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + # x1 = x[..., ::2] + + # x2 = x[..., 1::2] + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) + x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + if offsets is not None: + positions = positions + offsets + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + + if self.rotary_dim == self.head_size: + query = apply_rotary_emb_dispatch(query, cos, sin, + self.is_neox_style) + query = query.reshape(query_shape) + if key is not None: + key = apply_rotary_emb_dispatch(key, cos, sin, + self.is_neox_style) + key = key.reshape(key_shape) + else: + head_size = query.shape[-1] + query_reshaped = query.view(-1, head_size) + query_pass = query_reshaped[:, self.rotary_dim:].view( + *query.shape[:-1], head_size - self.rotary_dim) + query_rot = query_reshaped[:, :self.rotary_dim].view( + *query.shape[:-1], self.rotary_dim) + query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), + dim=-1).reshape(query_shape) + + if key is not None: + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py new file mode 100644 index 000000000000..8d821bea19e3 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch + +from vllm.platforms import current_platform + +if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + +# common functions +def rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + if current_platform.is_cuda(): + return apply_rotary_emb(x.unsqueeze(0), cos, sin, + not is_neox_style).squeeze(0) + else: + return apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + +# yarn functions +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> tuple[int, int]: + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py new file mode 100644 index 000000000000..cd888b733426 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch + +from vllm.platforms import current_platform + +from .base import RotaryEmbedding +from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, + yarn_linear_ramp_mask) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + assert key is not None + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py new file mode 100644 index 000000000000..3d8da0fa9d8f --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.model_executor.custom_op import CustomOp + +from .common import rotate_gptj, rotate_neox + + +@CustomOp.register("dual_chunk_rotary_embedding") +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, + q_inter_cache) = self._compute_cos_sin_cache() + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer("cos_sin_qc_no_clamp_cache", + qc_no_clamp_cache, + persistent=False) + self.register_buffer("cos_sin_q_inter_cache", + q_inter_cache, + persistent=False) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + + chunk_len).clamp(max=self.chunk_size) + k_t = torch.arange(self.max_position_embeddings, + dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, + dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + else: + query_pass = None + key_pass = None + + positions_with_offsets = (torch.add(positions, offsets) + if offsets is not None else positions) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, query_pass) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + + # merge query into one tensor to simplify the interfaces + query = torch.cat(( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py new file mode 100644 index 000000000000..1da39bbd303b --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from .base import RotaryEmbedding + + +class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK alpha. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_alpha: float, + dtype: torch.dtype, + ) -> None: + self.scaling_alpha = scaling_alpha + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # For Hunyuan DynamicNTKAlphaRotaryEmbedding + max_len = self.max_position_embeddings + base = self.base * self.scaling_alpha**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py new file mode 100644 index 000000000000..ec2008b90cfb --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from .base import RotaryEmbedding + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py new file mode 100644 index 000000000000..6e920991882d --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Union + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from .base import RotaryEmbedding + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factors: Union[list[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: list[float] = scaling_factors # noqa + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: list[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: list[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> dict[float, int]: + return self._scaling_factor_to_offset diff --git a/vllm/model_executor/layers/rotary_embedding/llama3_rope.py b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py new file mode 100644 index 000000000000..adcef549bc4c --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch + +from .base import RotaryEmbedding + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py new file mode 100644 index 000000000000..415a85ab698b --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch + +from .base import RotaryEmbedding + + +class Llama4VisionRotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ): + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + return inv_freqs + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + + # self.max_position_embeddings here is number of image patches + # i.e. (image_size // patch_size) ** 2 + num_patches = self.max_position_embeddings + img_idx = torch.arange(num_patches, + dtype=torch.int32) \ + .reshape(num_patches, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN + num_patches_single_dim = int(math.sqrt(num_patches)) + frequencies_x = img_idx % num_patches_single_dim + frequencies_y = img_idx // num_patches_single_dim + freqs_x = ((frequencies_x + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], + dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + cache = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + return cache + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + query_ = torch.view_as_complex(query.float().reshape( + *query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape( + *key.shape[:-1], -1, 2)) + broadcast_shape = [ + d if i == 1 or i == (query_.ndim - 1) else 1 + for i, d in enumerate(query_.shape) + ] + freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py new file mode 100644 index 000000000000..a75b9e5eb435 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from typing import Optional, Union + +import numpy as np +import torch +from transformers import PretrainedConfig + +from .base import RotaryEmbedding +from .common import apply_rotary_emb_dispatch + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[list[int]] = None, + ) -> None: + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + super().__init__(head_size, rotary_dim, self.cache_max_position_num, + base, is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @classmethod + def get_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[list[list[int]], int]: + """Get mrope input positions and delta value.""" + + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else \ + second_per_grid_ts + + llm_positions, mrope_position_delta = \ + cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + return llm_positions.tolist(), mrope_position_delta + + @classmethod + def get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + elif hf_config.model_type in ["glm4v", "glm4v_moe"]: + return cls._glm4v_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _glm4v_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + + @classmethod + def _vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len( + t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * + vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, + grid_hs, grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len) * [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, + dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split( + 1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( + len(t_index), -1, llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] + for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> list[list[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + @staticmethod + def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + + values = np.arange(mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype) + out[:, out_offset:out_offset + num_new_tokens] = values + + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, + audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates diff --git a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py new file mode 100644 index 000000000000..42926bad22ef --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .base import RotaryEmbedding + + +class NTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with fixed and mixed NTK scaling. + https://kexue.fm/archives/9706 """ + + def __init__(self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None) -> None: + self.scaling_factor = scaling_factor + self.mixed_b = mixed_b + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + base = self.base * (self.scaling_factor if self.mixed_b is None else 1) + inv_freq = super()._compute_inv_freq(base) + + if self.mixed_b is None: + inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + else: + a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / + 2)**self.mixed_b + lambda_1_m = (a * torch.arange( + 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + inv_freq = inv_freq / lambda_1_m + + return inv_freq diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py new file mode 100644 index 000000000000..9c36d633e2a9 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from typing import Optional + +import torch +import torch.nn as nn + +from .common import rotate_neox + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: list[float], + long_factor: list[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.rotary_dim = rotary_dim + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + + long_short_cache = torch.cat([short_cache, long_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: list[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = query_rot * cos + rotate_neox(query_rot) * sin + query = torch.cat((query_rot, query_pass), dim=-1) + + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = key_rot * cos + rotate_neox(key_rot) * sin + key = torch.cat((key_rot, key_pass), dim=-1) + + return query.flatten(-2), key.flatten(-2) diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py new file mode 100644 index 000000000000..851565c5667a --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from .base import RotaryEmbedding +from .common import (yarn_find_correction_range, yarn_get_mscale, + yarn_linear_ramp_mask) + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 4e3ba107ba7e..4cf73e2e0ea5 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -24,10 +24,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -260,6 +262,81 @@ def forward( return hidden_states, aux_hidden_states return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """Load weights, mapping q/k/v projections to fused qkv_proj.""" + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if "scale" in name: + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is None: + continue + name = remapped_name + + mapped = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + mapped = True + break + + if is_pp_missing_parameter(name, self): + mapped = True + break + + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore[attr-defined] + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + mapped = True + break + + if mapped: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM @@ -304,8 +381,7 @@ def __init__(self, *, vllm_config, prefix: str = "") -> None: else: # Placeholder for lm_head on non-last ranks self.lm_head = PPMissingLayer() - # Provide a reference to the model's method for generating empty - # tensors (used in pipeline parallel schedule) + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -316,7 +392,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: - # Forward pass through the Arcee model backbone model_output = self.model(input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 9030ff307bee..6f21cd267b0e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -44,6 +44,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } +class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.activation is None: + pooler_config.activation = False + + class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -155,6 +164,26 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: vllm_config.recalculate_max_model_len(max_model_len) +class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.step_tag_id is None: + pooler_config.step_tag_id = 151651 + + +class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.softmax is None: + pooler_config.softmax = False + + class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod @@ -218,6 +247,34 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: config.max_model_len) +class GptOssForCausalLMConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + decoding_config = vllm_config.decoding_config + if decoding_config.reasoning_backend == "": + decoding_config.reasoning_backend = "GptOss" + + # Increase the max capture size from 512 to 1024 for performance. + # NOTE(woosuk): This will increase the number of CUDA graphs + # from 67 to 83. + scheduler_config = vllm_config.scheduler_config + if len(scheduler_config.cuda_graph_sizes) == 1: + max_capture_size = scheduler_config.cuda_graph_sizes[0] + # FIXME(woosuk): When using full cuda graph with FA3, the max + # supported size is 992. + if max_capture_size < 1024: + cuda_graph_sizes = [1, 2, 4] + # Step size 8 for small batch sizes + cuda_graph_sizes += [i for i in range(8, 256, 8)] + # Step size 16 for larger batch sizes + cuda_graph_sizes += [i for i in range(256, 1025, 16)] + scheduler_config.cuda_graph_sizes = cuda_graph_sizes + logger.info( + "Overriding max cuda graph capture size to " + "%d for performance.", 1024) + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -309,8 +366,12 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, + "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, + "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, + "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, + "GptOssForCausalLM": GptOssForCausalLMConfig, } diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index c702684c6caa..bd3e27662ee7 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -123,6 +123,7 @@ def __init__( config.n_routed_experts, bias=False, quant_config=None, + params_dtype=torch.float32, prefix=f"{prefix}.gate") self.gate.e_score_correction_bias = nn.Parameter( @@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - router_logits, _ = self.gate(hidden_states) + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py new file mode 100644 index 000000000000..896560fa24ca --- /dev/null +++ b/vllm/model_executor/models/gpt_oss.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn +from transformers import GptOssConfig + +from vllm import envs +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv + +from .utils import extract_layer_index, maybe_prefix + + +class OAIAttention(nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + dtype=torch.float32, + rope_scaling={ + "rope_type": + "yarn", + "factor": + config.rope_scaling["factor"], + "original_max_position_embeddings": + config.rope_scaling["original_max_position_embeddings"], + "beta_fast": + config.rope_ntk_beta, + "beta_slow": + config.rope_ntk_alpha, + }, + is_neox_style=True, + ) + + tp_size = get_tensor_model_parallel_world_size() + + attention_sink_dtype = ( + torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16) + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads // tp_size, + dtype=attention_sink_dtype, + requires_grad=False)) + + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.q_size = self.num_attention_heads * self.head_dim // tp_size + self.kv_size = self.num_key_value_heads * self.head_dim // tp_size + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + + self.qkv = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_attention_heads, + total_num_kv_heads=self.num_key_value_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.num_local_attention_heads = config.num_attention_heads // tp_size + self.num_local_key_value_heads = config.num_key_value_heads // tp_size + + # Only apply sliding window to every other layer + sliding_window = (config.sliding_window if self.layer_idx % + 2 == 0 else None) + self.attn = Attention( + self.num_local_attention_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_local_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.sinks, + ) + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + t = self.norm(hidden_states) + + qkv, _ = self.qkv(t) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + hidden_states + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + layer_idx: int, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + self.experts = FusedMoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_token, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.router(t) + t = self.experts(hidden_states=t, router_logits=g) + return x + t + + +class TransformerBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.mlp = MLPBlock(config, + self.layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + attn_output = self.attn(hidden_states, positions) + output = self.mlp(attn_output) + return output + + +@support_torch_compile +class GptOssModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + self.layers = torch.nn.ModuleList([ + TransformerBlock( + self.config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, f"block.{layer_idx}"), + ) for layer_idx in range(self.config.num_hidden_layers) + ]) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + x = self.embedding(input_ids) + for layer in self.layers: + x = layer(x, positions) + x = self.norm(x) + return x + + +class GptOssForCausalLM(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config.hf_config + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.model_config.vocab_size, + self.model_config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + assert intermediate_tensors is None + assert inputs_embeds is None + return self.model(input_ids, positions) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", + "w13_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4d68227b2af8..697fa020deb4 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, +from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) import torch @@ -14,6 +14,10 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.sampling_metadata import SamplingMetadata +else: + VllmConfig = Any + Pooler = Any + SamplingMetadata = Any logger = init_logger(__name__) @@ -34,7 +38,7 @@ class VllmModel(Protocol[T_co]): def __init__( self, - vllm_config: "VllmConfig", + vllm_config: VllmConfig, prefix: str = "", ) -> None: ... @@ -96,7 +100,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: "SamplingMetadata", + sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... @@ -140,7 +144,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ - pooler: "Pooler" + pooler: Pooler """The pooler is only called on TP rank 0.""" diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 263f4c8379cf..ab21b7ce2c5f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -593,7 +593,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config, classifier=self.score, default_pooling_type=PoolingType.LAST, - default_normalize=False, - default_softmax=False, ), }) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 0c4284f7daaa..8c64f636c6a0 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -90,15 +90,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "qwen2_vl")) config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None # logit bias for sigmoid normalization self.LOGIT_BIAS = 2.65 self.score = JinaVLScorer(config) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 9c0a6ba92389..1c7ddd7df7f8 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -46,7 +46,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -79,6 +79,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import is_pp_missing_parameter, maybe_prefix @@ -118,15 +119,22 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -class KimiVLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: Union[torch.Tensor, list[torch.Tensor]] +class KimiVLImagePixelInputs(TensorSchema): """ - Shape:`(num_patches, num_channels, patch_size, patch_size)` + Dimensions: + - nc: Number of channels + - np: Number of patches + - ps: Patch size + - ni: Number of images """ + type: Literal["pixel_values"] = "pixel_values" - image_grid_hws: torch.Tensor - """Shape:`(num_images, 2)`""" + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("np", 3, "ps", "ps"), + ] + + image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)] # TODO: support embeds too @@ -348,8 +356,6 @@ def _parse_and_validate_image_input( pixel_values = pixel_values.reshape(-1, num_channels, patch_size, patch_size) pixel_values = pixel_values.to(self.vision_tower.dtype) - # image_grid_hws.shape = (N, 2) - assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" return KimiVLImagePixelInputs( type="pixel_values", diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py new file mode 100644 index 000000000000..e66eba0212e3 --- /dev/null +++ b/vllm/model_executor/models/ovis2_5.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" PyTorch Ovis model.""" +from collections.abc import Iterable, Mapping +from functools import partial +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.ovis import (OvisImagePatchInputs, + VisualEmbedding) +from vllm.model_executor.models.siglip2navit import Siglip2NavitModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.ovis2_5 import (IMAGE_TOKEN, + INDICATOR_IDS, + VIDEO_TOKEN, + Ovis2_5Config) +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", + "qwen3": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, + "qwen3": 151655, +} + + +def _ovis2_5_field_config(): + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video")) + + +class VisualTokenizer(torch.nn.Module): + """ + VIT + """ + + def __init__( + self, + config: PretrainedConfig, + visual_vocab_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.vit = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + ) + # reserved tokens for INDICATOR_IDS + head_dim = visual_vocab_size - len(INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + self.config.hidden_size * self.config.hidden_stride**2, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + model_type = config.model_type + if model_type == "siglip2_navit": + return Siglip2NavitModel(config=config, ) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self): + return next(self.head.parameters()).dtype + + @property + def device(self): + return next(self.head.parameters()).device + + def tokenize(self, logits): + tokens = torch.softmax(logits, dim=-1, + dtype=torch.float32).to(logits.dtype) + return tokens + + def encode(self, pixel_values, grid_thws): + features = self.vit(pixel_values, + grid_thws, + output_hidden_states=True, + return_dict=True) + # refer to qwen2.5-vl patchmerger + seq_len, _ = features.shape + features = features.reshape(seq_len // (self.config.hidden_stride**2), + -1) + + return features + + def forward(self, pixel_values, grid_thws) -> torch.Tensor: + features = self.encode(pixel_values, grid_thws) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [#Token, VocabSize-4], + # so padding with [#Token, 4], after which, + # tokens' shape should become [#Token, VocabSize]; + tokens = torch.nn.functional.pad( + tokens, + (0, len(INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class Ovis2_5ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Ovis2_5Config) + + def get_hf_processor(self, **kwargs): + vit_config = self.get_hf_config().vit_config + return self.ctx.get_hf_processor( + Ovis2_5Processor, + image_pad_token=self.get_image_pad_token(), + patch_size=vit_config.patch_size, + hidden_stride=vit_config.hidden_stride, + temporal_patch_size=vit_config.temporal_patch_size, + ) + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_size_with_most_features(self) -> ImageSize: + # NOTE(myselvess): max_pixels 1344 * 1792 hardcoded in original code + return ImageSize(width=1344, height=1792) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + temporal_patch_size = vit_config.temporal_patch_size + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + num_frames % temporal_patch_size + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = image_height // patch_size + grid_w = image_width // patch_size + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches + return num_vision_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + num_frames = 0 + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + if next_max_tokens > max_tokens: + break + num_frames = next_num_frames + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + return max(max_frames_per_video, 1) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[BaseImageProcessor], + ) -> int: + num_video_tokens = self.get_num_image_tokens(image_width=image_width, + image_height=image_height, + num_frames=num_frames) + return num_video_tokens + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + return mm_data + + +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] + ): + + def visual_indicators_to_visual_tokens( + self, + visual_indicators: list[int], + ) -> list[int]: + """ + Filter image indicators placeholders and convert them to corresponding + tokens in visual tokenizer. + """ + hf_config = self.info.get_hf_config() + vte_vocab_size = hf_config.visual_vocab_size + return [ + vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 + for x in visual_indicators if x < -300 + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + hf_processor = self.info.get_hf_processor() + + if "videos" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), True) + for grid in processed_outputs["video_grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + processed_outputs["video_indicator_tokens"] = indicator_tokens + if "images" in mm_data: + visual_indicators = [ + hf_processor.construct_visual_indicators((1, 1, 1), False) + for grid in processed_outputs["grids"] + ] + indicator_tokens = [ + self.visual_indicators_to_visual_tokens(indicator) + for indicator in visual_indicators + ] + + processed_outputs["indicator_tokens"] = indicator_tokens + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _ovis2_5_field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx, modality: str): + if modality == "image": + grid = out_mm_kwargs["grids"][item_idx][0] + elif modality == "video": + grid = out_mm_kwargs["video_grids"][item_idx][0] + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_visual_placeholders(grid, ) + + return [ + PromptReplacement( + modality=modality, + target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, + replacement=partial(get_replacement_ovis, modality=modality), + ) for modality in ("image", "video") + ] + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder) +class Ovis2_5(nn.Module, SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: PretrainedConfig = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + + self.vte = VisualEmbedding(config.visual_vocab_size, + config.hidden_size) + + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + + # TODO(Isotr0py): PP support + # self.make_empty_intermediate_tensors = ( + # self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_visual_input( + self, is_video, + **kwargs: object) -> Optional[OvisImagePatchInputs]: + if is_video: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + else: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}") + + return OvisImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + indicator_tokens = image_input["indicator_tokens"] + grid_thws = image_input["grids"] + + indicator_per_image = list( + map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype), grid_thws) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + indicator_embeds = self.vte(indicator_tokens) + + visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) + indicator_embeds_per_image = indicator_embeds.split( + indicator_per_image) + + vision_embeddings = [] + for indicator, visual in zip(indicator_embeds_per_image, + visual_embeds_per_image): + vision_embeddings_per_image = [] + visual = visual.unsqueeze(0) + for i in range(visual.shape[0]): + vision_embeddings_per_image.append( + torch.cat([indicator[i:i + 1], visual[i]], dim=0)) + vision_embeddings_per_image.append(indicator[i + 1:]) + vision_embeddings.append( + torch.cat(vision_embeddings_per_image, dim=0)) + return tuple(vision_embeddings) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + embeddings = [] + + # NOTE: _parse_and_validate_visual_input has side-effects and pops + # keys from kwargs. We process images first, then videos. + image_input = self._parse_and_validate_visual_input(False, **kwargs) + if image_input: + embeddings.extend(self._process_image_input(image_input)) + + video_input = self._parse_and_validate_visual_input(True, **kwargs) + if video_input: + embeddings.extend(self._process_image_input(video_input)) + + return tuple(embeddings) if embeddings else None + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + tmp = torch.concat(multimodal_embeddings, dim=0) + inputs_embeds[input_ids == self.image_pad_token_id] = tmp + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.compute_logits(hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index f12e9a041a94..9b6b70c75c34 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -117,8 +117,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): Pooler.for_encode( pooler_config, default_pooling_type=PoolingType.STEP, - default_normalize=False, - default_softmax=True, - default_step_tag_id=151651, ) }) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 40d77312b72c..633f8598e879 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1395,11 +1395,12 @@ def __init__( **kwargs, ): self.image_processor = Tarsier2ImageProcessor(**vision_config) - super().__init__(image_processor=self.image_processor, - tokenizer=tokenizer, - video_processor=Qwen2VLVideoProcessor(), - chat_template=None, - **kwargs) + super().__init__( + image_processor=self.image_processor, + tokenizer=tokenizer, + video_processor=Qwen2VLVideoProcessor(**vision_config), + chat_template=None, + **kwargs) class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9b6ab52d8680..e9b41ca1cd7e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -74,6 +74,7 @@ "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), + "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -228,6 +229,7 @@ "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), + "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py new file mode 100644 index 000000000000..182706101e12 --- /dev/null +++ b/vllm/model_executor/models/siglip2navit.py @@ -0,0 +1,608 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +from typing import Optional, Union + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithNoAttention + +from vllm.platforms import _Backend +from vllm.transformers_utils.configs.ovis2_5 import Siglip2NavitConfig + +from .vision import get_vit_attn_backend + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Siglip2VisionEmbeddings(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.image_size = config.image_size + self.num_patches = config.num_patches + self.preserve_original_pe = config.preserve_original_pe + self.hidden_stride = config.hidden_stride + + # siglip2 naflex + if self.num_patches > 0: + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * + self.patch_size, + out_features=self.embed_dim, + ) + if self.preserve_original_pe: + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + if self.preserve_original_pe: + self.num_patches = (self.image_size // self.patch_size)**2 + self.position_embedding_size = (self.image_size // + self.patch_size) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape ( + num_patches, + num_channels * temporal_patch_size * patch_size * patch_size + ) + grid_thws: (`torch.LongTensor`): + grid shape (num_patches, 3) + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + if isinstance(self.patch_embedding, nn.Linear): + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + elif isinstance(self.patch_embedding, nn.Conv2d): + pixel_values = pixel_values.view( + -1, self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, self.patch_size) + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.reshape(-1, self.embed_dim) + + if self.preserve_original_pe: + assert grid_thws is not None + pos_embed_new = torch.zeros_like(patch_embeds) + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, + -1).unsqueeze(0).permute(0, 3, 1, 2) + cnt = 0 + for t, h, w in grid_thws: + volume = t * h * w + pe = F.interpolate(positional_embeddings, + size=(h, w), + mode='bicubic', + align_corners=False) + pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) + pe = pe[0].repeat(t, 1) + pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, + w // self.hidden_stride, self.hidden_stride, + -1) + pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) + pos_embed_new[cnt:cnt + volume] = pe + cnt += volume + patch_embeds = patch_embeds + pos_embed_new + + return patch_embeds + + +# copy from flash_attn/layers/rotary.py +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_flash_attn_backend: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + if is_flash_attn_backend: + from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb + else: + apply_rotary_emb_func = apply_rotary_emb_torch + q_embed = apply_rotary_emb_func(q.float(), cos.float(), + sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), + sin.float()).type_as(k) + return q_embed, k_embed + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.use_rope = config.use_rope + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Ovis2.5 does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, + torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(seq_length, self.num_heads, self.head_dim) + keys = keys.view(seq_length, self.num_heads, self.head_dim) + values = values.view(seq_length, self.num_heads, self.head_dim) + + if self.use_rope: + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), + keys.unsqueeze(0), cos, sin, + self.is_flash_attn_backend) + queries = queries.squeeze(0) + keys = keys.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, + max_seqlen).reshape(seq_length, -1) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + batch_size = cu_seqlens.shape[0] - 1 + queries_ = queries.view(batch_size, max_seqlen, self.num_heads, + self.head_dim) + keys_ = keys.view(batch_size, max_seqlen, self.num_heads, + self.head_dim) + values_ = values.view(batch_size, max_seqlen, self.num_heads, + self.head_dim) + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = queries_[:, start_idx:end_idx] + k_i = keys_[:, start_idx:end_idx] + v_i = values_[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + attn_output = torch.cat(outputs, dim=1) + attn_output = rearrange(attn_output, + "b s h d -> (s b) (h d)").contiguous() + attn_output = self.out_proj(attn_output) + return attn_output + + +class Siglip2MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2NavitConfig + """ + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Siglip2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + self.rotary_pos_emb = VisionRotaryEmbedding( + config.hidden_size // config.num_attention_heads // 2) + self.patch_size = config.patch_size + self.hidden_stride = config.hidden_stride + self.window_size = config.window_size + self.spatial_merge_unit = config.hidden_stride * config.hidden_stride + if config.fullatt_block_indexes is None: + self.fullatt_block_indexes = None + else: + self.fullatt_block_indexes = [ + int(i) for i in config.fullatt_block_indexes.split('|') + ] + + # copied from qwen2.5_vl + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + # patch (after merge) number in each window + vit_merger_window_size = (self.window_size // self.hidden_stride // + self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.hidden_stride, # number of patch after merge + grid_w // self.hidden_stride, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + # Ignore copy + def forward( + self, + inputs_embeds, + grid_thws: torch.Tensor, + output_hidden_states: bool = False, + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if + you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding + lookup matrix. + grid_thws (`torch.LongTensor`): + grid shape (num_patches, 3) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of + a plain tuple. + """ + rotary_pos_emb = self.rot_pos_emb(grid_thws) + window_index, cu_window_seqlens = self.get_window_index(grid_thws) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=inputs_embeds.device, + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = inputs_embeds.size() + inputs_embeds = inputs_embeds.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + inputs_embeds = inputs_embeds[window_index, :, :] + inputs_embeds = inputs_embeds.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have + # same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 + # for more information + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + reverse_indices = torch.argsort(window_index) + encoder_states = () if output_hidden_states else None + + hidden_states = inputs_embeds + for index, block in enumerate(self.layers): + if (not self.fullatt_block_indexes + or index in self.fullatt_block_indexes): + cu_seqlens_tmp = cu_seqlens + else: + cu_seqlens_tmp = cu_window_seqlens + hidden_states = block(hidden_states, cu_seqlens_tmp, + position_embeddings) + if output_hidden_states: + hidden_states_ = hidden_states.reshape( + seq_len // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + encoder_states += (hidden_states_[reverse_indices, :].reshape( + seq_len, -1), ) + # tokens = self.post_trunk_norm(tokens) + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) + + return hidden_states, encoder_states + + +class Siglip2VisionTransformer(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self._use_flash_attention_2 = \ + (config._attn_implementation == "flash_attention_2") + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = True, + ) -> Union[ + tuple[torch.Tensor], + tuple[torch.Tensor, tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) + of the input images. + """ + hidden_states = self.embeddings(pixel_values, grid_thws) + + last_hidden_state, hidden_states = self.encoder( + hidden_states, grid_thws, output_hidden_states) + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + output = (last_hidden_state, ) + output += (hidden_states, ) if output_hidden_states else () + return output + + return last_hidden_state + + +class Siglip2NavitModel(torch.nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + + self.vision_model = Siglip2VisionTransformer(config) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + tuple[torch.Tensor], + tuple[torch.Tensor, tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + if return_dict is None: + return_dict = self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + grid_thws=grid_thws, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5059d1e1d9fe..0c3df267edb1 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -90,7 +90,7 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig -) -> Union[ColumnParallelLinear, RowParallelLinear]: +) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -445,7 +445,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` - # method after v4.54.0 is released + # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"), config_override: self.model: PreTrainedModel = AutoModel.from_config( @@ -520,7 +520,7 @@ def pipeline_parallel(self): for i in range(len(layers)): if start_layer <= i and i < end_layer: continue - layers[i] = PPMissingLayer(return_tuple=True) + layers[i] = PPMissingLayer() # Layers after module list for name in pp_plan[module_list_idx + 1:]: @@ -533,14 +533,16 @@ def tensor_parallel(self): Apply the model's tensor parallelization plan. Currently only supports linear layers. """ - if not self.model.supports_tp_plan: - if self.tp_size <= 1: - return + tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {} + if not tp_plan and self.tp_size > 1: raise ValueError( f"{type(self.model)} does not support tensor parallel yet!") - tp_plan = self.model._tp_plan + # Some weight loaders expect linear layers to inherit from vLLM's + # LinearBase class, so we set a default style which causes any + # unspecified linear layers to be replaced with ReplicatedLinear + tp_plan[".*"] = "replicated" def _tensor_parallel(module: nn.Module, prefix: str = ""): for child_name, child_module in module.named_children(): @@ -552,6 +554,7 @@ def _tensor_parallel(module: nn.Module, prefix: str = ""): child_module, style, self.quant_config) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) + break else: _tensor_parallel(child_module, prefix=qual_name) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 28508e1bac1e..fecd14dde4a8 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -534,16 +534,10 @@ class PPMissingLayer(torch.nn.Identity): def __init__(self, *args, **kwargs): super().__init__() - self.return_tuple = kwargs.get("return_tuple", False) def forward(self, *args, **kwargs): - """ - Return the first arg from args or the first value from kwargs. - - Wraps the input in a tuple if `self.return_tuple` is True. - """ - input = args[0] if args else next(iter(kwargs.values())) - return (input, ) if self.return_tuple else input + """Return the first arg from args or the first value from kwargs.""" + return args[0] if args else next(iter(kwargs.values())) _CPU_OFFLOAD_BYTES = 0 diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b61b39a9274d..dd9356e399c9 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -271,6 +271,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") @@ -291,6 +292,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, elif selected_backend == _Backend.TREE_ATTN: logger.info_once("Using Tree Attention backend on V1 engine.") return TREE_ATTN_V1 + elif selected_backend == _Backend.XFORMERS_VLLM_V1: + logger.info_once("Using XFormers backend on V1 engine.") + return XFORMERS_V1 from vllm.attention.selector import is_attn_backend_supported diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 61ce868c13b4..a85b583abc2c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -63,6 +63,7 @@ class _Backend(enum.Enum): NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() + XFORMERS_VLLM_V1 = enum.auto() class PlatformEnum(enum.Enum): diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 54ffc83cd565..d26e4b335038 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None) -> bool: + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention( and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + and envs.VLLM_ROCM_USE_AITER) and sinks is None) else: return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 @@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 3 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) class RocmPlatform(Platform): @@ -170,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8" + "quark", "ptpc_fp8", "mxfp4" ] @classmethod @@ -469,4 +470,4 @@ def device_count(cls) -> int: @classmethod def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: - return True \ No newline at end of file + return True diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 23eb775f2dc6..7077f68353fc 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy from typing import TYPE_CHECKING, Optional import msgspec @@ -19,13 +20,25 @@ class PoolingParams( """API parameters for pooling models. Attributes: + normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. + activation: Whether to apply activation function to + the classification outputs. + softmax: Whether to apply softmax to the reward outputs. """ + ## for embeddings models dimensions: Optional[int] = None + normalize: Optional[bool] = None - output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + ## for classification models + activation: Optional[bool] = None + + ## for reward models + softmax: Optional[bool] = None + step_tag_id: Optional[int] = None + returned_token_ids: Optional[list[int]] = None task: Optional[PoolingTask] = None """Internal use only.""" @@ -33,15 +46,32 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + + @property + def all_parameters(self) -> list[str]: + return [ + "dimensions", "normalize", "activation", "softmax", "step_tag_id", + "returned_token_ids" + ] + + @property + def valid_parameters(self): + return { + "embed": ["dimensions", "normalize"], + "classify": ["activation"], + "score": ["activation"], + "encode": ["softmax", "step_tag_id", "returned_token_ids"], + } + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams( - dimensions=self.dimensions, - task=self.task, - requires_token_ids=self.requires_token_ids, - ) + return deepcopy(self) + + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: - def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: if self.task is None: self.task = task elif self.task != task: @@ -52,28 +82,91 @@ def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: # which is not available in model config. So, it's not included # in this method - if self.dimensions is not None: - if not model_config.is_matryoshka: - raise ValueError( - f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.') + self._merge_default_parameters(model_config) + self._set_default_parameters(model_config) + self._verify_valid_parameters() + + def _merge_default_parameters(self, + model_config: Optional["ModelConfig"] = None + ) -> None: + + if model_config is None: + return - mds = model_config.matryoshka_dimensions - if mds is not None: - if self.dimensions not in mds: + pooler_config = model_config.pooler_config + if pooler_config is None: + return + + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + + for k in valid_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): + if self.task == "embed": + if self.normalize is None: + self.normalize = True + + if self.dimensions is not None and model_config is not None: + if not model_config.is_matryoshka: raise ValueError( - f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') - elif self.dimensions < 1: - raise ValueError("Dimensions must be greater than 0") + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.' + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") + + elif self.task in ["classify", "score"]: + if self.activation is None: + self.activation = True + + elif self.task == "encode": + if self.softmax is None: + self.softmax = True + else: + raise ValueError(f"Unknown pooling task: {self.task}") + + def _verify_valid_parameters(self): + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + invalid_parameters = [] + for k in self.all_parameters: + if k in valid_parameters: + continue + + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"Task {self.task} only supports {valid_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters") def __repr__(self) -> str: return (f"PoolingParams(" - f"dimensions={self.dimensions}, " f"task={self.task}, " + f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " + f"activation={self.activation}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 1c3f78f2edbf..b987adeb6428 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -4,6 +4,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser +from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser @@ -20,4 +21,5 @@ "Glm4MoeModelReasoningParser", "MistralReasoningParser", "Step3ReasoningParser", + "GptOssReasoningParser", ] diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py new file mode 100644 index 000000000000..05a72ac23bf2 --- /dev/null +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("GptOss") +class GptOssReasoningParser(ReasoningParser): + """ + Reasoning parser for GptOss model. + + The GptOss model uses harmony to extract reasoning content and this parser + is only used for detecting the end of the reasoning content. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.reasoning_end_token_ids = self.model_tokenizer.encode( + "<|start|>assistant<|channel|>final<|message|>") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + end_token_ids = self.reasoning_end_token_ids + assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty" + # Check if the end sequence is present in the input_ids. + # We search from the end of input_ids to find the last match. + for i in range(len(input_ids) - len(end_token_ids), -1, -1): + if input_ids[i:i + len(end_token_ids)] == end_token_ids: + return True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + raise RuntimeError( + "GptOss model uses harmony to extract reasoning content. This " + "function should not be called.") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + raise RuntimeError( + "GptOss model uses harmony to extract reasoning content. This " + "function should not be called.") + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + raise RuntimeError( + "GptOss model uses harmony to extract reasoning content. This " + "function should not be called.") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 322e53b75394..52e4cbd09615 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -156,6 +156,7 @@ class SamplingParams( Note that the implementation follows the OpenAI API: The API will always return the log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. + When set to -1, return all `vocab_size` log probabilities. prompt_logprobs: Number of log probabilities to return per prompt token. detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. @@ -414,9 +415,10 @@ def _verify_args(self) -> None: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if self.logprobs is not None and self.logprobs < 0: + if (self.logprobs is not None and self.logprobs != -1 + and self.logprobs < 0): raise ValueError( - f"logprobs must be non-negative, got {self.logprobs}.") + f"logprobs must be non-negative or -1, got {self.logprobs}.") if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cc41a771d06c..413a8abd97f6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -35,6 +35,7 @@ MllamaConfig, MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, + Ovis2_5Config, OvisConfig, RWConfig, SpeculatorsConfig, Step3TextConfig, Step3VLConfig, UltravoxConfig) @@ -85,6 +86,8 @@ def _get_hf_token() -> Optional[str]: "speculators": SpeculatorsConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, + "ovis": OvisConfig, + "ovis2_5": Ovis2_5Config, "ultravox": UltravoxConfig, "step3_vl": Step3VLConfig, "step3_text": Step3TextConfig, @@ -449,6 +452,20 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) + # ModelOpt 0.31.0 and after saves the quantization config in the model + # config file. + quantization_config = config_dict.get("quantization_config", None) + + # ModelOpt 0.29.0 and before saves the quantization config in a separate + # "hf_quant_config.json" in the same directory as the model config file. + if quantization_config is None \ + and file_or_path_exists(model, "hf_quant_config.json", revision): + quantization_config = get_hf_file_to_dict("hf_quant_config.json", + model, revision) + + if quantization_config is not None: + config.quantization_config = quantization_config + if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) config.update(hf_overrides_kw) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 64ace167a5a0..2fa3438effdc 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,8 @@ from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.ovis import OvisConfig +from vllm.transformers_utils.configs.ovis2_5 import Ovis2_5Config from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -45,6 +47,8 @@ "NemotronHConfig", "Nemotron_Nano_VL_Config", "NVLM_D_Config", + "OvisConfig", + "Ovis2_5Config", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py new file mode 100644 index 000000000000..550f5e15dbcc --- /dev/null +++ b/vllm/transformers_utils/configs/ovis.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +# ruff: noqa: E501 +# adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py +# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py +# Ovis Config with AimV2 config registration removed for Transformers compatibility +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class AIMv2Config(PretrainedConfig): + """This is the configuration class to store the configuration of an [`AIMv2Model`]. + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224). + Args: + hidden_size: Dimension of the hidden representations. + intermediate_size: Dimension of the SwiGLU representations. + num_hidden_layers: Number of hidden layers in the Transformer. + num_attention_heads: Number of attention heads for each attention layer + in the Transformer. + num_channels: Number of input channels. + image_size: Image size. + patch_size: Patch size. + rms_norm_eps: Epsilon value used for the RMS normalization layer. + attention_dropout: Dropout ratio for attention probabilities. + projection_dropout: Dropout ratio for the projection layer after the attention. + qkv_bias: Whether to add a bias to the queries, keys and values. + use_bias: Whether to add a bias in the feed-forward and projection layers. + kwargs: Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias + + +# ---------------------------------------------------------------------- +# Visual Tokenizer Configuration +# ---------------------------------------------------------------------- +class BaseVisualTokenizerConfig(PretrainedConfig): + + def __init__(self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, + dict]] = None, + hidden_stride: int = 1, + **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.tokenize_function = tokenize_function + self.tau = tau + if isinstance(depths, str): + depths = [int(x) for x in depths.split('|')] + self.depths = depths + self.backbone_kwargs = dict[str, Any]() + self.drop_cls_token = drop_cls_token + if backbone_config is not None: + assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + if not isinstance(backbone_config, PretrainedConfig): + model_type = backbone_config['model_type'] + if model_type != "aimv2": + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + else: + backbone_config = AIMv2Config(**backbone_config) + self.backbone_config = backbone_config + self.hidden_stride = hidden_stride + + +class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "aimv2_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) +AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) + + +# ---------------------------------------------------------------------- +# Ovis Configuration +# ---------------------------------------------------------------------- +class OvisConfig(PretrainedConfig): + model_type = "ovis" + + def __init__(self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, + dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs): + super().__init__(**kwargs) + if llm_config is not None: + assert isinstance(llm_config, (PretrainedConfig, dict)), \ + f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + if not isinstance(llm_config, PretrainedConfig): + model_type = llm_config['model_type'] + llm_config.pop('model_type') + llm_config = AutoConfig.for_model(model_type, **llm_config) + + # map llm_config to text_config + self.text_config = llm_config + if visual_tokenizer_config is not None: + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + if not isinstance(visual_tokenizer_config, PretrainedConfig): + model_type = visual_tokenizer_config['model_type'] + visual_tokenizer_config.pop('model_type') + visual_tokenizer_config = AutoConfig.for_model( + model_type, **visual_tokenizer_config) + + self.visual_tokenizer_config = visual_tokenizer_config + self.multimodal_max_length = multimodal_max_length + self.hidden_size = hidden_size + self.conversation_formatter_class = conversation_formatter_class + self.llm_attn_implementation = llm_attn_implementation + self.disable_tie_weight = disable_tie_weight diff --git a/vllm/transformers_utils/configs/ovis2_5.py b/vllm/transformers_utils/configs/ovis2_5.py new file mode 100644 index 000000000000..6c3ff07f41d5 --- /dev/null +++ b/vllm/transformers_utils/configs/ovis2_5.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + +# Model Constants +IMAGE_TOKEN = "" +VIDEO_TOKEN = "