From 1738e8e59a15694be169c7b4f9e0c068ff2ec856 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Fri, 11 Jul 2025 01:46:16 -0400 Subject: [PATCH 1/6] add shm allreduce --- gloo/CMakeLists.txt | 2 + gloo/allreduce.cc | 11 + gloo/allreduce.h | 24 ++ gloo/allreduce_shm.cc | 741 ++++++++++++++++++++++++++++++++++++++++++ gloo/allreduce_shm.h | 8 + 5 files changed, 786 insertions(+) create mode 100644 gloo/allreduce_shm.cc create mode 100644 gloo/allreduce_shm.h diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 186fe1288..fb65defd5 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.cc" @@ -34,6 +35,7 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring_chunked.h" + "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.h" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.h" diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index 080f7f302..511e8d3d3 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -15,6 +15,7 @@ #include "gloo/common/logging.h" #include "gloo/math.h" #include "gloo/types.h" +#include "gloo/allreduce_shm.h" namespace gloo { @@ -95,6 +96,7 @@ BroadcastRangeFunction genLocalBroadcastFunction(const BufferVector& out) { } void allreduce(const detail::AllreduceOptionsImpl& opts) { + //printf("In gloo::allreduce\n"); if (opts.elements == 0) { return; } @@ -153,6 +155,15 @@ void ring( const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag); const size_t totalBytes = opts.elements * opts.elementSize; + + if (is_intra_node(context->size)) { + shm(opts); + return; + } + + //shm(opts); + //return; + // Note: context->size > 1 const auto recvRank = (context->size + context->rank + 1) % context->size; const auto sendRank = (context->size + context->rank - 1) % context->size; diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 904eb8b32..2133cf2f3 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -11,9 +11,13 @@ #include #include #include +#include #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" +#include "gloo/types.h" +//#include "gloo/allreduce_shm.h" + namespace gloo { @@ -41,6 +45,12 @@ struct AllreduceOptionsImpl { BCUBE = 2, }; + enum ScalarType { + BFLOAT16, + HALF, + FLOAT, + }; + explicit AllreduceOptionsImpl(const std::shared_ptr& context) : context(context), timeout(context->getTimeout()), @@ -54,6 +64,9 @@ struct AllreduceOptionsImpl { // Algorithm selection. Algorithm algorithm; + // Scalar type + ScalarType scalarType; + // Input and output buffers. // The output is used as input if input is not specified. std::vector> in; @@ -90,6 +103,7 @@ class AllreduceOptions { public: using Func = detail::AllreduceOptionsImpl::Func; using Algorithm = detail::AllreduceOptionsImpl::Algorithm; + using ScalarType = detail::AllreduceOptionsImpl::ScalarType; explicit AllreduceOptions(const std::shared_ptr& context) : impl_(context) {} @@ -154,6 +168,16 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { + //printf("set outputs\n"); + if (std::is_same_v) { + printf("output type is float\n"); + impl_.scalarType = ScalarType::FLOAT; + } else if (std::is_same_v) { + printf("output type is float16\n"); + impl_.scalarType = ScalarType::HALF; + } else { + printf("Unknown datatype\n"); + } setOutputs(ptrs.data(), ptrs.size(), elements); } diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc new file mode 100644 index 000000000..1a6498a14 --- /dev/null +++ b/gloo/allreduce_shm.cc @@ -0,0 +1,741 @@ +#include "gloo/allreduce_shm.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace gloo { + +namespace { +#define VECTOR_LENGTH_IN_BYTES 32 +// states for collectives +enum coll_state { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, +}; + +// SHM building blocks +struct SharedData { + const char* name; + int descriptor; + void* bytes; + size_t nbytes; +}; + +void shared_open(SharedData* data, const char* name, size_t nbytes) { + int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); + data->name = name; + data->descriptor = d; + data->bytes = bytes; + data->nbytes = nbytes; + } else { + if (errno != ENOENT) { + // don't print if shm can not be found because we want to loop over from + // caller again until the other ranks created the shm + printf("shared_open %s failed, errno=%d\n", name, errno); + } + data->descriptor = -1; + } +} + +void shared_create( + SharedData* data, + const char* name, + void* bytes, + size_t nbytes) { + int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + if (nbytes = write(d, bytes, nbytes)) { + shared_open(data, name, nbytes); + } + } else { + printf("shared_create %s failed\n", name); + } +} + +static int world_rank = -1; +static int world_size = -1; +static bool is_initialized = false; + +// SHM based allreduce helper functions +// buffer that holds shm name +#define NAME_BUF_SIZE 1000 +#define MAX_BUF_SIZE 1048576 * 32 +#define NAIVE_ALLREDUCE_THRESHOLD 1048576 +#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +struct allreduce_workspace { + enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for + // symmetric_naive_all_reduce after that : buffer for + // distributed_naive_all_reduce + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; +}; + +#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) \ + 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE + +struct allreduce_workspace** workspace; + +// buffer for small messages, double buffer +char** symmetric_buffer[2]; +// buffer for large messages, double buffer +char** distributed_buffer[2]; + +void wait_buffer_state_until_2( + int index, + enum coll_state state0, + enum coll_state state1, + int state_group) { + volatile enum coll_state* state_ptr = + &(workspace[index]->states[state_group]); + + while (1) { + volatile enum coll_state cur_state = *state_ptr; + if (cur_state == state0 || cur_state == state1) + break; + } +} + +__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_bf16_to_fp32(const __m256i src) { + auto y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) + __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_bf16(const __m512 src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_fp16_to_fp32(const __m256i src) { + return _mm512_cvtph_ps(src); +} + +inline __m256i cvt_fp32_to_fp16(const __m512 src) + __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_fp16(const __m512 src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +void reduce_bf16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_fp16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_fp32_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_all_buffers( + int start_elements, + int num_elements, + AllreduceOptions::ScalarType scalar_type, + int to_buffer_idx, + char* to_buffer, + char** buffers) { + switch (scalar_type) { + case AllreduceOptions::ScalarType::BFLOAT16: + assert(!"BFloat16 not supported in gloo yet."); + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case AllreduceOptions::ScalarType::HALF: + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case AllreduceOptions::ScalarType::FLOAT: + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + break; + default: + assert(!"Should not get here"); + } +} + +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = \ + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +// Reduce functions down below use vectorized algorithm, the number of bytes +// processed each iteration depends on vector length. 256bit vector ==> 32 +// bytes, 512bit vector ==> 64 bytes If you change implementation of +// reduce_bf16_buffers, etc. , check whether this number needs to be changed +#define VECTOR_LENGTH_IN_BYTES 32 + +void reduce_bf16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_BF16(15); + case 15: + CVT_ADD_BF16(14); + case 14: + CVT_ADD_BF16(13); + case 13: + CVT_ADD_BF16(12); + case 12: + CVT_ADD_BF16(11); + case 11: + CVT_ADD_BF16(10); + case 10: + CVT_ADD_BF16(9); + case 9: + CVT_ADD_BF16(8); + case 8: + CVT_ADD_BF16(7); + case 7: + CVT_ADD_BF16(6); + case 6: + CVT_ADD_BF16(5); + case 5: + CVT_ADD_BF16(4); + case 4: + CVT_ADD_BF16(3); + case 3: + CVT_ADD_BF16(2); + case 2: + CVT_ADD_BF16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); + } + + // process remaining part + // todo: support bfloat16 + /* + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(at::BFloat16*)(buffers[j] + i); + } + *(at::BFloat16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } + */ +} + +#define CVT_ADD_FP16(x) \ + do { \ + auto in##x##_val = \ + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_FP16(15); + case 15: + CVT_ADD_FP16(14); + case 14: + CVT_ADD_FP16(13); + case 13: + CVT_ADD_FP16(12); + case 12: + CVT_ADD_FP16(11); + case 11: + CVT_ADD_FP16(10); + case 10: + CVT_ADD_FP16(9); + case 9: + CVT_ADD_FP16(8); + case 8: + CVT_ADD_FP16(7); + case 7: + CVT_ADD_FP16(6); + case 6: + CVT_ADD_FP16(5); + case 5: + CVT_ADD_FP16(4); + case 4: + CVT_ADD_FP16(3); + case 3: + CVT_ADD_FP16(2); + case 2: + CVT_ADD_FP16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val)); + } + + + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float16 val =float16(0.0f); + for (int j = 0; j < world_size; j++) { + val += *(float16*)(buffers[j] + i); + } + *(float16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ + inout_val = _mm256_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp32_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 4; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); + switch (world_size) { + case 16: + CVT_ADD_F32(15); + case 15: + CVT_ADD_F32(14); + case 14: + CVT_ADD_F32(13); + case 13: + CVT_ADD_F32(12); + case 12: + CVT_ADD_F32(11); + case 11: + CVT_ADD_F32(10); + case 10: + CVT_ADD_F32(9); + case 9: + CVT_ADD_F32(8); + case 8: + CVT_ADD_F32(7); + case 7: + CVT_ADD_F32(6); + case 6: + CVT_ADD_F32(5); + case 5: + CVT_ADD_F32(4); + case 4: + CVT_ADD_F32(3); + case 3: + CVT_ADD_F32(2); + case 2: + CVT_ADD_F32(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); + inout_val = _mm256_add_ps(inout_val, in_val); + } + } + _mm256_storeu_ps((float*)(to_buffer + i), inout_val); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(float*)(buffers[j] + i); + } + *(float*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +void shm_initialize(int size, int rank, char* addr_string, char* port_string) { + world_size = size; + world_rank = rank; + + char shm_name_prefix[NAME_BUF_SIZE]; + char shm_name[NAME_BUF_SIZE]; + snprintf( + shm_name_prefix, + NAME_BUF_SIZE, + "%s_%d_%s_%s", + SHM_BUFFER_NAME, + getuid(), + addr_string, + port_string); + // create shared workspace for SHM based allreduce + SharedData allreduce_buffer; + // allocate workspace_buf for current rank + struct allreduce_workspace* workspace_buf; + struct allreduce_workspace* workspace_buf_other; + workspace_buf = + (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); + int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + if (written >= NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + shared_create( + &allreduce_buffer, + shm_name, + workspace_buf, + sizeof(struct allreduce_workspace)); + workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; + + // create the workspace pointer list + workspace = (struct allreduce_workspace**)malloc( + size * sizeof(struct allreduce_workspace*)); + symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); + + // map shm of all ranks + for (int i = 0; i < size; i++) { + if (i != rank) { + int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + if (written >= NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + // printf("open %s, %d\n", shm_name, rank); + do { + shared_open( + &allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; + } + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); + } +} + +static void parallel_memcpy(void* to, void* from, size_t n_bytes) + __attribute__((target("avx512bw"))); +static void parallel_memcpy(void* to, void* from, size_t n_bytes) { + auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); + // process aligned part +#pragma omp parallel for + for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { + auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); + _mm256_storeu_si256((__m256i*)((char*)to + i), val); + } + + // process remaining part + for (int i = aligned_bytes; i < n_bytes; i++) { + *((char*)to + i) = *((char*)from + i); + } +} + +#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) +#define rank_mod(rank) positive_mod(rank, world_size) +size_t slice_size(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) + : slice_size; +} + +char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) { + size_t slice_size = chunk_el / world_size; + size_t el_offset = slice_size * slice_idx; + return data_ptr + el_offset * el_size; +} + +size_t slice_el_start(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_size * slice_idx; +} + +void symmetric_naive_all_reduce( + char* data_ptr, + AllreduceOptions::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) { + const int state_group = 0; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + copy_next = coll_alt2_allreduce_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allreduce_naive__copy_in_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 3; + + parallel_memcpy( + symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until the other rank copy the buffer + if (i != world_rank) { + wait_buffer_state_until_2(i, copy_current, copy_next, state_group); + } + } + + // each rank reduce the buffer independently so therre is no need for + // synchronization afterward + reduce_all_buffers( + 0, + chunk_el, + scalar_type, + world_rank, + data_ptr, + symmetric_buffer[current_buffer]); + + // switch buffer + current_buffer = 1 - current_buffer; +} + +// naive allreduce distributed, each rank do naive reduce on its slice +void distributed_naive_reduce( + char* data_ptr, + AllreduceOptions::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) { + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next, reduce_current; + + // similar to symmetric_naive_allreduce, but here we only need two sets of + // states, because distributed naive reduce has two barriers in the algorithm + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + reduce_current = coll_allreduce_naive__reduce_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + reduce_current = coll_alt1_allreduce_naive__reduce_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 2; + + int data_size = chunk_size / chunk_el; + parallel_memcpy( + distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + if (i != world_rank) + wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); + } + + // reduce scatter + reduce_all_buffers( + slice_el_start(chunk_el, world_rank), + slice_size(chunk_el, world_rank), + scalar_type, + world_rank, + distributed_buffer[current_buffer][world_rank], + distributed_buffer[current_buffer]); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = reduce_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks reduce the buffer + if (i != world_rank) + wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); + } + + for (int i = 0; i < world_size; i++) { + int rank = (i + world_rank) % world_size; + parallel_memcpy( + slice_data(data_ptr, chunk_el, data_size, rank), + slice_data( + distributed_buffer[current_buffer][rank], + chunk_el, + chunk_size / chunk_el, + rank), + slice_size(chunk_el, rank) * data_size); + } + + current_buffer = 1 - current_buffer; +} + +} // namespace + +bool is_intra_node(const int size) { + // must launch with torchrun + auto local_size_string = std::getenv("LOCAL_WORLD_SIZE"); + int local_size = 0; + if (local_size_string != NULL) { + local_size = std::stoi(local_size_string); + } + + return size > 1 && size == local_size; +} + + +void shm(const detail::AllreduceOptionsImpl& opts) { + + //printf("In shm allreduce\n"); + const auto& context = opts.context; + if (!is_initialized) { + + //int size = context->size; + //int rank = context->rank; + + int size = std::stoi(std::getenv("PMI_SIZE")); + int rank = std::stoi(std::getenv("PMI_RANK")); + + world_size = size; + world_rank = rank; + is_initialized = true; + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { + addr_string = ""; + } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { + port_string = ""; + } + // std::cout << "size: " << size << std::endl; + // std::cout << "rank: " << rank << std::endl; + // std::cout << "addr_string: " << addr_string << std::endl; + // std::cout << "port_string: " << port_string << std::endl; + shm_initialize(size, rank, addr_string, port_string); + } + + const size_t data_size = opts.elements * opts.elementSize; + const std::vector>& out = opts.out; + void* data = out[0].get()->ptr; + + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data) + offset); + size_t chunk_size = + data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / opts.elements); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { + symmetric_naive_all_reduce( + data_ptr, opts.scalarType, chunk_size, chunk_el); + } else { + distributed_naive_reduce( + data_ptr, opts.scalarType, chunk_size, chunk_el); + } + } + +} + +} //namespace gloo + diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h new file mode 100644 index 000000000..e9236759c --- /dev/null +++ b/gloo/allreduce_shm.h @@ -0,0 +1,8 @@ +#include "gloo/allreduce.h" + +namespace gloo { + +bool is_intra_node(const int size); +void shm(const detail::AllreduceOptionsImpl& opts); + +} // namespace gloo \ No newline at end of file From 76d111461f802df7a5c634437d6244aec89a0951 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Wed, 16 Jul 2025 03:52:12 -0400 Subject: [PATCH 2/6] add bf16 and half support --- gloo/CMakeLists.txt | 5 +++++ gloo/allreduce.h | 34 +++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index fb65defd5..6b0ac60b0 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,6 +188,11 @@ if(USE_ROCM) endif() endif() +message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") +if(GLOO_USE_TORCH_DTYPES) +target_include_directories(gloo PRIVATE ${GLOO_TORCH_DIR}) +endif() + # Install if necessary. # If the Gloo build is included from another project's build, it may # want to statically link with Gloo and not install any artifacts. diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 2133cf2f3..2ca69ca94 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -18,6 +18,11 @@ #include "gloo/types.h" //#include "gloo/allreduce_shm.h" +#define GPF_PRINT(...) do {\ + printf("GPF_DEBUG:");\ + printf(__VA_ARGS__);\ + printf("\n");\ +}while(0) namespace gloo { @@ -39,6 +44,11 @@ struct AllreduceOptionsImpl { // using Func = std::function; +#if GLOO_USE_TORCH_DTYPES +using BFloat16 = c10::BFloat16; +using Half = c10::Half; +#endif + enum Algorithm { UNSPECIFIED = 0, RING = 1, @@ -49,6 +59,7 @@ struct AllreduceOptionsImpl { BFLOAT16, HALF, FLOAT, + UNKNOWN, }; explicit AllreduceOptionsImpl(const std::shared_ptr& context) @@ -169,18 +180,23 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { //printf("set outputs\n"); - if (std::is_same_v) { - printf("output type is float\n"); - impl_.scalarType = ScalarType::FLOAT; - } else if (std::is_same_v) { - printf("output type is float16\n"); - impl_.scalarType = ScalarType::HALF; - } else { - printf("Unknown datatype\n"); - } + // default is float + impl_.scalarType = ScalarType::FLOAT; + +#if GLOO_USE_TORCH_DTYPES +if (std::is_same_v) { + //GPF_PRINT("output type is half"); + impl_.scalarType = ScalarType::HALF; +} else if (std::is_same_v) { + impl_.scalarType = ScalarType::BFLOAT16; + //GPF_PRINT("output type is bfloat16"); +} +#endif setOutputs(ptrs.data(), ptrs.size(), elements); } + + template void setOutputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; From 2d152a33ccbc0012d92320b5e5b90a0f32921349 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Fri, 18 Jul 2025 01:35:01 -0400 Subject: [PATCH 3/6] remove bf16 support --- gloo/allreduce.h | 41 ++++++++++++++++++++--------------------- gloo/allreduce_shm.cc | 8 ++++---- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 2ca69ca94..4ef031073 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -16,7 +16,6 @@ #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" #include "gloo/types.h" -//#include "gloo/allreduce_shm.h" #define GPF_PRINT(...) do {\ printf("GPF_DEBUG:");\ @@ -44,11 +43,6 @@ struct AllreduceOptionsImpl { // using Func = std::function; -#if GLOO_USE_TORCH_DTYPES -using BFloat16 = c10::BFloat16; -using Half = c10::Half; -#endif - enum Algorithm { UNSPECIFIED = 0, RING = 1, @@ -179,24 +173,9 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { - //printf("set outputs\n"); - // default is float - impl_.scalarType = ScalarType::FLOAT; - -#if GLOO_USE_TORCH_DTYPES -if (std::is_same_v) { - //GPF_PRINT("output type is half"); - impl_.scalarType = ScalarType::HALF; -} else if (std::is_same_v) { - impl_.scalarType = ScalarType::BFLOAT16; - //GPF_PRINT("output type is bfloat16"); -} -#endif setOutputs(ptrs.data(), ptrs.size(), elements); } - - template void setOutputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; @@ -230,6 +209,26 @@ if (std::is_same_v) { friend void allreduce(const AllreduceOptions&); }; +#if GLOO_USE_TORCH_DTYPES + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::HALF; + setOutputs(ptrs.data(), ptrs.size(), elements); + } + + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::BFLOAT16; + setOutputs(ptrs.data(), ptrs.size(), elements); + } + + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::FLOAT; + setOutputs(ptrs.data(), ptrs.size(), elements); + } +#endif + void allreduce(const AllreduceOptions& opts); } // namespace gloo diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 1a6498a14..0169c3f6e 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -179,7 +179,7 @@ void reduce_all_buffers( char** buffers) { switch (scalar_type) { case AllreduceOptions::ScalarType::BFLOAT16: - assert(!"BFloat16 not supported in gloo yet."); + GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); break; case AllreduceOptions::ScalarType::HALF: @@ -273,9 +273,9 @@ void reduce_bf16_buffers( while (remain_elements > 0) { float val = 0.0f; for (int j = 0; j < world_size; j++) { - val += *(at::BFloat16*)(buffers[j] + i); + val += *(c10::BFloat16*)(buffers[j] + i); } - *(at::BFloat16*)(to_buffer + i) = val; + *(BFloat16*)(to_buffer + i) = val; remain_elements--; i += element_size; } @@ -688,7 +688,6 @@ bool is_intra_node(const int size) { void shm(const detail::AllreduceOptionsImpl& opts) { - //printf("In shm allreduce\n"); const auto& context = opts.context; if (!is_initialized) { @@ -715,6 +714,7 @@ void shm(const detail::AllreduceOptionsImpl& opts) { // std::cout << "addr_string: " << addr_string << std::endl; // std::cout << "port_string: " << port_string << std::endl; shm_initialize(size, rank, addr_string, port_string); + GPF_PRINT("SHM reduce has been initialized"); } const size_t data_size = opts.elements * opts.elementSize; From 8c29eeba1c15aa203393dd916ad92ebaa44b3964 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Tue, 22 Jul 2025 01:38:39 -0400 Subject: [PATCH 4/6] add bf16 support --- gloo/allreduce_shm.cc | 110 +++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 0169c3f6e..848533f5b 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,9 @@ namespace gloo { namespace { + +using ReductionFunction = AllreduceOptions::Func; + #define VECTOR_LENGTH_IN_BYTES 32 // states for collectives enum coll_state { @@ -156,19 +160,47 @@ void reduce_bf16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); void reduce_fp16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); void reduce_fp32_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); + +void reduce_remaining_part( + int start_elements, + int num_elements, + int remain_elements, + int main_elements, + int element_size, + char *to_buffer, + char **buffers, + ReductionFunction fn){ + size_t offset = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + memcpy(to_buffer + offset, buffers[0] + offset, element_size); + for (int j = 1; j < world_size; j++) { + + fn(to_buffer + offset, + to_buffer + offset, + buffers[j] + offset, + 1); + + } + remain_elements--; + offset += element_size; + } +} void reduce_all_buffers( int start_elements, @@ -176,17 +208,18 @@ void reduce_all_buffers( AllreduceOptions::ScalarType scalar_type, int to_buffer_idx, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { switch (scalar_type) { case AllreduceOptions::ScalarType::BFLOAT16: - GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); - reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + //GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; case AllreduceOptions::ScalarType::HALF: - reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; case AllreduceOptions::ScalarType::FLOAT: - reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; default: assert(!"Should not get here"); @@ -210,7 +243,8 @@ void reduce_bf16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -267,19 +301,7 @@ void reduce_bf16_buffers( } // process remaining part - // todo: support bfloat16 - /* - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float val = 0.0f; - for (int j = 0; j < world_size; j++) { - val += *(c10::BFloat16*)(buffers[j] + i); - } - *(BFloat16*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } - */ + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } #define CVT_ADD_FP16(x) \ @@ -293,7 +315,8 @@ void reduce_fp16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -352,16 +375,7 @@ void reduce_fp16_buffers( // process remaining part - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float16 val =float16(0.0f); - for (int j = 0; j < world_size; j++) { - val += *(float16*)(buffers[j] + i); - } - *(float16*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } #define CVT_ADD_F32(x) \ @@ -374,7 +388,8 @@ void reduce_fp32_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 4; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -429,16 +444,7 @@ void reduce_fp32_buffers( } // process remaining part - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float val = 0.0f; - for (int j = 0; j < world_size; j++) { - val += *(float*)(buffers[j] + i); - } - *(float*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } void shm_initialize(int size, int rank, char* addr_string, char* port_string) { @@ -547,7 +553,8 @@ void symmetric_naive_all_reduce( char* data_ptr, AllreduceOptions::ScalarType scalar_type, size_t chunk_size, - size_t chunk_el) { + size_t chunk_el, + ReductionFunction fn) { const int state_group = 0; static int current_buffer = 0; static int state_idx = 0; @@ -592,7 +599,8 @@ void symmetric_naive_all_reduce( scalar_type, world_rank, data_ptr, - symmetric_buffer[current_buffer]); + symmetric_buffer[current_buffer], + fn); // switch buffer current_buffer = 1 - current_buffer; @@ -603,7 +611,8 @@ void distributed_naive_reduce( char* data_ptr, AllreduceOptions::ScalarType scalar_type, size_t chunk_size, - size_t chunk_el) { + size_t chunk_el, + ReductionFunction fn) { const int state_group = 1; static int current_buffer = 0; static int state_idx = 0; @@ -647,7 +656,8 @@ void distributed_naive_reduce( scalar_type, world_rank, distributed_buffer[current_buffer][world_rank], - distributed_buffer[current_buffer]); + distributed_buffer[current_buffer], + fn); std::atomic_thread_fence(std::memory_order_release); workspace[world_rank]->states[state_group] = reduce_current; @@ -728,10 +738,10 @@ void shm(const detail::AllreduceOptionsImpl& opts) { size_t chunk_el = chunk_size / (data_size / opts.elements); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el); + data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); } else { distributed_naive_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el); + data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); } } From 554d3176fefbef1fcbfd7043a540d98595e9b177 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Thu, 24 Jul 2025 03:23:17 -0400 Subject: [PATCH 5/6] use reduce function to do reduce job --- gloo/CMakeLists.txt | 4 - gloo/allreduce.h | 31 ---- gloo/allreduce_shm.cc | 357 +++++------------------------------------- 3 files changed, 43 insertions(+), 349 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 6b0ac60b0..db54496ef 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,10 +188,6 @@ if(USE_ROCM) endif() endif() -message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") -if(GLOO_USE_TORCH_DTYPES) -target_include_directories(gloo PRIVATE ${GLOO_TORCH_DIR}) -endif() # Install if necessary. # If the Gloo build is included from another project's build, it may diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 4ef031073..8fc494358 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -49,13 +49,6 @@ struct AllreduceOptionsImpl { BCUBE = 2, }; - enum ScalarType { - BFLOAT16, - HALF, - FLOAT, - UNKNOWN, - }; - explicit AllreduceOptionsImpl(const std::shared_ptr& context) : context(context), timeout(context->getTimeout()), @@ -69,9 +62,6 @@ struct AllreduceOptionsImpl { // Algorithm selection. Algorithm algorithm; - // Scalar type - ScalarType scalarType; - // Input and output buffers. // The output is used as input if input is not specified. std::vector> in; @@ -108,7 +98,6 @@ class AllreduceOptions { public: using Func = detail::AllreduceOptionsImpl::Func; using Algorithm = detail::AllreduceOptionsImpl::Algorithm; - using ScalarType = detail::AllreduceOptionsImpl::ScalarType; explicit AllreduceOptions(const std::shared_ptr& context) : impl_(context) {} @@ -209,26 +198,6 @@ class AllreduceOptions { friend void allreduce(const AllreduceOptions&); }; -#if GLOO_USE_TORCH_DTYPES - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::HALF; - setOutputs(ptrs.data(), ptrs.size(), elements); - } - - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::BFLOAT16; - setOutputs(ptrs.data(), ptrs.size(), elements); - } - - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::FLOAT; - setOutputs(ptrs.data(), ptrs.size(), elements); - } -#endif - void allreduce(const AllreduceOptions& opts); } // namespace gloo diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 848533f5b..6af27d3ad 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -118,333 +118,62 @@ void wait_buffer_state_until_2( } } -__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); -inline __m512 cvt_bf16_to_fp32(const __m256i src) { - auto y = _mm512_cvtepu16_epi32(src); - return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); -} - -inline __m256i cvt_fp32_to_bf16(const __m512 src) - __attribute__((target("avx512bw"))); -inline __m256i cvt_fp32_to_bf16(const __m512 src) { - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); -} - -__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); -inline __m512 cvt_fp16_to_fp32(const __m256i src) { - return _mm512_cvtph_ps(src); -} - -inline __m256i cvt_fp32_to_fp16(const __m512 src) - __attribute__((target("avx512bw"))); -inline __m256i cvt_fp32_to_fp16(const __m512 src) { - return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); -} - -void reduce_bf16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_fp16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_fp32_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_remaining_part( - int start_elements, - int num_elements, - int remain_elements, - int main_elements, - int element_size, - char *to_buffer, - char **buffers, - ReductionFunction fn){ - size_t offset = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - memcpy(to_buffer + offset, buffers[0] + offset, element_size); - for (int j = 1; j < world_size; j++) { - - fn(to_buffer + offset, - to_buffer + offset, - buffers[j] + offset, - 1); - - } - remain_elements--; - offset += element_size; - } -} - void reduce_all_buffers( int start_elements, int num_elements, - AllreduceOptions::ScalarType scalar_type, + int element_size, int to_buffer_idx, char* to_buffer, char** buffers, ReductionFunction fn) { - switch (scalar_type) { - case AllreduceOptions::ScalarType::BFLOAT16: - //GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); - reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - case AllreduceOptions::ScalarType::HALF: - reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - case AllreduceOptions::ScalarType::FLOAT: - reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - default: - assert(!"Should not get here"); - } -} - -#define CVT_ADD_BF16(x) \ - do { \ - auto in##x##_val = \ - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ - } while (0) - -// Reduce functions down below use vectorized algorithm, the number of bytes -// processed each iteration depends on vector length. 256bit vector ==> 32 -// bytes, 512bit vector ==> 64 bytes If you change implementation of -// reduce_bf16_buffers, etc. , check whether this number needs to be changed -#define VECTOR_LENGTH_IN_BYTES 32 - -void reduce_bf16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); int remain_elements = num_elements % vector_length; - - // process aligned part + #pragma omp parallel for for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); - switch (world_size) { - case 16: - CVT_ADD_BF16(15); - case 15: - CVT_ADD_BF16(14); - case 14: - CVT_ADD_BF16(13); - case 13: - CVT_ADD_BF16(12); - case 12: - CVT_ADD_BF16(11); - case 11: - CVT_ADD_BF16(10); - case 10: - CVT_ADD_BF16(9); - case 9: - CVT_ADD_BF16(8); - case 8: - CVT_ADD_BF16(7); - case 7: - CVT_ADD_BF16(6); - case 6: - CVT_ADD_BF16(5); - case 5: - CVT_ADD_BF16(4); - case 4: - CVT_ADD_BF16(3); - case 3: - CVT_ADD_BF16(2); - case 2: - CVT_ADD_BF16(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); - inout_val = _mm512_add_ps(inout_val, in_val); + memcpy(to_buffer + i, buffers[0] + i, element_size); + switch (world_size){ + case 16: fn(to_buffer + i, to_buffer + i, buffers[15] + i, vector_length); + case 15: fn(to_buffer + i, to_buffer + i, buffers[14] + i, vector_length); + case 14: fn(to_buffer + i, to_buffer + i, buffers[13] + i, vector_length); + case 13: fn(to_buffer + i, to_buffer + i, buffers[12] + i, vector_length); + case 12: fn(to_buffer + i, to_buffer + i, buffers[11] + i, vector_length); + case 11: fn(to_buffer + i, to_buffer + i, buffers[10] + i, vector_length); + case 10: fn(to_buffer + i, to_buffer + i, buffers[9] + i, vector_length); + case 9: fn(to_buffer + i, to_buffer + i, buffers[8] + i, vector_length); + case 8: fn(to_buffer + i, to_buffer + i, buffers[7] + i, vector_length); + case 7: fn(to_buffer + i, to_buffer + i, buffers[6] + i, vector_length); + case 6: fn(to_buffer + i, to_buffer + i, buffers[5] + i, vector_length); + case 5: fn(to_buffer + i, to_buffer + i, buffers[4] + i, vector_length); + case 4: fn(to_buffer + i, to_buffer + i, buffers[3] + i, vector_length); + case 3: fn(to_buffer + i, to_buffer + i, buffers[2] + i, vector_length); + case 2: fn(to_buffer + i, to_buffer + i, buffers[1] + i, vector_length); + case 1: break; + default: + for (int j = 1; j < world_size; j++) { + fn(to_buffer + i, to_buffer + i, buffers[j] + i, vector_length); + } } - } - _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); - } - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); -} - -#define CVT_ADD_FP16(x) \ - do { \ - auto in##x##_val = \ - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ - } while (0) - -void reduce_fp16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 2; - const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; - int main_elements = num_elements - (num_elements % vector_length); - int remain_elements = num_elements % vector_length; - - // process aligned part -#pragma omp parallel for - for (int i = start_elements * element_size; - i < (start_elements + main_elements) * element_size; - i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); - switch (world_size) { - case 16: - CVT_ADD_FP16(15); - case 15: - CVT_ADD_FP16(14); - case 14: - CVT_ADD_FP16(13); - case 13: - CVT_ADD_FP16(12); - case 12: - CVT_ADD_FP16(11); - case 11: - CVT_ADD_FP16(10); - case 10: - CVT_ADD_FP16(9); - case 9: - CVT_ADD_FP16(8); - case 8: - CVT_ADD_FP16(7); - case 7: - CVT_ADD_FP16(6); - case 6: - CVT_ADD_FP16(5); - case 5: - CVT_ADD_FP16(4); - case 4: - CVT_ADD_FP16(3); - case 3: - CVT_ADD_FP16(2); - case 2: - CVT_ADD_FP16(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); - inout_val = _mm512_add_ps(inout_val, in_val); } - } - _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val)); - } - - - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); -} - -#define CVT_ADD_F32(x) \ - do { \ - auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ - inout_val = _mm256_add_ps(inout_val, in##x##_val); \ - } while (0) - -void reduce_fp32_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 4; - const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; - int main_elements = num_elements - (num_elements % vector_length); - int remain_elements = num_elements % vector_length; - // process aligned part -#pragma omp parallel for - for (int i = start_elements * element_size; - i < (start_elements + main_elements) * element_size; - i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); - switch (world_size) { - case 16: - CVT_ADD_F32(15); - case 15: - CVT_ADD_F32(14); - case 14: - CVT_ADD_F32(13); - case 13: - CVT_ADD_F32(12); - case 12: - CVT_ADD_F32(11); - case 11: - CVT_ADD_F32(10); - case 10: - CVT_ADD_F32(9); - case 9: - CVT_ADD_F32(8); - case 8: - CVT_ADD_F32(7); - case 7: - CVT_ADD_F32(6); - case 6: - CVT_ADD_F32(5); - case 5: - CVT_ADD_F32(4); - case 4: - CVT_ADD_F32(3); - case 3: - CVT_ADD_F32(2); - case 2: - CVT_ADD_F32(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); - inout_val = _mm256_add_ps(inout_val, in_val); - } + size_t offset = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + memcpy(to_buffer + offset, buffers[0] + offset, element_size); + for (int j = 1; j < world_size; j++) { + + fn(to_buffer + offset, + to_buffer + offset, + buffers[j] + offset, + 1); + } - _mm256_storeu_ps((float*)(to_buffer + i), inout_val); + remain_elements--; + offset += element_size; } - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); + } void shm_initialize(int size, int rank, char* addr_string, char* port_string) { @@ -551,7 +280,7 @@ size_t slice_el_start(size_t chunk_el, int slice_idx) { void symmetric_naive_all_reduce( char* data_ptr, - AllreduceOptions::ScalarType scalar_type, + int element_size, size_t chunk_size, size_t chunk_el, ReductionFunction fn) { @@ -596,7 +325,7 @@ void symmetric_naive_all_reduce( reduce_all_buffers( 0, chunk_el, - scalar_type, + element_size, world_rank, data_ptr, symmetric_buffer[current_buffer], @@ -609,7 +338,7 @@ void symmetric_naive_all_reduce( // naive allreduce distributed, each rank do naive reduce on its slice void distributed_naive_reduce( char* data_ptr, - AllreduceOptions::ScalarType scalar_type, + int element_size, size_t chunk_size, size_t chunk_el, ReductionFunction fn) { @@ -653,7 +382,7 @@ void distributed_naive_reduce( reduce_all_buffers( slice_el_start(chunk_el, world_rank), slice_size(chunk_el, world_rank), - scalar_type, + element_size, world_rank, distributed_buffer[current_buffer][world_rank], distributed_buffer[current_buffer], @@ -738,10 +467,10 @@ void shm(const detail::AllreduceOptionsImpl& opts) { size_t chunk_el = chunk_size / (data_size / opts.elements); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); } else { distributed_naive_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); } } From 0fdde35a948e81970d07f2c149d6eda3801ed44d Mon Sep 17 00:00:00 2001 From: gaopengf Date: Thu, 24 Jul 2025 03:27:29 -0400 Subject: [PATCH 6/6] refine format --- gloo/CMakeLists.txt | 1 - gloo/allreduce.cc | 1 - gloo/allreduce.h | 8 -------- gloo/allreduce_shm.cc | 6 ++++++ gloo/allreduce_shm.h | 2 +- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index db54496ef..fb65defd5 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,7 +188,6 @@ if(USE_ROCM) endif() endif() - # Install if necessary. # If the Gloo build is included from another project's build, it may # want to statically link with Gloo and not install any artifacts. diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index 511e8d3d3..4099dd757 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -96,7 +96,6 @@ BroadcastRangeFunction genLocalBroadcastFunction(const BufferVector& out) { } void allreduce(const detail::AllreduceOptionsImpl& opts) { - //printf("In gloo::allreduce\n"); if (opts.elements == 0) { return; } diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 8fc494358..904eb8b32 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -11,17 +11,9 @@ #include #include #include -#include #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" -#include "gloo/types.h" - -#define GPF_PRINT(...) do {\ - printf("GPF_DEBUG:");\ - printf(__VA_ARGS__);\ - printf("\n");\ -}while(0) namespace gloo { diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 6af27d3ad..f972cd3ea 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -11,6 +11,12 @@ #include #include +#define GPF_PRINT(...) do {\ + printf("GPF_DEBUG:");\ + printf(__VA_ARGS__);\ + printf("\n");\ +}while(0) + namespace gloo { diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index e9236759c..3271ba4a2 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -5,4 +5,4 @@ namespace gloo { bool is_intra_node(const int size); void shm(const detail::AllreduceOptionsImpl& opts); -} // namespace gloo \ No newline at end of file +} // namespace gloo