Skip to content

Commit 7be5d11

Browse files
authored
[CPU] Refactor CPU W8A8 scaled_mm (#23071)
Signed-off-by: jiang1.li <[email protected]>
1 parent b029de9 commit 7be5d11

File tree

17 files changed

+1527
-1275
lines changed

17 files changed

+1527
-1275
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ function cpu_tests() {
4646
set -e
4747
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
4848

49+
# Run kernel tests
50+
docker exec cpu-test-"$NUMA_NODE" bash -c "
51+
set -e
52+
pytest -v -s tests/kernels/test_onednn.py"
53+
4954
# Run basic model test
5055
docker exec cpu-test-"$NUMA_NODE" bash -c "
5156
set -e
@@ -99,4 +104,4 @@ function cpu_tests() {
99104

100105
# All of CPU tests are expected to be finished less than 40 mins.
101106
export -f cpu_tests
102-
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
107+
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"

cmake/cpu_extension.cmake

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,17 @@ endif()
182182
#
183183
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
184184
# Flag to enable ACL kernels for AARCH64 platforms
185-
if ( VLLM_BUILD_ACL STREQUAL "ON")
185+
if (VLLM_BUILD_ACL STREQUAL "ON")
186186
set(USE_ACL ON)
187187
else()
188188
set(USE_ACL OFF)
189189
endif()
190190

191-
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
191+
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
192192
FetchContent_Declare(
193193
oneDNN
194194
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
195-
GIT_TAG v3.8.1
195+
GIT_TAG v3.9
196196
GIT_PROGRESS TRUE
197197
GIT_SHALLOW TRUE
198198
)
@@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
204204
endif()
205205
set(ONEDNN_AARCH64_USE_ACL "ON")
206206
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
207-
endif()
207+
endif()
208208

209209
set(ONEDNN_LIBRARY_TYPE "STATIC")
210210
set(ONEDNN_BUILD_DOC "OFF")
@@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
217217
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
218218
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
219219
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
220+
set(ONEDNN_VERBOSE "OFF")
220221
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
221222

222223
FetchContent_MakeAvailable(oneDNN)
223-
224-
list(APPEND LIBS dnnl)
225-
elseif(POWER10_FOUND)
226-
FetchContent_Declare(
227-
oneDNN
228-
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
229-
GIT_TAG v3.7.2
230-
GIT_PROGRESS TRUE
231-
GIT_SHALLOW TRUE
224+
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
225+
target_include_directories(
226+
dnnl_ext
227+
PUBLIC ${oneDNN_SOURCE_DIR}/include
228+
PUBLIC ${oneDNN_BINARY_DIR}/include
229+
PRIVATE ${oneDNN_SOURCE_DIR}/src
232230
)
233-
234-
set(ONEDNN_LIBRARY_TYPE "STATIC")
235-
set(ONEDNN_BUILD_DOC "OFF")
236-
set(ONEDNN_BUILD_EXAMPLES "OFF")
237-
set(ONEDNN_BUILD_TESTS "OFF")
238-
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
239-
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
240-
set(ONEDNN_BUILD_GRAPH "OFF")
241-
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
242-
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
243-
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
244-
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
245-
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
246-
247-
set(DNNL_CPU_RUNTIME "OMP")
248-
249-
FetchContent_MakeAvailable(oneDNN)
250-
251-
list(APPEND LIBS dnnl)
231+
target_link_libraries(dnnl_ext dnnl)
232+
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
233+
list(APPEND LIBS dnnl_ext)
234+
set(USE_ONEDNN ON)
235+
else()
236+
set(USE_ONEDNN OFF)
252237
endif()
253238

254239
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
@@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
275260

276261
if (AVX512_FOUND AND NOT AVX512_DISABLED)
277262
set(VLLM_EXT_SRC
278-
"csrc/cpu/quant.cpp"
279263
"csrc/cpu/shm.cpp"
280264
${VLLM_EXT_SRC})
281265
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
@@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
289273
${VLLM_EXT_SRC})
290274
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
291275
endif()
292-
elseif(POWER10_FOUND)
293-
set(VLLM_EXT_SRC
294-
"csrc/cpu/quant.cpp"
295-
${VLLM_EXT_SRC})
296276
endif()
297-
if (ASIMD_FOUND)
277+
278+
if(USE_ONEDNN)
298279
set(VLLM_EXT_SRC
299-
"csrc/cpu/quant.cpp"
280+
"csrc/cpu/dnnl_kernels.cpp"
300281
${VLLM_EXT_SRC})
301282
endif()
302283

csrc/cpu/cpu_types_x86.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
8989

9090
explicit FP16Vec16(const FP32Vec16&);
9191

92-
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
92+
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
9393

9494
void save(void* ptr, const int elem_num) const {
9595
constexpr uint32_t M = 0xFFFFFFFF;
@@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
126126

127127
explicit BF16Vec16(const FP32Vec16&);
128128

129-
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
129+
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
130130

131131
void save(void* ptr, const int elem_num) const {
132132
constexpr uint32_t M = 0xFFFFFFFF;
@@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
180180
(__m128i)vec8_data.reg, 1)) {}
181181

182182
void save(void* ptr) const {
183-
*reinterpret_cast<__m256i*>(ptr) = reg_low;
184-
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
183+
_mm256_storeu_si256((__m256i*)ptr, reg_low);
184+
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
185185
}
186186
};
187187
#endif

0 commit comments

Comments
 (0)