From 86c06fc5d46a4433a724173dda81129c7f033504 Mon Sep 17 00:00:00 2001 From: liqingping Date: Thu, 23 Jan 2025 11:27:38 +0800 Subject: [PATCH 1/9] build: build gloo --- .gitignore | 60 +++++++++++++++++- examples/CMakeLists.txt | 14 +++++ examples/build.sh | 4 ++ examples/run.sh | 7 +++ examples/test_allreduce.cpp | 92 +++++++++++++++++++++++++++ examples/test_allreduce_ib.cpp | 110 +++++++++++++++++++++++++++++++++ examples/test_send_recv_ib.cpp | 108 ++++++++++++++++++++++++++++++++ examples/test_type.cpp | 47 ++++++++++++++ gloo/common/CMakeLists.txt | 1 + gloo/rendezvous/store.cc | 11 +++- hack/bench.sh | 25 ++++++++ hack/build.sh | 15 +++++ 12 files changed, 491 insertions(+), 3 deletions(-) create mode 100644 examples/CMakeLists.txt create mode 100644 examples/build.sh create mode 100644 examples/run.sh create mode 100644 examples/test_allreduce.cpp create mode 100644 examples/test_allreduce_ib.cpp create mode 100644 examples/test_send_recv_ib.cpp create mode 100644 examples/test_type.cpp create mode 100644 hack/bench.sh create mode 100644 hack/build.sh diff --git a/.gitignore b/.gitignore index e0d1d4e65..cdb50ee2d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,60 @@ -/build* +# Created by https://www.toptal.com/developers/gitignore/api/c++,cmake +# Edit at https://www.toptal.com/developers/gitignore?templates=c++,cmake + +### C++ ### +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +### CMake ### +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + +### CMake Patch ### +CMakeUserPresets.json + +# External projects +*-prefix/ + +# End of https://www.toptal.com/developers/gitignore/api/c++,cmake + +build/ *.pyc diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 000000000..092af62b3 --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,14 @@ +set(GLOO_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_allreduce.cpp") +add_executable(gloo_test ${GLOO_TEST_SRCS}) +target_link_libraries(gloo_test gloo ${GLOO_TEST_LIBRARIES}) + +set(IB_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_allreduce_ib.cpp") +add_executable(ib_test ${IB_TEST_SRCS}) +target_link_libraries(ib_test gloo ${IB_TEST_LIBRARIES}) + +set(SEND_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_send_recv_ib.cpp") +add_executable(send_test ${SEND_TEST_SRCS}) +target_link_libraries(send_test gloo ${IB_TEST_LIBRARIES}) + + +add_executable(type_test "${CMAKE_CURRENT_SOURCE_DIR}/test_type.cpp") \ No newline at end of file diff --git a/examples/build.sh b/examples/build.sh new file mode 100644 index 000000000..b74eafd4a --- /dev/null +++ b/examples/build.sh @@ -0,0 +1,4 @@ +mkdir build +cd build +cmake .. +make \ No newline at end of file diff --git a/examples/run.sh b/examples/run.sh new file mode 100644 index 000000000..35e4cb618 --- /dev/null +++ b/examples/run.sh @@ -0,0 +1,7 @@ +rm -rf /mnt/public/liqingping/opensource/gloo/tmp/file_store/* + +# in rank 0 +IB_DEVICE=mlx5_10 RANK=0 WORLD_SIZE=2 ./build/send_test + +# in rank 1 +IB_DEVICE=mlx5_10 RANK=1 WORLD_SIZE=2 ./build/send_test \ No newline at end of file diff --git a/examples/test_allreduce.cpp b/examples/test_allreduce.cpp new file mode 100644 index 000000000..902158853 --- /dev/null +++ b/examples/test_allreduce.cpp @@ -0,0 +1,92 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace gloo; + +int main() +{ + // Initialize context + auto rank = getenv("RANK"); + auto world_size = getenv("WORLD_SIZE"); + auto myRank = atoi(rank); + auto contextSize = atoi(world_size); + gloo::rendezvous::Context context(myRank, contextSize); + + // Perform rendezvous for TCP pairs + gloo::transport::tcp::attr attr("localhost"); + auto dev = gloo::transport::tcp::CreateDevice(attr); + // gloo::transport::ibverbs::attr attr = { + // "mlx5_10", 1, 1}; + // auto dev = gloo::transport::ibverbs::CreateDevice(attr); + gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); + context.connectFullMesh(store, dev); + + std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; + + size_t data_size = 3; + std::vector inputs{new float[data_size * 2]}; + for (auto i = 0; i < data_size; i++) + { + inputs[0][i] = i + 1; + inputs[0][i + data_size] = i + 1; + } + std::vector outputs{new float[data_size * 2]}; + for (auto i = 0; i < data_size; i++) + { + outputs[0][i] = 0; + outputs[0][i + data_size] = 0; + } + + for (auto i = 0; i < data_size; i++) + { + std::cout << "inputs[0][" << i << "] = " << inputs[0][i] << std::endl; + std::cout << "inputs[1][" << i << "] = " << inputs[0][i + data_size] << std::endl; + } + + std::shared_ptr rzv_context = std::make_shared(context); + AllreduceOptions opts(rzv_context); + auto algorithm = gloo::AllreduceOptions::Algorithm::RING; + opts.setAlgorithm(algorithm); + opts.setOutputs(outputs, data_size * 2); + std::cout << "##### before setInputs #####" << std::endl; + opts.setInputs(inputs, data_size * 2); + outputs.clear(); + + // gloo::AllreduceOptions::Func fn = [](void *a, const void *b, const void *c, size_t n) + // { + // return gloo::sum(a, b, c, n); + // }; + + // opts.setReduceFunction(fn); + opts.setReduceFunction([](void *a, const void *b, const void *c, size_t n) + { + std::cout << "a = " << a << ", b = " << b << ", c = " << c << ", n = " << n << std::endl; + auto ua = static_cast(a); + const auto ub = static_cast(b); + const auto uc = static_cast(c); + for (size_t i = 0; i < n; i++) { + ua[i] = ub[i] + uc[i]; + std::cout << "ua[" << i << "] = " << ua[i] << " = " << ub[i] << " + " << uc[i] << std::endl; + } }); + + // A small maximum segment size triggers code paths where we'll + // have a number of segments larger than the lower bound of + // twice the context size. + opts.setMaxSegmentSize(128); + + gloo::allreduce(opts); + + for (auto i = 0; i < data_size; i++) + { + std::cout << "outputs[0][" << i << "] = " << outputs[0][i] << std::endl; + std::cout << "outputs[0][" << i << "] = " << outputs[0][i + data_size] << std::endl; + } + return 0; +} diff --git a/examples/test_allreduce_ib.cpp b/examples/test_allreduce_ib.cpp new file mode 100644 index 000000000..c1f2444b5 --- /dev/null +++ b/examples/test_allreduce_ib.cpp @@ -0,0 +1,110 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace gloo; + +// Function to instantiate and run algorithm. +using Func = void( + std::shared_ptr<::gloo::Context>, + std::vector dataPtrs, + int dataSize); + +// RAII handle for aligned buffer +template +#ifdef _WIN32 +std::vector newBuffer(int size) +{ + return std::vector(size); +#else +std::vector> newBuffer(int size) +{ + return std::vector>(size); +#endif +} + +int main() +{ + // Initialize context + auto rank = getenv("RANK"); + if (!rank) + { + rank = "0"; + } + auto world_size = getenv("WORLD_SIZE"); + if (!world_size) + { + world_size = "1"; + } + auto ib_device = getenv("IB_DEVICE"); + if (!ib_device) + { + ib_device = "mlx5_0"; + } + auto myRank = atoi(rank); + auto contextSize = atoi(world_size); + gloo::rendezvous::Context context(myRank, contextSize); + + // Perform rendezvous for TCP pairs + // gloo::transport::tcp::attr attr("localhost"); + // auto dev = gloo::transport::tcp::CreateDevice(attr); + gloo::transport::ibverbs::attr attr = { + ib_device, 1, 1}; + auto dev = gloo::transport::ibverbs::CreateDevice(attr); + gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); + context.connectFullMesh(store, dev); + + std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; + + size_t data_size = 3; + static std::function allreduceRing = + [](std::shared_ptr<::gloo::Context> context, + std::vector dataPtrs, + int dataSize) + { + ::gloo::AllreduceRing algorithm(context, dataPtrs, dataSize); + algorithm.run(); + }; + + std::shared_ptr rzv_context = std::make_shared(context); + + // std::vector ptr{new float[data_size * 2]}; + // for (auto i = 0; i < data_size; i++) + // { + // ptr[0][i] = i + 1; + // ptr[0][i + data_size] = i + 1; + // } + // for (auto i = 0; i < data_size; i++) + // { + // std::cout << "ptr[0][" << i << "] = " << ptr[0][i] << std::endl; + // std::cout << "ptr[0][" << i + data_size << "] = " << ptr[0][i + data_size] << std::endl; + // } + + // allreduceRing(rzv_context, ptr, data_size * 2); + + const auto contextRank = rzv_context->rank; + auto buffer = newBuffer(data_size * 2); + auto *ptr = buffer.data(); + + for (int i = 0; i < data_size; i++) + { + ptr[i] = i + 1; + ptr[i + data_size] = i + 1; + } + + allreduceRing(rzv_context, std::vector{ptr}, data_size * 2); + + for (auto i = 0; i < data_size * 2; i++) + { + std::cout << "ptr[" << i << "] = " << ptr[i] << std::endl; + } + return 0; +} diff --git a/examples/test_send_recv_ib.cpp b/examples/test_send_recv_ib.cpp new file mode 100644 index 000000000..85332e86f --- /dev/null +++ b/examples/test_send_recv_ib.cpp @@ -0,0 +1,108 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace gloo; + +// Function to instantiate and run algorithm. +using Func = void( + std::shared_ptr<::gloo::Context>, + std::vector dataPtrs, + int dataSize); + +// RAII handle for aligned buffer +template +#ifdef _WIN32 +std::vector newBuffer(int size) +{ + return std::vector(size); +#else +std::vector> newBuffer(int size) +{ + return std::vector>(size); +#endif +} + +int main() +{ + // Initialize context + auto rank = getenv("RANK"); + if (!rank) + { + rank = "0"; + } + auto world_size = getenv("WORLD_SIZE"); + if (!world_size) + { + world_size = "1"; + } + auto ib_device = getenv("IB_DEVICE"); + if (!ib_device) + { + ib_device = "mlx5_0"; + } + auto myRank = atoi(rank); + auto contextSize = atoi(world_size); + gloo::rendezvous::Context context(myRank, contextSize); + + // // Perform rendezvous for TCP pairs + // gloo::transport::tcp::attr attr("localhost"); + // auto dev = gloo::transport::tcp::CreateDevice(attr); + gloo::transport::ibverbs::attr attr = { + ib_device, 1, 1}; + auto dev = gloo::transport::ibverbs::CreateDevice(attr); + gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); + context.connectFullMesh(store, dev); + + std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; + + std::shared_ptr rzv_context = std::make_shared(context); + size_t data_size = 3; + float sends[data_size] = {1 + float(myRank), 2 + float(myRank), 3 + float(myRank)}; + float recvs[data_size] = {0, 0, 0}; + for (auto i = 0; i < data_size; i++) + { + std::cout << "sends[" << i << "] = " << sends[i] << std::endl; + } + + auto slot = context.nextSlot(); + int peer; + if (context.rank == 0) + peer = 1; + else + peer = 0; + + std::cout << "peer = " << peer << std::endl; + int bytes_ = sizeof(float) * data_size; + auto inbox_ = static_cast(malloc(bytes_)); + auto outbox_ = static_cast(malloc(bytes_)); + auto &pair = context.getPair(peer); + std::unique_ptr<::gloo::transport::Buffer> sendBuf = pair->createSendBuffer(slot, outbox_, bytes_); + std::unique_ptr<::gloo::transport::Buffer> recvBuf = pair->createRecvBuffer(slot, inbox_, bytes_); + + std::memcpy(outbox_, sends, bytes_); + + sendBuf->send(); + recvBuf->waitRecv(); + sendBuf->waitSend(); + + std::memcpy(recvs, inbox_, bytes_); + + for (auto i = 0; i < data_size; i++) + { + std::cout << "recvs[" << i << "] = " << recvs[i] << std::endl; + } + + free(inbox_); + free(outbox_); + return 0; +} diff --git a/examples/test_type.cpp b/examples/test_type.cpp new file mode 100644 index 000000000..2438e5aeb --- /dev/null +++ b/examples/test_type.cpp @@ -0,0 +1,47 @@ +#include +#include +#include // For std::memcpy +#include + +#include + +using namespace gloo; + +// RAII handle for aligned buffer +template +#ifdef _WIN32 +std::vector newBuffer(int size) +{ + return std::vector(size); +#else +std::vector> newBuffer(int size) +{ + return std::vector>(size); +#endif +} + +int main() +{ + // Example size and alignment + constexpr std::size_t kBufferAlignment = 64; + constexpr std::size_t size = 10; + + // Create the aligned vector + auto a = std::vector>(size); + + // Simulate intptr_t pointing to external data + int external_data[size] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + intptr_t b = reinterpret_cast(external_data); + + // Write the data from b into a + std::memcpy(a.data(), reinterpret_cast(b), size * sizeof(int)); + + // Print the result + for (auto val : a) + { + std::cout << val << " "; + } + std::cout << std::endl; + + return 0; +} \ No newline at end of file diff --git a/gloo/common/CMakeLists.txt b/gloo/common/CMakeLists.txt index a8d474498..2b63d43a2 100644 --- a/gloo/common/CMakeLists.txt +++ b/gloo/common/CMakeLists.txt @@ -8,6 +8,7 @@ set(GLOO_COMMON_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/common.h" "${CMAKE_CURRENT_SOURCE_DIR}/error.h" "${CMAKE_CURRENT_SOURCE_DIR}/logging.h" + "${CMAKE_CURRENT_SOURCE_DIR}/memory.h" "${CMAKE_CURRENT_SOURCE_DIR}/store.h" "${CMAKE_CURRENT_SOURCE_DIR}/string.h" "${CMAKE_CURRENT_SOURCE_DIR}/utils.h" diff --git a/gloo/rendezvous/store.cc b/gloo/rendezvous/store.cc index 9f901cf80..d30e9bae2 100644 --- a/gloo/rendezvous/store.cc +++ b/gloo/rendezvous/store.cc @@ -6,6 +6,13 @@ * LICENSE file in the root directory of this source tree. */ -namespace gloo { -namespace rendezvous {} // namespace rendezvous +#include + +namespace gloo +{ + namespace rendezvous + { + // 定义静态成员变量 + constexpr std::chrono::milliseconds Store::kDefaultTimeout; + } // namespace rendezvous } // namespace gloo diff --git a/hack/bench.sh b/hack/bench.sh new file mode 100644 index 000000000..b18a3de28 --- /dev/null +++ b/hack/bench.sh @@ -0,0 +1,25 @@ +ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +cd $ROOT_DIR/.. + +cd build/gloo/benchmark + +NNODES=${WORLD_SIZE:-1} +NODE_RANK=${RANK:-0} +REDIS_HOST=${REDIS_HOST:-"localhost"} +REDIS_PORT=${REDIS_PORT:-6379} + +./benchmark \ + --size ${NNODES} \ + --rank ${NODE_RANK} \ + --redis-host ${REDIS_HOST} \ + --redis-port ${REDIS_PORT} \ + --prefix test-for-benchmark \ + --transport ibverbs \ + --ib-device mlx5_10 \ + --ib-port 1 \ + --elements $(( 1024 * 1024 )) \ + --inputs 4 \ + --iteration-time 2s \ + allreduce_ring + # allreduce_ring_chunked diff --git a/hack/build.sh b/hack/build.sh new file mode 100644 index 000000000..c0a98db5f --- /dev/null +++ b/hack/build.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +cd $ROOT_DIR/.. + +mkdir -p build +cd build +cmake -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ \ + -DCMAKE_BUILD_TYPE=Debug \ + -DUSE_IBVERBS=1 -DBUILD_BENCHMARK=1 -DUSE_REDIS=1 \ + -DBUILD_SHARED_LIBS=1 \ + ../ +make +make install \ No newline at end of file From 5c7c6de7ace44121ff0d4ad8c190d4d8df99b9b0 Mon Sep 17 00:00:00 2001 From: liqingping Date: Thu, 13 Feb 2025 16:17:25 +0800 Subject: [PATCH 2/9] feat: add tcp store --- gloo/rendezvous/CMakeLists.txt | 2 + gloo/rendezvous/tcp_store.cc | 295 +++++++++++++++++++++++++++++++++ gloo/rendezvous/tcp_store.h | 87 ++++++++++ 3 files changed, 384 insertions(+) create mode 100644 gloo/rendezvous/tcp_store.cc create mode 100644 gloo/rendezvous/tcp_store.h diff --git a/gloo/rendezvous/CMakeLists.txt b/gloo/rendezvous/CMakeLists.txt index 0fae748cc..84c618cca 100644 --- a/gloo/rendezvous/CMakeLists.txt +++ b/gloo/rendezvous/CMakeLists.txt @@ -3,6 +3,7 @@ set(GLOO_RENDEZVOUS_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/file_store.cc" "${CMAKE_CURRENT_SOURCE_DIR}/hash_store.cc" "${CMAKE_CURRENT_SOURCE_DIR}/prefix_store.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/tcp_store.cc" "${CMAKE_CURRENT_SOURCE_DIR}/store.cc" "${CMAKE_CURRENT_SOURCE_DIR}/context.cc" ) @@ -11,6 +12,7 @@ set(GLOO_RENDEZVOUS_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/file_store.h" "${CMAKE_CURRENT_SOURCE_DIR}/hash_store.h" "${CMAKE_CURRENT_SOURCE_DIR}/prefix_store.h" + "${CMAKE_CURRENT_SOURCE_DIR}/tcp_store.h" "${CMAKE_CURRENT_SOURCE_DIR}/store.h" "${CMAKE_CURRENT_SOURCE_DIR}/context.h" ) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc new file mode 100644 index 000000000..01979c5d8 --- /dev/null +++ b/gloo/rendezvous/tcp_store.cc @@ -0,0 +1,295 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "tcp_store.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#else +#include +#endif + +#include "gloo/common/error.h" +#include "gloo/common/logging.h" + +#define BUFFER_SIZE 1024 +const std::string POST_ACTION_SET = "set"; +const std::string POST_ACTION_GET = "get"; +const std::string NOT_FOUND = "NOT_FOUND"; + +namespace gloo +{ + namespace rendezvous + { + TCPStore::~TCPStore() + { + close(server_fd); + } + + TCPStore::TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout) + : hostname_(hostname), + host_ip_(host_to_ip(hostname)), + port_(port), + world_size_(world_size), + is_master_(is_master), + timeout_(timeout), + data_({}) + { + uint16_t PORT = static_cast(port); + std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size + << ", is_master: " << is_master << std::endl; + if (is_master) + { + // 创建 socket + server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // 设置服务器地址信息 + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; // 监听所有的网络接口 + address.sin_port = htons(PORT); + + // 绑定 socket 到地址 + if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) + { + auto err = std::string("Socket bind failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // 开始监听 + if (listen(server_fd, 3) < 0) + { + auto err = std::string("Socket listen failed: ") + strerror(errno); + GLOO_THROW(err); + } + + std::thread(&TCPStore::accept_func, this).detach(); + } + else + { + // 创建 socket + server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // 设置服务器地址信息 + server_address.sin_family = AF_INET; + server_address.sin_port = htons(PORT); + + // 将 IP 地址从文本转换为二进制形式 + if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + { + auto err = std::string("Invalid address: ") + strerror(errno); + GLOO_THROW(err); + } + + // 连接服务器 + if (connect(server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) + { + auto err = std::string("Connection to server failed: ") + strerror(errno); + GLOO_THROW(err); + } + } + } + + void TCPStore::accept_func() + { + // 接受客户端连接 + int new_socket; + struct sockaddr_in client_address; + socklen_t addr_len = sizeof(client_address); + new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len); + if (new_socket < 0) + { + auto err = std::string("Accept client connection failed: ") + strerror(errno); + GLOO_THROW(err); + } + + std::cout << "Connection established with client." << std::endl; + + // 服务器进入循环,持续接受客户端连接 + while (true) + { + // 读取客户端消息 + char buffer[BUFFER_SIZE] = {0}; + int valread = read(new_socket, buffer, BUFFER_SIZE); + if (valread > 0) + { + std::string buffer_str = std::string(buffer); + std::vector buffer_split = str_split(buffer_str, ':'); + if (buffer_split.size() < 2) + { + GLOO_THROW("Invalid message format, must be formated as [action]:[key]:[value] or [action]:[key]!"); + } + + std::string action = buffer_split[0]; + if (action == POST_ACTION_SET) + { + std::string key = buffer_split[1]; + std::string value = buffer_split[2]; + std::vector value_vec(value.begin(), value.end()); + mtx.lock(); + data_[key] = value_vec; + mtx.unlock(); + + // 向客户端发送响应 + const char *response = "OK"; + send(new_socket, response, strlen(response), 0); + // std::cout << "Response sent to client." << std::endl; + } + else if (action == POST_ACTION_GET) + { + std::string key = buffer_split[1]; + bool found = false; + std::vector value = {}; + + mtx.lock(); + if (data_.find(key) != data_.end()) + { + found = true; + value = data_[key]; + } + mtx.unlock(); + + std::string value_str(value.begin(), value.end()); + value_str = found ? value_str : NOT_FOUND; + const char *response = value_str.c_str(); + send(new_socket, response, strlen(response), 0); + } + else + { + // 向客户端发送响应 + const char *response = "OK"; + send(new_socket, response, strlen(response), 0); + // std::cout << "Response sent to client." << std::endl; + } + } + } + close(new_socket); + } + + void TCPStore::set(const std::string &key, const std::vector &data) + { + if (is_master_) + { + mtx.lock(); + data_[key] = data; + mtx.unlock(); + } + else + { + // 向服务器发送消息 + std::string key_with_data = POST_ACTION_SET + ":" + key + ":" + std::string(data.begin(), data.end()); + const char *message = key_with_data.c_str(); + send(server_fd, message, strlen(message), 0); + // std::cout << "Message sent to server." << std::endl; + + // 读取服务器响应 + char buffer[BUFFER_SIZE] = {0}; + int valread = read(server_fd, buffer, BUFFER_SIZE); + // std::cout << "Server response: " << buffer << std::endl; + } + } + + std::vector TCPStore::get(const std::string &key) + { + if (is_master_) + { + bool found = false; + std::vector value = {}; + + mtx.lock(); + if (data_.find(key) != data_.end()) + { + found = true; + value = data_[key]; + } + mtx.unlock(); + + std::string value_str(value.begin(), value.end()); + value_str = found ? value_str : NOT_FOUND; + return std::vector(value_str.begin(), value_str.end()); + } + else + { + // 向服务器发送消息 + std::string key_with_data = POST_ACTION_GET + ":" + key; + const char *message = key_with_data.c_str(); + send(server_fd, message, strlen(message), 0); + // std::cout << "Message sent to server." << std::endl; + + // 读取服务器响应 + char buffer[BUFFER_SIZE] = {0}; + int valread = read(server_fd, buffer, BUFFER_SIZE); + if (valread > 0) + { + std::string buffer_str = std::string(buffer); + // std::cout << "Server response: " << buffer_str << std::endl; + + return std::vector(buffer_str.begin(), buffer_str.end()); + } + else + { + GLOO_THROW("Server response failed!"); + } + } + } + + void TCPStore::wait( + const std::vector &keys, + const std::chrono::milliseconds &timeout) + { + const auto start = std::chrono::steady_clock::now(); + auto check = [&](const std::vector &keys) -> bool + { + for (const auto &key : keys) + { + auto data = get(key); + std::string buffer_str(data.begin(), data.end()); + // std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; + if (buffer_str == NOT_FOUND) + { + return false; + } + } + return true; + }; + + while (!check(keys)) + { + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout != kNoTimeout && elapsed > timeout) + { + GLOO_THROW_IO_EXCEPTION(GLOO_ERROR_MSG( + "Wait timeout for key(s): ", ::gloo::MakeString(keys))); + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + } // namespace rendezvous +} // namespace gloo diff --git a/gloo/rendezvous/tcp_store.h b/gloo/rendezvous/tcp_store.h new file mode 100644 index 000000000..fba9e6aa1 --- /dev/null +++ b/gloo/rendezvous/tcp_store.h @@ -0,0 +1,87 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "gloo/rendezvous/store.h" + +#include +#include +#include +#include +#include + +using namespace gloo; + +namespace gloo +{ + namespace rendezvous + { + + class TCPStore : public gloo::rendezvous::Store + { + public: + explicit TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout = 30); + virtual ~TCPStore(); + + virtual void set(const std::string &key, const std::vector &data) + override; + + virtual std::vector get(const std::string &key) override; + + virtual void wait(const std::vector &keys) override + { + auto timeout = std::chrono::seconds(timeout_); + wait(keys, timeout); + } + + virtual void wait( + const std::vector &keys, + const std::chrono::milliseconds &timeout) override; + + virtual void accept_func(); + + std::string host_to_ip(const std::string &host) + { + hostent *hostname = gethostbyname(host.c_str()); + if (hostname) + return std::string(inet_ntoa(**(in_addr **)hostname->h_addr_list)); + return {}; + } + + std::vector str_split(const std::string &str, char delimiter) + { + std::vector tokens; + std::stringstream ss(str); + std::string token; + + while (std::getline(ss, token, delimiter)) + { + tokens.push_back(token); + } + + return tokens; + } + + protected: + std::string hostname_; + std::string host_ip_; + int port_; + int world_size_; + bool is_master_; + int timeout_; + + std::mutex mtx; + + int server_fd; + struct sockaddr_in server_address; + std::map> data_; + }; + + } // namespace rendezvous +} // namespace pygloo From ab9b40eff8e10c0857b9600679573bb0d367ff9a Mon Sep 17 00:00:00 2001 From: liqingping Date: Thu, 13 Feb 2025 23:07:09 +0800 Subject: [PATCH 3/9] feat: tcp store success --- gloo/rendezvous/tcp_store.cc | 300 ++++++++++++++++++---------- gloo/rendezvous/tcp_store.h | 8 +- gloo/transport/context.cc | 366 +++++++++++++++++++---------------- 3 files changed, 403 insertions(+), 271 deletions(-) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc index 01979c5d8..5d4799bff 100644 --- a/gloo/rendezvous/tcp_store.cc +++ b/gloo/rendezvous/tcp_store.cc @@ -27,6 +27,10 @@ #include "gloo/common/logging.h" #define BUFFER_SIZE 1024 +#define ACTION_SIZE 3 +#define SIZE_OF_SIZE 16 +#define RESPONSE_SIZE 2 + const std::string POST_ACTION_SET = "set"; const std::string POST_ACTION_GET = "get"; const std::string NOT_FOUND = "NOT_FOUND"; @@ -43,13 +47,12 @@ namespace gloo TCPStore::TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout) : hostname_(hostname), host_ip_(host_to_ip(hostname)), - port_(port), + port_(static_cast(port)), world_size_(world_size), is_master_(is_master), timeout_(timeout), data_({}) { - uint16_t PORT = static_cast(port); std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size << ", is_master: " << is_master << std::endl; if (is_master) @@ -66,7 +69,7 @@ namespace gloo struct sockaddr_in address; address.sin_family = AF_INET; address.sin_addr.s_addr = INADDR_ANY; // 监听所有的网络接口 - address.sin_port = htons(PORT); + address.sin_port = htons(port_); // 绑定 socket 到地址 if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) @@ -84,110 +87,115 @@ namespace gloo std::thread(&TCPStore::accept_func, this).detach(); } - else + } + + void TCPStore::accept_func() + { + + // 服务器进入循环,持续接受客户端连接 + while (true) { - // 创建 socket - server_fd = socket(AF_INET, SOCK_STREAM, 0); - if (server_fd == -1) + // 接受客户端连接 + int new_socket; + struct sockaddr_in client_address; + socklen_t addr_len = sizeof(client_address); + new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len); + if (new_socket < 0) { - auto err = std::string("Socket creation failed: ") + strerror(errno); + auto err = std::string("Accept client connection failed: ") + strerror(errno); GLOO_THROW(err); } - // 设置服务器地址信息 - server_address.sin_family = AF_INET; - server_address.sin_port = htons(PORT); + std::cout << "Connection established with client." << std::endl; - // 将 IP 地址从文本转换为二进制形式 - if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + // 读取客户端消息 + char act_buffer[ACTION_SIZE + 1] = {0}; + int valread = read(new_socket, act_buffer, ACTION_SIZE); + std::string action = std::string(act_buffer); + if (action == POST_ACTION_SET) { - auto err = std::string("Invalid address: ") + strerror(errno); - GLOO_THROW(err); - } + std::cout << "Set request received." << std::endl; - // 连接服务器 - if (connect(server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) - { - auto err = std::string("Connection to server failed: ") + strerror(errno); - GLOO_THROW(err); - } - } - } + // read key size + char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; + int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); + int key_size = atoi(key_size_buffer); + std::cout << "key size: " << key_size << std::endl; - void TCPStore::accept_func() - { - // 接受客户端连接 - int new_socket; - struct sockaddr_in client_address; - socklen_t addr_len = sizeof(client_address); - new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len); - if (new_socket < 0) - { - auto err = std::string("Accept client connection failed: ") + strerror(errno); - GLOO_THROW(err); - } + // read key + char key_buffer[key_size + 1] = {0}; + valread = read(new_socket, key_buffer, key_size); + std::string key = std::string(key_buffer); + std::cout << "key: " << key << std::endl; - std::cout << "Connection established with client." << std::endl; + // read data size + char data_size_buffer[SIZE_OF_SIZE + 1] = {0}; + valread = read(new_socket, data_size_buffer, SIZE_OF_SIZE); + int data_size = atoi(data_size_buffer); + std::cout << "data size: " << data_size << std::endl; - // 服务器进入循环,持续接受客户端连接 - while (true) - { - // 读取客户端消息 - char buffer[BUFFER_SIZE] = {0}; - int valread = read(new_socket, buffer, BUFFER_SIZE); - if (valread > 0) + // read data + char data_buffer[data_size + 1] = {0}; + valread = read(new_socket, data_buffer, data_size); + std::string value = std::string(data_buffer); + std::vector value_vec(data_buffer, data_buffer + data_size); + std::cout << "data_buffer: <" << data_buffer << ">" << std::endl; + std::cout << "value read: " << valread << "value: <" << value << ">" << std::endl; + + mtx.lock(); + data_[key] = value_vec; + mtx.unlock(); + + // 向客户端发送响应 + const char *response = "OK"; + send(new_socket, response, strlen(response), 0); + // std::cout << "Response sent to client." << std::endl; + } + else if (action == POST_ACTION_GET) { - std::string buffer_str = std::string(buffer); - std::vector buffer_split = str_split(buffer_str, ':'); - if (buffer_split.size() < 2) + std::cout << "Get request received." << std::endl; + // read key size + char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; + int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); + int key_size = atoi(key_size_buffer); + + // read key + char key_buffer[key_size + 1] = {0}; + valread = read(new_socket, key_buffer, key_size); + std::string key = std::string(key_buffer); + std::cout << "get key: " << key << std::endl; + + bool found = false; + std::vector value = {}; + + mtx.lock(); + if (data_.find(key) != data_.end()) { - GLOO_THROW("Invalid message format, must be formated as [action]:[key]:[value] or [action]:[key]!"); + found = true; + value = data_[key]; } + mtx.unlock(); - std::string action = buffer_split[0]; - if (action == POST_ACTION_SET) + if (found) { - std::string key = buffer_split[1]; - std::string value = buffer_split[2]; - std::vector value_vec(value.begin(), value.end()); - mtx.lock(); - data_[key] = value_vec; - mtx.unlock(); - - // 向客户端发送响应 - const char *response = "OK"; - send(new_socket, response, strlen(response), 0); - // std::cout << "Response sent to client." << std::endl; - } - else if (action == POST_ACTION_GET) - { - std::string key = buffer_split[1]; - bool found = false; - std::vector value = {}; - - mtx.lock(); - if (data_.find(key) != data_.end()) - { - found = true; - value = data_[key]; - } - mtx.unlock(); - - std::string value_str(value.begin(), value.end()); - value_str = found ? value_str : NOT_FOUND; - const char *response = value_str.c_str(); - send(new_socket, response, strlen(response), 0); + send(new_socket, value.data(), value.size(), 0); } else { - // 向客户端发送响应 - const char *response = "OK"; + const char *response = NOT_FOUND.c_str(); send(new_socket, response, strlen(response), 0); - // std::cout << "Response sent to client." << std::endl; } } + else + { + // 向客户端发送响应 + const char *response = "OK"; + send(new_socket, response, strlen(response), 0); + std::cout << "Response sent to client." << std::endl; + } + + close(new_socket); } - close(new_socket); } void TCPStore::set(const std::string &key, const std::vector &data) @@ -200,16 +208,68 @@ namespace gloo } else { - // 向服务器发送消息 - std::string key_with_data = POST_ACTION_SET + ":" + key + ":" + std::string(data.begin(), data.end()); - const char *message = key_with_data.c_str(); - send(server_fd, message, strlen(message), 0); - // std::cout << "Message sent to server." << std::endl; + // 创建 socket + int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (new_server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // 设置服务器地址信息 + server_address.sin_family = AF_INET; + server_address.sin_port = htons(port_); + + // 将 IP 地址从文本转换为二进制形式 + if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + { + auto err = std::string("Invalid address: ") + strerror(errno); + GLOO_THROW(err); + } + + // 连接服务器 + if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) + { + auto err = std::string("Connection to server failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // send action + std::string act_data = POST_ACTION_SET; + const char *message = act_data.c_str(); + send(new_server_fd, message, strlen(message), 0); + + // send key size + size_t len = key.length(); + std::string len_str = std::to_string(len); + len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + message = len_str.c_str(); + send(new_server_fd, message, strlen(message), 0); + std::cout << "key size: " << len_str << std::endl; + + // send key + message = key.c_str(); + send(new_server_fd, message, strlen(message), 0); + std::cout << "key: " << key << std::endl; + + // send data size + len = data.size(); + len_str = std::to_string(len); + len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + message = len_str.c_str(); + send(new_server_fd, message, strlen(message), 0); + std::cout << "data size: " << len_str << std::endl; + + // send data + void *data_ptr = static_cast(const_cast(data.data())); + send(new_server_fd, data_ptr, len, 0); // 读取服务器响应 - char buffer[BUFFER_SIZE] = {0}; - int valread = read(server_fd, buffer, BUFFER_SIZE); - // std::cout << "Server response: " << buffer << std::endl; + char buffer[RESPONSE_SIZE] = {0}; + int valread = read(new_server_fd, buffer, RESPONSE_SIZE); + std::cout << key << " set request, server response: " << buffer << std::endl; + + close(new_server_fd); } } @@ -234,26 +294,65 @@ namespace gloo } else { - // 向服务器发送消息 - std::string key_with_data = POST_ACTION_GET + ":" + key; - const char *message = key_with_data.c_str(); - send(server_fd, message, strlen(message), 0); + // 创建 socket + int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (new_server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // 设置服务器地址信息 + server_address.sin_family = AF_INET; + server_address.sin_port = htons(port_); + + // 将 IP 地址从文本转换为二进制形式 + if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + { + auto err = std::string("Invalid address: ") + strerror(errno); + GLOO_THROW(err); + } + + // 连接服务器 + if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) + { + auto err = std::string("Connection to server failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // send action + std::string act_data = POST_ACTION_GET; + const char *message = act_data.c_str(); + send(new_server_fd, message, strlen(message), 0); // std::cout << "Message sent to server." << std::endl; + // send key size + size_t len = key.length(); + std::string len_str = std::to_string(len); + len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + message = len_str.c_str(); + send(new_server_fd, message, strlen(message), 0); + + // send key + message = key.c_str(); + send(new_server_fd, message, strlen(message), 0); + // 读取服务器响应 char buffer[BUFFER_SIZE] = {0}; - int valread = read(server_fd, buffer, BUFFER_SIZE); + int valread = read(new_server_fd, buffer, BUFFER_SIZE); if (valread > 0) { std::string buffer_str = std::string(buffer); - // std::cout << "Server response: " << buffer_str << std::endl; + std::cout << key << " get request, server response: " << buffer_str << std::endl; - return std::vector(buffer_str.begin(), buffer_str.end()); + return std::vector(buffer, buffer + valread); } else { GLOO_THROW("Server response failed!"); } + + close(new_server_fd); } } @@ -268,7 +367,7 @@ namespace gloo { auto data = get(key); std::string buffer_str(data.begin(), data.end()); - // std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; + std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; if (buffer_str == NOT_FOUND) { return false; @@ -288,6 +387,7 @@ namespace gloo } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } diff --git a/gloo/rendezvous/tcp_store.h b/gloo/rendezvous/tcp_store.h index fba9e6aa1..264d54b94 100644 --- a/gloo/rendezvous/tcp_store.h +++ b/gloo/rendezvous/tcp_store.h @@ -8,7 +8,7 @@ #pragma once -#include "gloo/rendezvous/store.h" +#include "store.h" #include #include @@ -16,14 +16,12 @@ #include #include -using namespace gloo; - namespace gloo { namespace rendezvous { - class TCPStore : public gloo::rendezvous::Store + class TCPStore : public Store { public: explicit TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout = 30); @@ -71,7 +69,7 @@ namespace gloo protected: std::string hostname_; std::string host_ip_; - int port_; + uint16_t port_; int world_size_; bool is_master_; int timeout_; diff --git a/gloo/transport/context.cc b/gloo/transport/context.cc index 0b1f45b2c..7d9740e38 100644 --- a/gloo/transport/context.cc +++ b/gloo/transport/context.cc @@ -9,181 +9,215 @@ #include "gloo/transport/context.h" #include "gloo/common/utils.h" -namespace gloo { -namespace transport { - -Context::Context(int rank, int size) : rank(rank), size(size) { - pairs_.resize(size); -} - -// Have to provide implementation for pure virtual destructor. -Context::~Context() {} - -std::unique_ptr& Context::getPair(int rank_2) { - return pairs_.at(rank_2); -} - -void Context::createAndConnectAllPairs(IStore& store) { - // this is the default un-optimized version of the rendezvous protocol - // where each rank would write N pairs to the store - // and then for each remote peer load the N addresses - // and only pick the 1 useful - // A more efficient version (for transport supporting multiplexing like TCP) - // can be seen in gloo/transport/tcp/context.cc - - std::vector allBytes; - int localRank = 0; - - auto localHostName = getHostname(); - // Add global rank <> hostname pair to the Store. This store is then passed - // to Gloo when connectFullMesh is called, where Gloo uses the global rank <> - // hostname mapping to compute local ranks. - std::string localKey("rank_" + std::to_string(rank)); - const std::vector value(localHostName.begin(), localHostName.end()); - store.set(localKey, value); - - for (int i = 0; i < size; i++) { - if (i == rank) { - break; +#include + +namespace gloo +{ + namespace transport + { + + Context::Context(int rank, int size) : rank(rank), size(size) + { + pairs_.resize(size); + } + + // Have to provide implementation for pure virtual destructor. + Context::~Context() {} + + std::unique_ptr &Context::getPair(int rank_2) + { + return pairs_.at(rank_2); + } + + void Context::createAndConnectAllPairs(IStore &store) + { + // this is the default un-optimized version of the rendezvous protocol + // where each rank would write N pairs to the store + // and then for each remote peer load the N addresses + // and only pick the 1 useful + // A more efficient version (for transport supporting multiplexing like TCP) + // can be seen in gloo/transport/tcp/context.cc + + std::vector allBytes; + int localRank = 0; + + auto localHostName = getHostname(); + // Add global rank <> hostname pair to the Store. This store is then passed + // to Gloo when connectFullMesh is called, where Gloo uses the global rank <> + // hostname mapping to compute local ranks. + std::string localKey("rank_" + std::to_string(rank)); + const std::vector value(localHostName.begin(), localHostName.end()); + store.set(localKey, value); + + for (int i = 0; i < size; i++) + { + if (i == rank) + { + break; + } + + std::string key("rank_" + std::to_string(i)); + auto val = store.get(key); + auto hostName = std::string((const char *)val.data(), val.size()); + + if (hostName == localHostName) + { + localRank++; + } + } + + // Create pairs + for (int i = 0; i < size; i++) + { + if (i == rank) + { + continue; + } + + auto &pair = createPair(i); + pair->setLocalRank(localRank); + auto addrBytes = pair->address().bytes(); + // std::cout << "len: " << addrBytes.size() << "addrBytes: " << std::string(addrBytes.begin(), addrBytes.end()) << std::endl; + + allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end()); + } + + store.set(std::to_string(rank), allBytes); + + // Connect every pair + for (int i = 0; i < size; i++) + { + if (i == rank) + { + continue; + } + + // Wait for address of other side of this pair to become available + std::ostringstream key; + key << i; + store.wait({key.str()}, getTimeout()); + + // Connect to other side of this pair + auto allAddrs = store.get(key.str()); + auto addr = extractAddress(allAddrs, i); + getPair(i)->connect(addr); + } + } + + std::vector Context::extractAddress( + const std::vector &allAddrs, + int i) const + { + // Extract address from the list of all addresses + int adjRank = (rank > i ? rank - 1 : rank); + // Adjust for the fact that nodes do not store address for themselves + int addrSize = allAddrs.size() / (size - 1); + return std::vector( + allAddrs.begin() + adjRank * addrSize, + allAddrs.begin() + (adjRank + 1) * addrSize); } - std::string key("rank_" + std::to_string(i)); - auto val = store.get(key); - auto hostName = std::string((const char*)val.data(), val.size()); + Context::LazyTally::LazyTally(std::vector &vec, slot_t slot) + : vec_(vec), slot_(slot), initialized_(false) {} - if (hostName == localHostName) { - localRank++; + Context::LazyTally::~LazyTally() + { + // Remove empty tally from vector. + if (initialized_ && it_ != vec_.end() && it_->empty()) + { + vec_.erase(it_); + } } - } - // Create pairs - for (int i = 0; i < size; i++) { - if (i == rank) { - continue; + bool Context::LazyTally::exists() + { + initialize_iterator(); + return it_ != vec_.end(); } - auto& pair = createPair(i); - pair->setLocalRank(localRank); - auto addrBytes = pair->address().bytes(); - allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end()); - } + Context::Tally &Context::LazyTally::get() + { + initialize_iterator(); + if (it_ == vec_.end()) + { + vec_.emplace_back(slot_); + it_ = (vec_.end() - 1); + } + return *it_; + } + + void Context::LazyTally::initialize_iterator() + { + if (initialized_) + { + return; + } + + it_ = + std::find_if(vec_.begin(), vec_.end(), [this](const Context::Tally &op) + { return op.slot == slot_; }); + initialized_ = true; + } - store.set(std::to_string(rank), allBytes); + Context::Mutator::Mutator(Context &context, slot_t slot, rank_t rank) + : lock_(context.mutex_), + context_(context), + slot_(slot), + rank_(rank), + pendingOperations_(context_.pendingOperations_, slot_), + expectedNotifications_(context_.expectedNotifications_, slot_) {} + + void Context::Mutator::pushRemotePendingRecv() + { + pendingOperations_.get().pushRecv(rank_); + } + + void Context::Mutator::pushRemotePendingSend() + { + pendingOperations_.get().pushSend(rank_); + } + + bool Context::Mutator::shiftRemotePendingRecv() + { + if (!pendingOperations_.exists()) + { + return false; + } + return pendingOperations_.get().shiftRecv(rank_); + } + + bool Context::Mutator::shiftRemotePendingSend() + { + if (!pendingOperations_.exists()) + { + return false; + } + return pendingOperations_.get().shiftSend(rank_); + } + + void Context::Mutator::pushExpectedSendNotification() + { + expectedNotifications_.get().pushSend(rank_); + } + + bool Context::Mutator::shiftExpectedSendNotification() + { + if (!expectedNotifications_.exists()) + { + return false; + } + return expectedNotifications_.get().shiftSend(rank_); + } - // Connect every pair - for (int i = 0; i < size; i++) { - if (i == rank) { - continue; + std::vector::iterator Context::findPendingOperations( + slot_t slot) + { + return std::find_if( + pendingOperations_.begin(), + pendingOperations_.end(), + [slot](const Tally &op) + { return op.slot == slot; }); } - // Wait for address of other side of this pair to become available - std::ostringstream key; - key << i; - store.wait({key.str()}, getTimeout()); - - // Connect to other side of this pair - auto allAddrs = store.get(key.str()); - auto addr = extractAddress(allAddrs, i); - getPair(i)->connect(addr); - } -} - -std::vector Context::extractAddress( - const std::vector& allAddrs, - int i) const { - // Extract address from the list of all addresses - int adjRank = (rank > i ? rank - 1 : rank); - // Adjust for the fact that nodes do not store address for themselves - int addrSize = allAddrs.size() / (size - 1); - return std::vector( - allAddrs.begin() + adjRank * addrSize, - allAddrs.begin() + (adjRank + 1) * addrSize); -} - -Context::LazyTally::LazyTally(std::vector& vec, slot_t slot) - : vec_(vec), slot_(slot), initialized_(false) {} - -Context::LazyTally::~LazyTally() { - // Remove empty tally from vector. - if (initialized_ && it_ != vec_.end() && it_->empty()) { - vec_.erase(it_); - } -} - -bool Context::LazyTally::exists() { - initialize_iterator(); - return it_ != vec_.end(); -} - -Context::Tally& Context::LazyTally::get() { - initialize_iterator(); - if (it_ == vec_.end()) { - vec_.emplace_back(slot_); - it_ = (vec_.end() - 1); - } - return *it_; -} - -void Context::LazyTally::initialize_iterator() { - if (initialized_) { - return; - } - - it_ = - std::find_if(vec_.begin(), vec_.end(), [this](const Context::Tally& op) { - return op.slot == slot_; - }); - initialized_ = true; -} - -Context::Mutator::Mutator(Context& context, slot_t slot, rank_t rank) - : lock_(context.mutex_), - context_(context), - slot_(slot), - rank_(rank), - pendingOperations_(context_.pendingOperations_, slot_), - expectedNotifications_(context_.expectedNotifications_, slot_) {} - -void Context::Mutator::pushRemotePendingRecv() { - pendingOperations_.get().pushRecv(rank_); -} - -void Context::Mutator::pushRemotePendingSend() { - pendingOperations_.get().pushSend(rank_); -} - -bool Context::Mutator::shiftRemotePendingRecv() { - if (!pendingOperations_.exists()) { - return false; - } - return pendingOperations_.get().shiftRecv(rank_); -} - -bool Context::Mutator::shiftRemotePendingSend() { - if (!pendingOperations_.exists()) { - return false; - } - return pendingOperations_.get().shiftSend(rank_); -} - -void Context::Mutator::pushExpectedSendNotification() { - expectedNotifications_.get().pushSend(rank_); -} - -bool Context::Mutator::shiftExpectedSendNotification() { - if (!expectedNotifications_.exists()) { - return false; - } - return expectedNotifications_.get().shiftSend(rank_); -} - -std::vector::iterator Context::findPendingOperations( - slot_t slot) { - return std::find_if( - pendingOperations_.begin(), - pendingOperations_.end(), - [slot](const Tally& op) { return op.slot == slot; }); -} - -} // namespace transport + } // namespace transport } // namespace gloo From 7c03bf23b4050e77230306fa2af90ddd47898df1 Mon Sep 17 00:00:00 2001 From: liqingping Date: Fri, 14 Feb 2025 11:36:55 +0800 Subject: [PATCH 4/9] format: remove output code --- gloo/rendezvous/tcp_store.cc | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc index 5d4799bff..c0e03ceba 100644 --- a/gloo/rendezvous/tcp_store.cc +++ b/gloo/rendezvous/tcp_store.cc @@ -53,8 +53,8 @@ namespace gloo timeout_(timeout), data_({}) { - std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size - << ", is_master: " << is_master << std::endl; + // std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size + // << ", is_master: " << is_master << std::endl; if (is_master) { // 创建 socket @@ -106,7 +106,7 @@ namespace gloo GLOO_THROW(err); } - std::cout << "Connection established with client." << std::endl; + // std::cout << "Connection established with client." << std::endl; // 读取客户端消息 char act_buffer[ACTION_SIZE + 1] = {0}; @@ -114,33 +114,31 @@ namespace gloo std::string action = std::string(act_buffer); if (action == POST_ACTION_SET) { - std::cout << "Set request received." << std::endl; // read key size char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); int key_size = atoi(key_size_buffer); - std::cout << "key size: " << key_size << std::endl; + // std::cout << "key size: " << key_size << std::endl; // read key char key_buffer[key_size + 1] = {0}; valread = read(new_socket, key_buffer, key_size); std::string key = std::string(key_buffer); - std::cout << "key: " << key << std::endl; + // std::cout << "key: " << key << std::endl; // read data size char data_size_buffer[SIZE_OF_SIZE + 1] = {0}; valread = read(new_socket, data_size_buffer, SIZE_OF_SIZE); int data_size = atoi(data_size_buffer); - std::cout << "data size: " << data_size << std::endl; + // std::cout << "data size: " << data_size << std::endl; // read data char data_buffer[data_size + 1] = {0}; valread = read(new_socket, data_buffer, data_size); std::string value = std::string(data_buffer); std::vector value_vec(data_buffer, data_buffer + data_size); - std::cout << "data_buffer: <" << data_buffer << ">" << std::endl; - std::cout << "value read: " << valread << "value: <" << value << ">" << std::endl; + // std::cout << "value read: " << valread << "value: <" << value << ">" << std::endl; mtx.lock(); data_[key] = value_vec; @@ -153,7 +151,6 @@ namespace gloo } else if (action == POST_ACTION_GET) { - std::cout << "Get request received." << std::endl; // read key size char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); @@ -163,7 +160,7 @@ namespace gloo char key_buffer[key_size + 1] = {0}; valread = read(new_socket, key_buffer, key_size); std::string key = std::string(key_buffer); - std::cout << "get key: " << key << std::endl; + // std::cout << "get key: " << key << std::endl; bool found = false; std::vector value = {}; @@ -191,7 +188,6 @@ namespace gloo // 向客户端发送响应 const char *response = "OK"; send(new_socket, response, strlen(response), 0); - std::cout << "Response sent to client." << std::endl; } close(new_socket); @@ -245,12 +241,12 @@ namespace gloo len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; message = len_str.c_str(); send(new_server_fd, message, strlen(message), 0); - std::cout << "key size: " << len_str << std::endl; + // std::cout << "key size: " << len_str << std::endl; // send key message = key.c_str(); send(new_server_fd, message, strlen(message), 0); - std::cout << "key: " << key << std::endl; + // std::cout << "key: " << key << std::endl; // send data size len = data.size(); @@ -258,7 +254,7 @@ namespace gloo len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; message = len_str.c_str(); send(new_server_fd, message, strlen(message), 0); - std::cout << "data size: " << len_str << std::endl; + // std::cout << "data size: " << len_str << std::endl; // send data void *data_ptr = static_cast(const_cast(data.data())); @@ -267,7 +263,7 @@ namespace gloo // 读取服务器响应 char buffer[RESPONSE_SIZE] = {0}; int valread = read(new_server_fd, buffer, RESPONSE_SIZE); - std::cout << key << " set request, server response: " << buffer << std::endl; + // std::cout << key << " set request, server response: " << buffer << std::endl; close(new_server_fd); } @@ -343,7 +339,7 @@ namespace gloo if (valread > 0) { std::string buffer_str = std::string(buffer); - std::cout << key << " get request, server response: " << buffer_str << std::endl; + // std::cout << key << " get request, server response: " << buffer_str << std::endl; return std::vector(buffer, buffer + valread); } @@ -367,7 +363,7 @@ namespace gloo { auto data = get(key); std::string buffer_str(data.begin(), data.end()); - std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; + // std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; if (buffer_str == NOT_FOUND) { return false; From 78fe917a001f14f24e8a98b5e479a88d72919c73 Mon Sep 17 00:00:00 2001 From: liqingping Date: Fri, 14 Feb 2025 11:50:16 +0800 Subject: [PATCH 5/9] build: do not build redis store --- hack/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hack/build.sh b/hack/build.sh index c0a98db5f..e9d90a5a4 100644 --- a/hack/build.sh +++ b/hack/build.sh @@ -8,7 +8,7 @@ mkdir -p build cd build cmake -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ \ -DCMAKE_BUILD_TYPE=Debug \ - -DUSE_IBVERBS=1 -DBUILD_BENCHMARK=1 -DUSE_REDIS=1 \ + -DUSE_IBVERBS=1 -DBUILD_BENCHMARK=1 \ -DBUILD_SHARED_LIBS=1 \ ../ make From 8f08819eb6d8304f07510d72fb991120454cc911 Mon Sep 17 00:00:00 2001 From: liqingping Date: Mon, 17 Feb 2025 14:33:03 +0800 Subject: [PATCH 6/9] chore: remove useless code --- gloo/rendezvous/tcp_store.cc | 139 ++++++++++++++--------------------- gloo/rendezvous/tcp_store.h | 17 +---- 2 files changed, 59 insertions(+), 97 deletions(-) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc index c0e03ceba..aa3665ef0 100644 --- a/gloo/rendezvous/tcp_store.cc +++ b/gloo/rendezvous/tcp_store.cc @@ -28,7 +28,7 @@ #define BUFFER_SIZE 1024 #define ACTION_SIZE 3 -#define SIZE_OF_SIZE 16 +#define LENGTH_OF_DATA_SIZE 16 #define RESPONSE_SIZE 2 const std::string POST_ACTION_SET = "set"; @@ -53,11 +53,9 @@ namespace gloo timeout_(timeout), data_({}) { - // std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size - // << ", is_master: " << is_master << std::endl; if (is_master) { - // 创建 socket + // create socket server_fd = socket(AF_INET, SOCK_STREAM, 0); if (server_fd == -1) { @@ -65,20 +63,20 @@ namespace gloo GLOO_THROW(err); } - // 设置服务器地址信息 + // config server address struct sockaddr_in address; address.sin_family = AF_INET; - address.sin_addr.s_addr = INADDR_ANY; // 监听所有的网络接口 + address.sin_addr.s_addr = INADDR_ANY; // listening on all interfaces address.sin_port = htons(port_); - // 绑定 socket 到地址 + // bind socket to address if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) { auto err = std::string("Socket bind failed: ") + strerror(errno); GLOO_THROW(err); } - // 开始监听 + // start listening if (listen(server_fd, 3) < 0) { auto err = std::string("Socket listen failed: ") + strerror(errno); @@ -91,11 +89,8 @@ namespace gloo void TCPStore::accept_func() { - - // 服务器进入循环,持续接受客户端连接 while (true) { - // 接受客户端连接 int new_socket; struct sockaddr_in client_address; socklen_t addr_len = sizeof(client_address); @@ -108,7 +103,7 @@ namespace gloo // std::cout << "Connection established with client." << std::endl; - // 读取客户端消息 + // read action char act_buffer[ACTION_SIZE + 1] = {0}; int valread = read(new_socket, act_buffer, ACTION_SIZE); std::string action = std::string(act_buffer); @@ -116,8 +111,8 @@ namespace gloo { // read key size - char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; - int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); + char key_size_buffer[LENGTH_OF_DATA_SIZE + 1] = {0}; + int valread = read(new_socket, key_size_buffer, LENGTH_OF_DATA_SIZE); int key_size = atoi(key_size_buffer); // std::cout << "key size: " << key_size << std::endl; @@ -128,23 +123,22 @@ namespace gloo // std::cout << "key: " << key << std::endl; // read data size - char data_size_buffer[SIZE_OF_SIZE + 1] = {0}; - valread = read(new_socket, data_size_buffer, SIZE_OF_SIZE); + char data_size_buffer[LENGTH_OF_DATA_SIZE + 1] = {0}; + valread = read(new_socket, data_size_buffer, LENGTH_OF_DATA_SIZE); int data_size = atoi(data_size_buffer); // std::cout << "data size: " << data_size << std::endl; // read data char data_buffer[data_size + 1] = {0}; valread = read(new_socket, data_buffer, data_size); - std::string value = std::string(data_buffer); std::vector value_vec(data_buffer, data_buffer + data_size); - // std::cout << "value read: " << valread << "value: <" << value << ">" << std::endl; + // std::cout << "value read: " << valread << "value: <" << data_buffer << ">" << std::endl; + // update server data_ mtx.lock(); data_[key] = value_vec; mtx.unlock(); - // 向客户端发送响应 const char *response = "OK"; send(new_socket, response, strlen(response), 0); // std::cout << "Response sent to client." << std::endl; @@ -152,8 +146,8 @@ namespace gloo else if (action == POST_ACTION_GET) { // read key size - char key_size_buffer[SIZE_OF_SIZE + 1] = {0}; - int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE); + char key_size_buffer[LENGTH_OF_DATA_SIZE + 1] = {0}; + int valread = read(new_socket, key_size_buffer, LENGTH_OF_DATA_SIZE); int key_size = atoi(key_size_buffer); // read key @@ -185,7 +179,6 @@ namespace gloo } else { - // 向客户端发送响应 const char *response = "OK"; send(new_socket, response, strlen(response), 0); } @@ -194,6 +187,38 @@ namespace gloo } } + int TCPStore::create_server_fd() + { + // create socket + int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (new_server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } + + // config server address + struct sockaddr_in server_address; + server_address.sin_family = AF_INET; + server_address.sin_port = htons(port_); + + // set server address ip + if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + { + auto err = std::string("Invalid address: ") + strerror(errno); + GLOO_THROW(err); + } + + // connect to server + if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) + { + auto err = std::string("Connection to server failed: ") + strerror(errno); + GLOO_THROW(err); + } + + return new_server_fd; + } + void TCPStore::set(const std::string &key, const std::vector &data) { if (is_master_) @@ -204,31 +229,8 @@ namespace gloo } else { - // 创建 socket - int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); - if (new_server_fd == -1) - { - auto err = std::string("Socket creation failed: ") + strerror(errno); - GLOO_THROW(err); - } - - // 设置服务器地址信息 - server_address.sin_family = AF_INET; - server_address.sin_port = htons(port_); - - // 将 IP 地址从文本转换为二进制形式 - if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) - { - auto err = std::string("Invalid address: ") + strerror(errno); - GLOO_THROW(err); - } - - // 连接服务器 - if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) - { - auto err = std::string("Connection to server failed: ") + strerror(errno); - GLOO_THROW(err); - } + // create socket + int new_server_fd = create_server_fd(); // send action std::string act_data = POST_ACTION_SET; @@ -238,7 +240,7 @@ namespace gloo // send key size size_t len = key.length(); std::string len_str = std::to_string(len); - len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + len_str = std::string(LENGTH_OF_DATA_SIZE - len_str.length(), '0') + len_str; message = len_str.c_str(); send(new_server_fd, message, strlen(message), 0); // std::cout << "key size: " << len_str << std::endl; @@ -251,7 +253,7 @@ namespace gloo // send data size len = data.size(); len_str = std::to_string(len); - len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + len_str = std::string(LENGTH_OF_DATA_SIZE - len_str.length(), '0') + len_str; message = len_str.c_str(); send(new_server_fd, message, strlen(message), 0); // std::cout << "data size: " << len_str << std::endl; @@ -260,7 +262,7 @@ namespace gloo void *data_ptr = static_cast(const_cast(data.data())); send(new_server_fd, data_ptr, len, 0); - // 读取服务器响应 + // get response char buffer[RESPONSE_SIZE] = {0}; int valread = read(new_server_fd, buffer, RESPONSE_SIZE); // std::cout << key << " set request, server response: " << buffer << std::endl; @@ -290,31 +292,8 @@ namespace gloo } else { - // 创建 socket - int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); - if (new_server_fd == -1) - { - auto err = std::string("Socket creation failed: ") + strerror(errno); - GLOO_THROW(err); - } - - // 设置服务器地址信息 - server_address.sin_family = AF_INET; - server_address.sin_port = htons(port_); - - // 将 IP 地址从文本转换为二进制形式 - if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) - { - auto err = std::string("Invalid address: ") + strerror(errno); - GLOO_THROW(err); - } - - // 连接服务器 - if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) - { - auto err = std::string("Connection to server failed: ") + strerror(errno); - GLOO_THROW(err); - } + // create socket + int new_server_fd = create_server_fd(); // send action std::string act_data = POST_ACTION_GET; @@ -325,7 +304,7 @@ namespace gloo // send key size size_t len = key.length(); std::string len_str = std::to_string(len); - len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str; + len_str = std::string(LENGTH_OF_DATA_SIZE - len_str.length(), '0') + len_str; message = len_str.c_str(); send(new_server_fd, message, strlen(message), 0); @@ -333,22 +312,19 @@ namespace gloo message = key.c_str(); send(new_server_fd, message, strlen(message), 0); - // 读取服务器响应 + // get response char buffer[BUFFER_SIZE] = {0}; int valread = read(new_server_fd, buffer, BUFFER_SIZE); + close(new_server_fd); if (valread > 0) { std::string buffer_str = std::string(buffer); - // std::cout << key << " get request, server response: " << buffer_str << std::endl; - return std::vector(buffer, buffer + valread); } else { GLOO_THROW("Server response failed!"); } - - close(new_server_fd); } } @@ -363,7 +339,6 @@ namespace gloo { auto data = get(key); std::string buffer_str(data.begin(), data.end()); - // std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl; if (buffer_str == NOT_FOUND) { return false; diff --git a/gloo/rendezvous/tcp_store.h b/gloo/rendezvous/tcp_store.h index 264d54b94..cf2113705 100644 --- a/gloo/rendezvous/tcp_store.h +++ b/gloo/rendezvous/tcp_store.h @@ -44,6 +44,8 @@ namespace gloo virtual void accept_func(); + virtual int create_server_fd(); + std::string host_to_ip(const std::string &host) { hostent *hostname = gethostbyname(host.c_str()); @@ -52,20 +54,6 @@ namespace gloo return {}; } - std::vector str_split(const std::string &str, char delimiter) - { - std::vector tokens; - std::stringstream ss(str); - std::string token; - - while (std::getline(ss, token, delimiter)) - { - tokens.push_back(token); - } - - return tokens; - } - protected: std::string hostname_; std::string host_ip_; @@ -77,7 +65,6 @@ namespace gloo std::mutex mtx; int server_fd; - struct sockaddr_in server_address; std::map> data_; }; From 7e19c284a3964f2d46763e0e11ac5ba1d6bdb69e Mon Sep 17 00:00:00 2001 From: liqingping Date: Mon, 17 Feb 2025 14:41:30 +0800 Subject: [PATCH 7/9] chore: remove debug files --- .gitignore | 60 +----- examples/CMakeLists.txt | 14 -- examples/build.sh | 4 - examples/run.sh | 7 - examples/test_allreduce.cpp | 92 --------- examples/test_allreduce_ib.cpp | 110 ---------- examples/test_send_recv_ib.cpp | 108 ---------- examples/test_type.cpp | 47 ----- gloo/rendezvous/store.cc | 1 - gloo/transport/context.cc | 366 +++++++++++++++------------------ hack/bench.sh | 25 --- hack/build.sh | 15 -- 12 files changed, 167 insertions(+), 682 deletions(-) delete mode 100644 examples/CMakeLists.txt delete mode 100644 examples/build.sh delete mode 100644 examples/run.sh delete mode 100644 examples/test_allreduce.cpp delete mode 100644 examples/test_allreduce_ib.cpp delete mode 100644 examples/test_send_recv_ib.cpp delete mode 100644 examples/test_type.cpp delete mode 100644 hack/bench.sh delete mode 100644 hack/build.sh diff --git a/.gitignore b/.gitignore index cdb50ee2d..e0d1d4e65 100644 --- a/.gitignore +++ b/.gitignore @@ -1,60 +1,2 @@ -# Created by https://www.toptal.com/developers/gitignore/api/c++,cmake -# Edit at https://www.toptal.com/developers/gitignore?templates=c++,cmake - -### C++ ### -# Prerequisites -*.d - -# Compiled Object files -*.slo -*.lo -*.o -*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod -*.smod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.exe -*.out -*.app - -### CMake ### -CMakeLists.txt.user -CMakeCache.txt -CMakeFiles -CMakeScripts -Testing -Makefile -cmake_install.cmake -install_manifest.txt -compile_commands.json -CTestTestfile.cmake -_deps - -### CMake Patch ### -CMakeUserPresets.json - -# External projects -*-prefix/ - -# End of https://www.toptal.com/developers/gitignore/api/c++,cmake - -build/ +/build* *.pyc diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt deleted file mode 100644 index 092af62b3..000000000 --- a/examples/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -set(GLOO_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_allreduce.cpp") -add_executable(gloo_test ${GLOO_TEST_SRCS}) -target_link_libraries(gloo_test gloo ${GLOO_TEST_LIBRARIES}) - -set(IB_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_allreduce_ib.cpp") -add_executable(ib_test ${IB_TEST_SRCS}) -target_link_libraries(ib_test gloo ${IB_TEST_LIBRARIES}) - -set(SEND_TEST_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/test_send_recv_ib.cpp") -add_executable(send_test ${SEND_TEST_SRCS}) -target_link_libraries(send_test gloo ${IB_TEST_LIBRARIES}) - - -add_executable(type_test "${CMAKE_CURRENT_SOURCE_DIR}/test_type.cpp") \ No newline at end of file diff --git a/examples/build.sh b/examples/build.sh deleted file mode 100644 index b74eafd4a..000000000 --- a/examples/build.sh +++ /dev/null @@ -1,4 +0,0 @@ -mkdir build -cd build -cmake .. -make \ No newline at end of file diff --git a/examples/run.sh b/examples/run.sh deleted file mode 100644 index 35e4cb618..000000000 --- a/examples/run.sh +++ /dev/null @@ -1,7 +0,0 @@ -rm -rf /mnt/public/liqingping/opensource/gloo/tmp/file_store/* - -# in rank 0 -IB_DEVICE=mlx5_10 RANK=0 WORLD_SIZE=2 ./build/send_test - -# in rank 1 -IB_DEVICE=mlx5_10 RANK=1 WORLD_SIZE=2 ./build/send_test \ No newline at end of file diff --git a/examples/test_allreduce.cpp b/examples/test_allreduce.cpp deleted file mode 100644 index 902158853..000000000 --- a/examples/test_allreduce.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include - -using namespace gloo; - -int main() -{ - // Initialize context - auto rank = getenv("RANK"); - auto world_size = getenv("WORLD_SIZE"); - auto myRank = atoi(rank); - auto contextSize = atoi(world_size); - gloo::rendezvous::Context context(myRank, contextSize); - - // Perform rendezvous for TCP pairs - gloo::transport::tcp::attr attr("localhost"); - auto dev = gloo::transport::tcp::CreateDevice(attr); - // gloo::transport::ibverbs::attr attr = { - // "mlx5_10", 1, 1}; - // auto dev = gloo::transport::ibverbs::CreateDevice(attr); - gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); - context.connectFullMesh(store, dev); - - std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; - - size_t data_size = 3; - std::vector inputs{new float[data_size * 2]}; - for (auto i = 0; i < data_size; i++) - { - inputs[0][i] = i + 1; - inputs[0][i + data_size] = i + 1; - } - std::vector outputs{new float[data_size * 2]}; - for (auto i = 0; i < data_size; i++) - { - outputs[0][i] = 0; - outputs[0][i + data_size] = 0; - } - - for (auto i = 0; i < data_size; i++) - { - std::cout << "inputs[0][" << i << "] = " << inputs[0][i] << std::endl; - std::cout << "inputs[1][" << i << "] = " << inputs[0][i + data_size] << std::endl; - } - - std::shared_ptr rzv_context = std::make_shared(context); - AllreduceOptions opts(rzv_context); - auto algorithm = gloo::AllreduceOptions::Algorithm::RING; - opts.setAlgorithm(algorithm); - opts.setOutputs(outputs, data_size * 2); - std::cout << "##### before setInputs #####" << std::endl; - opts.setInputs(inputs, data_size * 2); - outputs.clear(); - - // gloo::AllreduceOptions::Func fn = [](void *a, const void *b, const void *c, size_t n) - // { - // return gloo::sum(a, b, c, n); - // }; - - // opts.setReduceFunction(fn); - opts.setReduceFunction([](void *a, const void *b, const void *c, size_t n) - { - std::cout << "a = " << a << ", b = " << b << ", c = " << c << ", n = " << n << std::endl; - auto ua = static_cast(a); - const auto ub = static_cast(b); - const auto uc = static_cast(c); - for (size_t i = 0; i < n; i++) { - ua[i] = ub[i] + uc[i]; - std::cout << "ua[" << i << "] = " << ua[i] << " = " << ub[i] << " + " << uc[i] << std::endl; - } }); - - // A small maximum segment size triggers code paths where we'll - // have a number of segments larger than the lower bound of - // twice the context size. - opts.setMaxSegmentSize(128); - - gloo::allreduce(opts); - - for (auto i = 0; i < data_size; i++) - { - std::cout << "outputs[0][" << i << "] = " << outputs[0][i] << std::endl; - std::cout << "outputs[0][" << i << "] = " << outputs[0][i + data_size] << std::endl; - } - return 0; -} diff --git a/examples/test_allreduce_ib.cpp b/examples/test_allreduce_ib.cpp deleted file mode 100644 index c1f2444b5..000000000 --- a/examples/test_allreduce_ib.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace gloo; - -// Function to instantiate and run algorithm. -using Func = void( - std::shared_ptr<::gloo::Context>, - std::vector dataPtrs, - int dataSize); - -// RAII handle for aligned buffer -template -#ifdef _WIN32 -std::vector newBuffer(int size) -{ - return std::vector(size); -#else -std::vector> newBuffer(int size) -{ - return std::vector>(size); -#endif -} - -int main() -{ - // Initialize context - auto rank = getenv("RANK"); - if (!rank) - { - rank = "0"; - } - auto world_size = getenv("WORLD_SIZE"); - if (!world_size) - { - world_size = "1"; - } - auto ib_device = getenv("IB_DEVICE"); - if (!ib_device) - { - ib_device = "mlx5_0"; - } - auto myRank = atoi(rank); - auto contextSize = atoi(world_size); - gloo::rendezvous::Context context(myRank, contextSize); - - // Perform rendezvous for TCP pairs - // gloo::transport::tcp::attr attr("localhost"); - // auto dev = gloo::transport::tcp::CreateDevice(attr); - gloo::transport::ibverbs::attr attr = { - ib_device, 1, 1}; - auto dev = gloo::transport::ibverbs::CreateDevice(attr); - gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); - context.connectFullMesh(store, dev); - - std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; - - size_t data_size = 3; - static std::function allreduceRing = - [](std::shared_ptr<::gloo::Context> context, - std::vector dataPtrs, - int dataSize) - { - ::gloo::AllreduceRing algorithm(context, dataPtrs, dataSize); - algorithm.run(); - }; - - std::shared_ptr rzv_context = std::make_shared(context); - - // std::vector ptr{new float[data_size * 2]}; - // for (auto i = 0; i < data_size; i++) - // { - // ptr[0][i] = i + 1; - // ptr[0][i + data_size] = i + 1; - // } - // for (auto i = 0; i < data_size; i++) - // { - // std::cout << "ptr[0][" << i << "] = " << ptr[0][i] << std::endl; - // std::cout << "ptr[0][" << i + data_size << "] = " << ptr[0][i + data_size] << std::endl; - // } - - // allreduceRing(rzv_context, ptr, data_size * 2); - - const auto contextRank = rzv_context->rank; - auto buffer = newBuffer(data_size * 2); - auto *ptr = buffer.data(); - - for (int i = 0; i < data_size; i++) - { - ptr[i] = i + 1; - ptr[i + data_size] = i + 1; - } - - allreduceRing(rzv_context, std::vector{ptr}, data_size * 2); - - for (auto i = 0; i < data_size * 2; i++) - { - std::cout << "ptr[" << i << "] = " << ptr[i] << std::endl; - } - return 0; -} diff --git a/examples/test_send_recv_ib.cpp b/examples/test_send_recv_ib.cpp deleted file mode 100644 index 85332e86f..000000000 --- a/examples/test_send_recv_ib.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace gloo; - -// Function to instantiate and run algorithm. -using Func = void( - std::shared_ptr<::gloo::Context>, - std::vector dataPtrs, - int dataSize); - -// RAII handle for aligned buffer -template -#ifdef _WIN32 -std::vector newBuffer(int size) -{ - return std::vector(size); -#else -std::vector> newBuffer(int size) -{ - return std::vector>(size); -#endif -} - -int main() -{ - // Initialize context - auto rank = getenv("RANK"); - if (!rank) - { - rank = "0"; - } - auto world_size = getenv("WORLD_SIZE"); - if (!world_size) - { - world_size = "1"; - } - auto ib_device = getenv("IB_DEVICE"); - if (!ib_device) - { - ib_device = "mlx5_0"; - } - auto myRank = atoi(rank); - auto contextSize = atoi(world_size); - gloo::rendezvous::Context context(myRank, contextSize); - - // // Perform rendezvous for TCP pairs - // gloo::transport::tcp::attr attr("localhost"); - // auto dev = gloo::transport::tcp::CreateDevice(attr); - gloo::transport::ibverbs::attr attr = { - ib_device, 1, 1}; - auto dev = gloo::transport::ibverbs::CreateDevice(attr); - gloo::rendezvous::FileStore store("/mnt/public/liqingping/opensource/gloo/tmp/file_store"); - context.connectFullMesh(store, dev); - - std::cout << "rank = " << context.rank << ", size = " << context.size << std::endl; - - std::shared_ptr rzv_context = std::make_shared(context); - size_t data_size = 3; - float sends[data_size] = {1 + float(myRank), 2 + float(myRank), 3 + float(myRank)}; - float recvs[data_size] = {0, 0, 0}; - for (auto i = 0; i < data_size; i++) - { - std::cout << "sends[" << i << "] = " << sends[i] << std::endl; - } - - auto slot = context.nextSlot(); - int peer; - if (context.rank == 0) - peer = 1; - else - peer = 0; - - std::cout << "peer = " << peer << std::endl; - int bytes_ = sizeof(float) * data_size; - auto inbox_ = static_cast(malloc(bytes_)); - auto outbox_ = static_cast(malloc(bytes_)); - auto &pair = context.getPair(peer); - std::unique_ptr<::gloo::transport::Buffer> sendBuf = pair->createSendBuffer(slot, outbox_, bytes_); - std::unique_ptr<::gloo::transport::Buffer> recvBuf = pair->createRecvBuffer(slot, inbox_, bytes_); - - std::memcpy(outbox_, sends, bytes_); - - sendBuf->send(); - recvBuf->waitRecv(); - sendBuf->waitSend(); - - std::memcpy(recvs, inbox_, bytes_); - - for (auto i = 0; i < data_size; i++) - { - std::cout << "recvs[" << i << "] = " << recvs[i] << std::endl; - } - - free(inbox_); - free(outbox_); - return 0; -} diff --git a/examples/test_type.cpp b/examples/test_type.cpp deleted file mode 100644 index 2438e5aeb..000000000 --- a/examples/test_type.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include -#include -#include // For std::memcpy -#include - -#include - -using namespace gloo; - -// RAII handle for aligned buffer -template -#ifdef _WIN32 -std::vector newBuffer(int size) -{ - return std::vector(size); -#else -std::vector> newBuffer(int size) -{ - return std::vector>(size); -#endif -} - -int main() -{ - // Example size and alignment - constexpr std::size_t kBufferAlignment = 64; - constexpr std::size_t size = 10; - - // Create the aligned vector - auto a = std::vector>(size); - - // Simulate intptr_t pointing to external data - int external_data[size] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - intptr_t b = reinterpret_cast(external_data); - - // Write the data from b into a - std::memcpy(a.data(), reinterpret_cast(b), size * sizeof(int)); - - // Print the result - for (auto val : a) - { - std::cout << val << " "; - } - std::cout << std::endl; - - return 0; -} \ No newline at end of file diff --git a/gloo/rendezvous/store.cc b/gloo/rendezvous/store.cc index d30e9bae2..4d44e808a 100644 --- a/gloo/rendezvous/store.cc +++ b/gloo/rendezvous/store.cc @@ -12,7 +12,6 @@ namespace gloo { namespace rendezvous { - // 定义静态成员变量 constexpr std::chrono::milliseconds Store::kDefaultTimeout; } // namespace rendezvous } // namespace gloo diff --git a/gloo/transport/context.cc b/gloo/transport/context.cc index 7d9740e38..0b1f45b2c 100644 --- a/gloo/transport/context.cc +++ b/gloo/transport/context.cc @@ -9,215 +9,181 @@ #include "gloo/transport/context.h" #include "gloo/common/utils.h" -#include - -namespace gloo -{ - namespace transport - { - - Context::Context(int rank, int size) : rank(rank), size(size) - { - pairs_.resize(size); - } - - // Have to provide implementation for pure virtual destructor. - Context::~Context() {} - - std::unique_ptr &Context::getPair(int rank_2) - { - return pairs_.at(rank_2); - } - - void Context::createAndConnectAllPairs(IStore &store) - { - // this is the default un-optimized version of the rendezvous protocol - // where each rank would write N pairs to the store - // and then for each remote peer load the N addresses - // and only pick the 1 useful - // A more efficient version (for transport supporting multiplexing like TCP) - // can be seen in gloo/transport/tcp/context.cc - - std::vector allBytes; - int localRank = 0; - - auto localHostName = getHostname(); - // Add global rank <> hostname pair to the Store. This store is then passed - // to Gloo when connectFullMesh is called, where Gloo uses the global rank <> - // hostname mapping to compute local ranks. - std::string localKey("rank_" + std::to_string(rank)); - const std::vector value(localHostName.begin(), localHostName.end()); - store.set(localKey, value); - - for (int i = 0; i < size; i++) - { - if (i == rank) - { - break; - } - - std::string key("rank_" + std::to_string(i)); - auto val = store.get(key); - auto hostName = std::string((const char *)val.data(), val.size()); - - if (hostName == localHostName) - { - localRank++; - } - } - - // Create pairs - for (int i = 0; i < size; i++) - { - if (i == rank) - { - continue; - } - - auto &pair = createPair(i); - pair->setLocalRank(localRank); - auto addrBytes = pair->address().bytes(); - // std::cout << "len: " << addrBytes.size() << "addrBytes: " << std::string(addrBytes.begin(), addrBytes.end()) << std::endl; - - allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end()); - } - - store.set(std::to_string(rank), allBytes); - - // Connect every pair - for (int i = 0; i < size; i++) - { - if (i == rank) - { - continue; - } - - // Wait for address of other side of this pair to become available - std::ostringstream key; - key << i; - store.wait({key.str()}, getTimeout()); - - // Connect to other side of this pair - auto allAddrs = store.get(key.str()); - auto addr = extractAddress(allAddrs, i); - getPair(i)->connect(addr); - } - } - - std::vector Context::extractAddress( - const std::vector &allAddrs, - int i) const - { - // Extract address from the list of all addresses - int adjRank = (rank > i ? rank - 1 : rank); - // Adjust for the fact that nodes do not store address for themselves - int addrSize = allAddrs.size() / (size - 1); - return std::vector( - allAddrs.begin() + adjRank * addrSize, - allAddrs.begin() + (adjRank + 1) * addrSize); +namespace gloo { +namespace transport { + +Context::Context(int rank, int size) : rank(rank), size(size) { + pairs_.resize(size); +} + +// Have to provide implementation for pure virtual destructor. +Context::~Context() {} + +std::unique_ptr& Context::getPair(int rank_2) { + return pairs_.at(rank_2); +} + +void Context::createAndConnectAllPairs(IStore& store) { + // this is the default un-optimized version of the rendezvous protocol + // where each rank would write N pairs to the store + // and then for each remote peer load the N addresses + // and only pick the 1 useful + // A more efficient version (for transport supporting multiplexing like TCP) + // can be seen in gloo/transport/tcp/context.cc + + std::vector allBytes; + int localRank = 0; + + auto localHostName = getHostname(); + // Add global rank <> hostname pair to the Store. This store is then passed + // to Gloo when connectFullMesh is called, where Gloo uses the global rank <> + // hostname mapping to compute local ranks. + std::string localKey("rank_" + std::to_string(rank)); + const std::vector value(localHostName.begin(), localHostName.end()); + store.set(localKey, value); + + for (int i = 0; i < size; i++) { + if (i == rank) { + break; } - Context::LazyTally::LazyTally(std::vector &vec, slot_t slot) - : vec_(vec), slot_(slot), initialized_(false) {} + std::string key("rank_" + std::to_string(i)); + auto val = store.get(key); + auto hostName = std::string((const char*)val.data(), val.size()); - Context::LazyTally::~LazyTally() - { - // Remove empty tally from vector. - if (initialized_ && it_ != vec_.end() && it_->empty()) - { - vec_.erase(it_); - } + if (hostName == localHostName) { + localRank++; } + } - bool Context::LazyTally::exists() - { - initialize_iterator(); - return it_ != vec_.end(); + // Create pairs + for (int i = 0; i < size; i++) { + if (i == rank) { + continue; } - Context::Tally &Context::LazyTally::get() - { - initialize_iterator(); - if (it_ == vec_.end()) - { - vec_.emplace_back(slot_); - it_ = (vec_.end() - 1); - } - return *it_; - } - - void Context::LazyTally::initialize_iterator() - { - if (initialized_) - { - return; - } - - it_ = - std::find_if(vec_.begin(), vec_.end(), [this](const Context::Tally &op) - { return op.slot == slot_; }); - initialized_ = true; - } + auto& pair = createPair(i); + pair->setLocalRank(localRank); + auto addrBytes = pair->address().bytes(); + allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end()); + } - Context::Mutator::Mutator(Context &context, slot_t slot, rank_t rank) - : lock_(context.mutex_), - context_(context), - slot_(slot), - rank_(rank), - pendingOperations_(context_.pendingOperations_, slot_), - expectedNotifications_(context_.expectedNotifications_, slot_) {} - - void Context::Mutator::pushRemotePendingRecv() - { - pendingOperations_.get().pushRecv(rank_); - } - - void Context::Mutator::pushRemotePendingSend() - { - pendingOperations_.get().pushSend(rank_); - } - - bool Context::Mutator::shiftRemotePendingRecv() - { - if (!pendingOperations_.exists()) - { - return false; - } - return pendingOperations_.get().shiftRecv(rank_); - } - - bool Context::Mutator::shiftRemotePendingSend() - { - if (!pendingOperations_.exists()) - { - return false; - } - return pendingOperations_.get().shiftSend(rank_); - } - - void Context::Mutator::pushExpectedSendNotification() - { - expectedNotifications_.get().pushSend(rank_); - } - - bool Context::Mutator::shiftExpectedSendNotification() - { - if (!expectedNotifications_.exists()) - { - return false; - } - return expectedNotifications_.get().shiftSend(rank_); - } + store.set(std::to_string(rank), allBytes); - std::vector::iterator Context::findPendingOperations( - slot_t slot) - { - return std::find_if( - pendingOperations_.begin(), - pendingOperations_.end(), - [slot](const Tally &op) - { return op.slot == slot; }); + // Connect every pair + for (int i = 0; i < size; i++) { + if (i == rank) { + continue; } - } // namespace transport + // Wait for address of other side of this pair to become available + std::ostringstream key; + key << i; + store.wait({key.str()}, getTimeout()); + + // Connect to other side of this pair + auto allAddrs = store.get(key.str()); + auto addr = extractAddress(allAddrs, i); + getPair(i)->connect(addr); + } +} + +std::vector Context::extractAddress( + const std::vector& allAddrs, + int i) const { + // Extract address from the list of all addresses + int adjRank = (rank > i ? rank - 1 : rank); + // Adjust for the fact that nodes do not store address for themselves + int addrSize = allAddrs.size() / (size - 1); + return std::vector( + allAddrs.begin() + adjRank * addrSize, + allAddrs.begin() + (adjRank + 1) * addrSize); +} + +Context::LazyTally::LazyTally(std::vector& vec, slot_t slot) + : vec_(vec), slot_(slot), initialized_(false) {} + +Context::LazyTally::~LazyTally() { + // Remove empty tally from vector. + if (initialized_ && it_ != vec_.end() && it_->empty()) { + vec_.erase(it_); + } +} + +bool Context::LazyTally::exists() { + initialize_iterator(); + return it_ != vec_.end(); +} + +Context::Tally& Context::LazyTally::get() { + initialize_iterator(); + if (it_ == vec_.end()) { + vec_.emplace_back(slot_); + it_ = (vec_.end() - 1); + } + return *it_; +} + +void Context::LazyTally::initialize_iterator() { + if (initialized_) { + return; + } + + it_ = + std::find_if(vec_.begin(), vec_.end(), [this](const Context::Tally& op) { + return op.slot == slot_; + }); + initialized_ = true; +} + +Context::Mutator::Mutator(Context& context, slot_t slot, rank_t rank) + : lock_(context.mutex_), + context_(context), + slot_(slot), + rank_(rank), + pendingOperations_(context_.pendingOperations_, slot_), + expectedNotifications_(context_.expectedNotifications_, slot_) {} + +void Context::Mutator::pushRemotePendingRecv() { + pendingOperations_.get().pushRecv(rank_); +} + +void Context::Mutator::pushRemotePendingSend() { + pendingOperations_.get().pushSend(rank_); +} + +bool Context::Mutator::shiftRemotePendingRecv() { + if (!pendingOperations_.exists()) { + return false; + } + return pendingOperations_.get().shiftRecv(rank_); +} + +bool Context::Mutator::shiftRemotePendingSend() { + if (!pendingOperations_.exists()) { + return false; + } + return pendingOperations_.get().shiftSend(rank_); +} + +void Context::Mutator::pushExpectedSendNotification() { + expectedNotifications_.get().pushSend(rank_); +} + +bool Context::Mutator::shiftExpectedSendNotification() { + if (!expectedNotifications_.exists()) { + return false; + } + return expectedNotifications_.get().shiftSend(rank_); +} + +std::vector::iterator Context::findPendingOperations( + slot_t slot) { + return std::find_if( + pendingOperations_.begin(), + pendingOperations_.end(), + [slot](const Tally& op) { return op.slot == slot; }); +} + +} // namespace transport } // namespace gloo diff --git a/hack/bench.sh b/hack/bench.sh deleted file mode 100644 index b18a3de28..000000000 --- a/hack/bench.sh +++ /dev/null @@ -1,25 +0,0 @@ -ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -cd $ROOT_DIR/.. - -cd build/gloo/benchmark - -NNODES=${WORLD_SIZE:-1} -NODE_RANK=${RANK:-0} -REDIS_HOST=${REDIS_HOST:-"localhost"} -REDIS_PORT=${REDIS_PORT:-6379} - -./benchmark \ - --size ${NNODES} \ - --rank ${NODE_RANK} \ - --redis-host ${REDIS_HOST} \ - --redis-port ${REDIS_PORT} \ - --prefix test-for-benchmark \ - --transport ibverbs \ - --ib-device mlx5_10 \ - --ib-port 1 \ - --elements $(( 1024 * 1024 )) \ - --inputs 4 \ - --iteration-time 2s \ - allreduce_ring - # allreduce_ring_chunked diff --git a/hack/build.sh b/hack/build.sh deleted file mode 100644 index e9d90a5a4..000000000 --- a/hack/build.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -cd $ROOT_DIR/.. - -mkdir -p build -cd build -cmake -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ \ - -DCMAKE_BUILD_TYPE=Debug \ - -DUSE_IBVERBS=1 -DBUILD_BENCHMARK=1 \ - -DBUILD_SHARED_LIBS=1 \ - ../ -make -make install \ No newline at end of file From 337f3c598168b9e459752f6e3936264980f36f4a Mon Sep 17 00:00:00 2001 From: liqingping Date: Tue, 18 Feb 2025 16:37:19 +0800 Subject: [PATCH 8/9] feat: wait until connect to master or timeout --- gloo/rendezvous/tcp_store.cc | 25 ++++++++++++++++++++----- gloo/rendezvous/tcp_store.h | 6 ++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc index aa3665ef0..5ed1cd5c6 100644 --- a/gloo/rendezvous/tcp_store.cc +++ b/gloo/rendezvous/tcp_store.cc @@ -16,6 +16,7 @@ #include #include #include +#include #ifndef _WIN32 #include @@ -209,11 +210,26 @@ namespace gloo GLOO_THROW(err); } - // connect to server - if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0) + const auto start = std::chrono::steady_clock::now(); + auto timeout = std::chrono::seconds(timeout_); + while (true) { - auto err = std::string("Connection to server failed: ") + strerror(errno); - GLOO_THROW(err); + // connect to server + if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) == 0) + { + break; + } + + // check timeout + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout != kNoTimeout && elapsed > timeout) + { + GLOO_THROW_IO_EXCEPTION(GLOO_ERROR_MSG( + "Connection to master timeout for " + std::to_string(timeout_) + " seconds")); + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } return new_server_fd; @@ -358,7 +374,6 @@ namespace gloo } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } } diff --git a/gloo/rendezvous/tcp_store.h b/gloo/rendezvous/tcp_store.h index cf2113705..25f83283a 100644 --- a/gloo/rendezvous/tcp_store.h +++ b/gloo/rendezvous/tcp_store.h @@ -16,6 +16,8 @@ #include #include +#define SOCKET_INIT_TIMEOUT_SECONDS 30 + namespace gloo { namespace rendezvous @@ -24,7 +26,7 @@ namespace gloo class TCPStore : public Store { public: - explicit TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout = 30); + explicit TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout = SOCKET_INIT_TIMEOUT_SECONDS); virtual ~TCPStore(); virtual void set(const std::string &key, const std::vector &data) @@ -64,7 +66,7 @@ namespace gloo std::mutex mtx; - int server_fd; + int server_fd = -1; std::map> data_; }; From 444fb9ff2cda17a9c5351cc44f4f6c94c6363b50 Mon Sep 17 00:00:00 2001 From: liqingping Date: Thu, 20 Feb 2025 20:42:54 +0800 Subject: [PATCH 9/9] fix: socket TIME_WAIT --- gloo/rendezvous/tcp_store.cc | 69 ++++++++++++++++++++---------------- gloo/rendezvous/tcp_store.h | 32 ++++++++++++++++- 2 files changed, 70 insertions(+), 31 deletions(-) diff --git a/gloo/rendezvous/tcp_store.cc b/gloo/rendezvous/tcp_store.cc index 5ed1cd5c6..dba8f9968 100644 --- a/gloo/rendezvous/tcp_store.cc +++ b/gloo/rendezvous/tcp_store.cc @@ -42,7 +42,6 @@ namespace gloo { TCPStore::~TCPStore() { - close(server_fd); } TCPStore::TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout) @@ -52,13 +51,15 @@ namespace gloo world_size_(world_size), is_master_(is_master), timeout_(timeout), - data_({}) + data_({}), + server_fd_(-1) { if (is_master) { // create socket - server_fd = socket(AF_INET, SOCK_STREAM, 0); - if (server_fd == -1) + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + server_fd_.reset(server_fd); + if (server_fd_.get() == -1) { auto err = std::string("Socket creation failed: ") + strerror(errno); GLOO_THROW(err); @@ -71,14 +72,14 @@ namespace gloo address.sin_port = htons(port_); // bind socket to address - if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) + if (bind(server_fd_.get(), (struct sockaddr *)&address, sizeof(address)) < 0) { auto err = std::string("Socket bind failed: ") + strerror(errno); GLOO_THROW(err); } // start listening - if (listen(server_fd, 3) < 0) + if (listen(server_fd_.get(), 3) < 0) { auto err = std::string("Socket listen failed: ") + strerror(errno); GLOO_THROW(err); @@ -95,7 +96,8 @@ namespace gloo int new_socket; struct sockaddr_in client_address; socklen_t addr_len = sizeof(client_address); - new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len); + std::cout << "server fd: <" << server_fd_.get() << ">" << std::endl; + new_socket = accept(server_fd_.get(), (struct sockaddr *)&client_address, &addr_len); if (new_socket < 0) { auto err = std::string("Accept client connection failed: ") + strerror(errno); @@ -190,36 +192,45 @@ namespace gloo int TCPStore::create_server_fd() { - // create socket - int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); - if (new_server_fd == -1) + while (true) { - auto err = std::string("Socket creation failed: ") + strerror(errno); - GLOO_THROW(err); - } + // create socket + int new_server_fd = socket(AF_INET, SOCK_STREAM, 0); + if (new_server_fd == -1) + { + auto err = std::string("Socket creation failed: ") + strerror(errno); + GLOO_THROW(err); + } - // config server address - struct sockaddr_in server_address; - server_address.sin_family = AF_INET; - server_address.sin_port = htons(port_); + // config server address + struct sockaddr_in server_address; + server_address.sin_family = AF_INET; + server_address.sin_port = htons(port_); - // set server address ip - if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) - { - auto err = std::string("Invalid address: ") + strerror(errno); - GLOO_THROW(err); - } + // set server address ip + if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0) + { + close(new_server_fd); + auto err = std::string("Invalid address: ") + strerror(errno); + GLOO_THROW(err); + } + + const auto start = std::chrono::steady_clock::now(); + auto timeout = std::chrono::seconds(timeout_); - const auto start = std::chrono::steady_clock::now(); - auto timeout = std::chrono::seconds(timeout_); - while (true) - { // connect to server if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) == 0) { - break; + struct linger so_linger; + so_linger.l_onoff = 1; // enable LINGER + so_linger.l_linger = 0; // send RST to close the connection immediately + setsockopt(new_server_fd, SOL_SOCKET, SO_LINGER, &so_linger, sizeof(so_linger)); + + return new_server_fd; } + close(new_server_fd); + // check timeout const auto elapsed = std::chrono::duration_cast( std::chrono::steady_clock::now() - start); @@ -231,8 +242,6 @@ namespace gloo /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - - return new_server_fd; } void TCPStore::set(const std::string &key, const std::vector &data) diff --git a/gloo/rendezvous/tcp_store.h b/gloo/rendezvous/tcp_store.h index 25f83283a..dc7020999 100644 --- a/gloo/rendezvous/tcp_store.h +++ b/gloo/rendezvous/tcp_store.h @@ -16,12 +16,42 @@ #include #include +#ifndef _WIN32 +#include +#else +#include +#endif + #define SOCKET_INIT_TIMEOUT_SECONDS 30 namespace gloo { namespace rendezvous { + class Socket + { + public: + explicit Socket(int fd) : fd_(fd) {} + ~Socket() + { + if (fd_ != -1) + { + close(fd_); + } + } + int get() const { return fd_; } + void reset(int fd = -1) + { + if (fd_ != -1) + { + close(fd_); + } + fd_ = fd; + } + + private: + int fd_; + }; class TCPStore : public Store { @@ -66,7 +96,7 @@ namespace gloo std::mutex mtx; - int server_fd = -1; + Socket server_fd_; std::map> data_; };