Skip to content

Commit f8af6b8

Browse files
authored
Merge branch 'vllm-project:main' into Eagle-mulitmodal-support-Qwen2.5vl
2 parents 8874e16 + 0933f9d commit f8af6b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1661
-413
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ steps:
669669
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
670670
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
671671
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
672+
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
672673
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
673674
# Fusion
674675
- pytest -v -s tests/compile/test_fusion_all_reduce.py

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
351351
set_gencode_flags_for_srcs(
352352
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
353353
CUDA_ARCHS "${MARLIN_ARCHS}")
354+
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
355+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
354356

355357
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
356358

@@ -364,7 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
364366
set_gencode_flags_for_srcs(
365367
SRCS "${MARLIN_SRCS}"
366368
CUDA_ARCHS "${MARLIN_ARCHS}")
369+
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
370+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
367371
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
372+
368373
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
369374
else()
370375
message(STATUS "Not building Marlin kernels as no compatible archs found"
@@ -854,6 +859,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
854859
set_gencode_flags_for_srcs(
855860
SRCS "${MOE_WNAA16_MARLIN_SRC}"
856861
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
862+
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
863+
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
857864

858865
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
859866

benchmarks/kernels/benchmark_machete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
236236
a=bt.a,
237237
c=None,
238238
b_q_weight=w_q,
239+
b_bias=None,
239240
b_scales=w_s,
240241
global_scale=None,
241242
b_zeros=w_zp,

csrc/core/scalar_type.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
321321
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
322322
static inline constexpr auto kFE4M3fn =
323323
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
324+
static inline constexpr auto kFE8M0fnu =
325+
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
324326
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
325327
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
326328
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);

csrc/moe/marlin_moe_wna16/generate_kernels.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TEMPLATE = ("template __global__ void Marlin<"
2121
"{{scalar_t}}, "
2222
"{{w_type_id}}, "
23+
"{{s_type_id}}, "
2324
"{{threads}}, "
2425
"{{thread_m_blocks}}, "
2526
"{{thread_n_blocks}}, "
@@ -77,6 +78,7 @@ def generate_new_kernels():
7778
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
7879
continue
7980
# nvfp4 only supports group_size == 16
81+
# mxfp4 only supports group_size == 32
8082
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
8183
continue
8284
# other quantization methods don't support group_size = 16
@@ -89,9 +91,22 @@ def generate_new_kernels():
8991

9092
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
9193

94+
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
95+
s_type = "vllm::kFE4M3fn"
96+
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
97+
s_type = "vllm::kFE8M0fnu"
98+
if dtype == "fp16":
99+
# we cannot safely dequantize e8m0 to fp16, so skip this
100+
continue
101+
elif dtype == "fp16":
102+
s_type = "vllm::kFloat16"
103+
elif dtype == "bf16":
104+
s_type = "vllm::kBFloat16"
105+
92106
template_str = jinja2.Template(TEMPLATE).render(
93107
scalar_t=c_dtype,
94108
w_type_id=scalar_type + ".id()",
109+
s_type_id=s_type + ".id()",
95110
threads=threads,
96111
thread_m_blocks=max(m_blocks, 1),
97112
thread_n_blocks=n_blocks,

csrc/moe/marlin_moe_wna16/kernel.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,25 @@
77
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
88
#include "core/scalar_type.hpp"
99

10-
#define MARLIN_KERNEL_PARAMS \
11-
const int4 *__restrict__ A, const int4 *__restrict__ B, \
12-
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
13-
const int4 *__restrict__ scales_ptr, \
14-
const uint16_t *__restrict__ scale2_ptr, \
15-
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
16-
const int32_t *__restrict__ sorted_token_ids_ptr, \
17-
const int32_t *__restrict__ expert_ids_ptr, \
18-
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
19-
const float *__restrict__ topk_weights_ptr, int top_k, \
20-
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
21-
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
10+
#define MARLIN_KERNEL_PARAMS \
11+
const int4 *__restrict__ A, const int4 *__restrict__ B, \
12+
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
13+
const int4 *__restrict__ b_bias_ptr, \
14+
const int4 *__restrict__ scales_ptr, \
15+
const uint16_t *__restrict__ scale2_ptr, \
16+
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
17+
const int32_t *__restrict__ sorted_token_ids_ptr, \
18+
const int32_t *__restrict__ expert_ids_ptr, \
19+
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
20+
const float *__restrict__ topk_weights_ptr, int top_k, \
21+
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
22+
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
2223
bool use_fp32_reduce, int max_shared_mem
2324

2425
namespace MARLIN_NAMESPACE_NAME {
2526
template <typename scalar_t, // compute dtype, half or nv_float16
2627
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
28+
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
2729
const int threads, // number of threads in a threadblock
2830
const int thread_m_blocks, // number of 16x16 blocks in the m
2931
// dimension (batchsize) of the

0 commit comments

Comments
 (0)