Skip to content

Commit 60554bc

Browse files
committed
Added SVE implementation to improve the performance on ARM architecture
1 parent cc3b56f commit 60554bc

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

CMakeLists.txt

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,51 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "OS400")
265265
set(CMAKE_CXX_ARCHIVE_CREATE "<CMAKE_AR> -X64 qc <TARGET> <OBJECTS>")
266266
endif()
267267

268+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
269+
include(CheckCSourceCompiles)
270+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve")
271+
check_c_source_compiles("
272+
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
273+
#include <arm_sve.h>
274+
int main() {
275+
svfloat64_t a;
276+
a = svdup_n_f64(0);
277+
return 0;
278+
}
279+
#endif
280+
" COMPILER_HAS_ARM_SVE)
281+
282+
if(COMPILER_HAS_ARM_SVE)
283+
message(STATUS "ARM SVE compiler support detected")
284+
set(SOURCE_CODE "
285+
#include <sys/prctl.h>
286+
int main() {
287+
int ret = prctl(PR_SVE_GET_VL);
288+
return ret >= 0 ? 0 : 1;
289+
}
290+
")
291+
file(WRITE ${CMAKE_BINARY_DIR}/check_sve_support.c "${SOURCE_CODE}")
292+
try_run(RUN_RESULT COMPILE_RESULT
293+
${CMAKE_BINARY_DIR}/check_sve_support_output
294+
${CMAKE_BINARY_DIR}/check_sve_support.c
295+
)
296+
297+
if(RUN_RESULT EQUAL 0)
298+
message(STATUS "ARM SVE hardware support detected")
299+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8-a+sve")
300+
string(APPEND CMAKE_CXX_FLAGS " -DSVE_SUPPORT_DETECTED")
301+
else()
302+
message(STATUS "ARM SVE hardware support not detected")
303+
endif()
304+
else()
305+
message(STATUS "ARM SVE compiler support not detected")
306+
endif()
307+
308+
set(CMAKE_C_FLAGS "${ORIGINAL_CMAKE_C_FLAGS}")
309+
else()
310+
message(STATUS "Not an aarch64 architecture")
311+
endif()
312+
268313
if(USE_NCCL)
269314
find_package(Nccl REQUIRED)
270315
endif()

src/common/hist_util.cc

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/**
22
* Copyright 2017-2023 by XGBoost Contributors
3+
* Copyright 2024 FUJITSU LIMITED
34
* \file hist_util.cc
45
*/
56
#include "hist_util.h"
@@ -15,6 +16,10 @@
1516
#include "xgboost/context.h" // for Context
1617
#include "xgboost/data.h" // for SparsePage, SortedCSCPage
1718

19+
#if defined(SVE_SUPPORT_DETECTED)
20+
#include <arm_sve.h> //to leverage sve intrinsics
21+
#endif
22+
1823
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
1924
#include <xmmintrin.h>
2025
#define PREFETCH_READ_T0(addr) _mm_prefetch(reinterpret_cast<const char*>(addr), _MM_HINT_T0)
@@ -252,13 +257,55 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair, Span<bst_idx_t cons
252257

253258
// The trick with pgh_t buffer helps the compiler to generate faster binary.
254259
const float pgh_t[] = {p_gpair[idx_gh], p_gpair[idx_gh + 1]};
255-
for (size_t j = 0; j < row_size; ++j) {
256-
const uint32_t idx_bin =
257-
two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
258-
auto hist_local = hist_data + idx_bin;
259-
*(hist_local) += pgh_t[0];
260-
*(hist_local + 1) += pgh_t[1];
261-
}
260+
261+
#if defined(SVE_SUPPORT_DETECTED)
262+
svfloat64_t pgh_t0_vec = svdup_n_f64(pgh_t[0]);
263+
svfloat64_t pgh_t1_vec = svdup_n_f64(pgh_t[1]);
264+
265+
for (size_t j = 0; j < row_size; j+=svcntw()) {
266+
svbool_t pg32 = svwhilelt_b32(j,row_size);
267+
svbool_t pg64 = svwhilelt_b64(j,row_size);
268+
269+
svuint32_t gr_index_vec = svld1ub_u32(pg32, reinterpret_cast<const uint8_t *> (&gr_index_local[j]));
270+
svuint32_t offsets_vec = svld1(pg32, &offsets[j]);
271+
272+
svuint32_t idx_bin_vec;
273+
if (kAnyMissing) {
274+
idx_bin_vec = svmul_n_u32_x(pg32, gr_index_vec, two);
275+
} else {
276+
svuint32_t temp = svadd_u32_m(pg32, gr_index_vec, offsets_vec);
277+
idx_bin_vec = svmul_n_u32_x(pg32, temp, two);
278+
}
279+
280+
svuint64_t idx_bin_vec0_0 = svunpklo_u64(idx_bin_vec);
281+
svuint64_t idx_bin_vec0_1 = svunpkhi_u64(idx_bin_vec);
282+
svuint64_t idx_bin_vec1_0 = svadd_n_u64_m(pg64, idx_bin_vec0_0, 1);
283+
svuint64_t idx_bin_vec1_1 = svadd_n_u64_m(pg64, idx_bin_vec0_1, 1);
284+
285+
svfloat64_t hist0_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_0);
286+
svfloat64_t hist0_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_1);
287+
svfloat64_t hist1_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_0);
288+
svfloat64_t hist1_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_1);
289+
290+
hist0_vec0 = svadd_f64_m(pg64, hist0_vec0, pgh_t0_vec);
291+
hist0_vec1 = svadd_f64_m(pg64, hist0_vec1, pgh_t0_vec);
292+
hist1_vec0 = svadd_f64_m(pg64, hist1_vec0, pgh_t1_vec);
293+
hist1_vec1 = svadd_f64_m(pg64, hist1_vec1, pgh_t1_vec);
294+
295+
svst1_scatter_index(pg64, hist_data, idx_bin_vec0_0, hist0_vec0);
296+
svst1_scatter_index(pg64, hist_data, idx_bin_vec0_1, hist0_vec1);
297+
svst1_scatter_index(pg64, hist_data, idx_bin_vec1_0, hist1_vec0);
298+
svst1_scatter_index(pg64, hist_data, idx_bin_vec1_1, hist1_vec1);
299+
}
300+
#else
301+
for (size_t j = 0; j < row_size; ++j) {
302+
const uint32_t idx_bin =
303+
two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
304+
auto hist_local = hist_data + idx_bin;
305+
*(hist_local) += pgh_t[0];
306+
*(hist_local + 1) += pgh_t[1];
307+
}
308+
#endif
262309
}
263310
}
264311

0 commit comments

Comments
 (0)