Skip to content

Commit c37d821

Browse files
authored
gloo/ibverbs: support put/get operations
Differential Revision: D75919074 Pull Request resolved: #450
1 parent c7b7b02 commit c37d821

19 files changed

+579
-28
lines changed

gloo/context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <memory>
1313
#include <vector>
1414

15+
#include <gloo/transport/context.h>
1516
#include <gloo/transport/pair.h>
1617

1718
namespace gloo {
@@ -51,6 +52,11 @@ class Context {
5152

5253
std::chrono::milliseconds getTimeout() const;
5354

55+
std::unique_ptr<transport::RemoteKey> deserializeRemoteKey(
56+
const std::string& serialized) {
57+
return transportContext_->deserializeRemoteKey(serialized);
58+
}
59+
5460
protected:
5561
std::shared_ptr<transport::Device> device_;
5662
std::shared_ptr<transport::Context> transportContext_;

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ set(GLOO_TEST_SRCS
1414
"${CMAKE_CURRENT_SOURCE_DIR}/main.cc"
1515
"${CMAKE_CURRENT_SOURCE_DIR}/memory_test.cc"
1616
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_test.cc"
17+
"${CMAKE_CURRENT_SOURCE_DIR}/remote_key_test.cc"
1718
"${CMAKE_CURRENT_SOURCE_DIR}/send_recv_test.cc"
1819
)
1920
set(GLOO_TEST_LIBRARIES)

gloo/test/base_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ const std::vector<Transport> kTransportsForClassAlgorithms = {
2626
#endif
2727
};
2828

29+
const std::vector<Transport> kTransportsForRDMA = {
30+
#if GLOO_HAVE_TRANSPORT_IBVERBS
31+
Transport::IBVERBS,
32+
#endif
33+
};
34+
2935
// Transports that function algorithms can be tested against.
3036
// This is the new style of calling collectives and must be
3137
// preferred over the instantiated style.

gloo/test/base_test.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum Transport {
7373

7474
extern const std::vector<Transport> kTransportsForClassAlgorithms;
7575
extern const std::vector<Transport> kTransportsForFunctionAlgorithms;
76+
extern const std::vector<Transport> kTransportsForRDMA;
7677

7778
std::shared_ptr<::gloo::transport::Device> createDevice(Transport transport);
7879

gloo/test/remote_key_test.cc

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <array>
10+
#include <functional>
11+
#include <vector>
12+
13+
#include "gloo/allgather.h"
14+
#include "gloo/common/common.h"
15+
#include "gloo/test/base_test.h"
16+
17+
namespace gloo {
18+
namespace test {
19+
namespace {
20+
21+
// Test parameterization.
22+
using Param = std::tuple<Transport, int, int>;
23+
24+
std::vector<std::unique_ptr<gloo::transport::RemoteKey>> exchangeKeys(
25+
std::shared_ptr<Context>& context,
26+
const std::unique_ptr<gloo::transport::RemoteKey>& key) {
27+
auto selfRemoteKey = key->serialize();
28+
GLOO_ENFORCE_GT(selfRemoteKey.size(), 0);
29+
30+
std::array<char, 1024> keyBuf;
31+
for (int i = 0; i < selfRemoteKey.size(); ++i) {
32+
keyBuf[i] = selfRemoteKey[i];
33+
}
34+
keyBuf[selfRemoteKey.size()] = '\0';
35+
36+
std::unique_ptr<char[]> outPtr{
37+
gloo::make_unique<char[]>(keyBuf.size() * context->size)};
38+
39+
AllgatherOptions opts(context);
40+
opts.setInput(&keyBuf[0], keyBuf.size());
41+
opts.setOutput(outPtr.get(), keyBuf.size() * context->size);
42+
allgather(opts);
43+
44+
std::vector<std::unique_ptr<gloo::transport::RemoteKey>> remoteKeys;
45+
46+
for (int i = 0; i < context->size; i++) {
47+
std::string remoteKeyStr{&outPtr[i * keyBuf.size()]};
48+
GLOO_ENFORCE_GT(remoteKeyStr.size(), 0);
49+
auto remoteKey = context->deserializeRemoteKey(remoteKeyStr);
50+
GLOO_ENFORCE(remoteKey.get() != nullptr);
51+
52+
remoteKeys.push_back(std::move(remoteKey));
53+
}
54+
55+
return remoteKeys;
56+
}
57+
58+
// Test fixture.
59+
class RemoteKeyTest : public BaseTest,
60+
public ::testing::WithParamInterface<Param> {};
61+
62+
TEST_P(RemoteKeyTest, Get) {
63+
const auto transport = std::get<0>(GetParam());
64+
const auto contextSize = std::get<1>(GetParam());
65+
const auto dataSize = std::get<2>(GetParam());
66+
67+
Barrier barrier(contextSize);
68+
69+
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
70+
int rank = context->rank;
71+
std::unique_ptr<char[]> sharedPtr{gloo::make_unique<char[]>(dataSize)};
72+
for (int i = 0; i < dataSize; ++i) {
73+
sharedPtr[i] = rank;
74+
}
75+
auto sharedBuf = context->createUnboundBuffer(sharedPtr.get(), dataSize);
76+
77+
std::unique_ptr<char[]> localPtr{gloo::make_unique<char[]>(dataSize)};
78+
auto localBuf = context->createUnboundBuffer(localPtr.get(), dataSize);
79+
80+
auto remoteKeys = exchangeKeys(context, sharedBuf->getRemoteKey());
81+
82+
for (int i = 0; i < contextSize; ++i) {
83+
if (i == rank) {
84+
continue;
85+
}
86+
87+
localBuf->get(*remoteKeys.at(i), context->nextSlot(), 0, 0, dataSize);
88+
localBuf->waitSend();
89+
90+
for (int j = 0; j < dataSize; ++j) {
91+
ASSERT_EQ(localPtr[j], i);
92+
}
93+
}
94+
95+
barrier.wait();
96+
97+
// Do bound checking
98+
auto testRank = (rank + 1) % contextSize;
99+
auto& testKey = remoteKeys.at(testRank);
100+
EXPECT_THROW(
101+
localBuf->get(*testKey, context->nextSlot(), 1000000000, 0, 1),
102+
gloo::EnforceNotMet);
103+
EXPECT_THROW(
104+
localBuf->get(*testKey, context->nextSlot(), 0, 1000000000, 1),
105+
gloo::EnforceNotMet);
106+
EXPECT_THROW(
107+
localBuf->get(*testKey, context->nextSlot(), 0, 0, 1000000000),
108+
gloo::EnforceNotMet);
109+
});
110+
}
111+
112+
TEST_P(RemoteKeyTest, Put) {
113+
const auto transport = std::get<0>(GetParam());
114+
const auto contextSize = std::get<1>(GetParam());
115+
116+
Barrier barrier(contextSize);
117+
118+
spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
119+
int rank = context->rank;
120+
std::unique_ptr<char[]> exportPtr{gloo::make_unique<char[]>(contextSize)};
121+
auto exportBuf = context->createUnboundBuffer(exportPtr.get(), contextSize);
122+
123+
std::unique_ptr<char[]> localPtr{gloo::make_unique<char[]>(contextSize)};
124+
for (int i = 0; i < contextSize; ++i) {
125+
localPtr[i] = rank;
126+
}
127+
auto localBuf = context->createUnboundBuffer(localPtr.get(), contextSize);
128+
129+
auto remoteKeys = exchangeKeys(context, exportBuf->getRemoteKey());
130+
131+
for (int i = 0; i < contextSize; ++i) {
132+
if (i == rank) {
133+
continue;
134+
}
135+
136+
localBuf->put(*remoteKeys.at(i), context->nextSlot(), rank, rank, 1);
137+
localBuf->waitSend();
138+
}
139+
140+
barrier.wait();
141+
142+
for (int j = 0; j < contextSize; ++j) {
143+
if (j == rank) {
144+
continue;
145+
}
146+
ASSERT_EQ(exportPtr[j], j);
147+
}
148+
149+
// Do bound checking
150+
auto testRank = (rank + 1) % contextSize;
151+
auto& testKey = remoteKeys.at(testRank);
152+
EXPECT_THROW(
153+
localBuf->put(*testKey, context->nextSlot(), 1000000000, 0, 1),
154+
gloo::EnforceNotMet);
155+
EXPECT_THROW(
156+
localBuf->put(*testKey, context->nextSlot(), 0, 1000000000, 1),
157+
gloo::EnforceNotMet);
158+
EXPECT_THROW(
159+
localBuf->put(*testKey, context->nextSlot(), 0, 0, 1000000000),
160+
gloo::EnforceNotMet);
161+
});
162+
}
163+
164+
INSTANTIATE_TEST_CASE_P(
165+
RemoteKeyTestBasics,
166+
RemoteKeyTest,
167+
::testing::Combine(
168+
::testing::ValuesIn(kTransportsForRDMA),
169+
::testing::Values(2, 4),
170+
::testing::Values(0, 1024, 1000000)));
171+
172+
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(RemoteKeyTest);
173+
174+
} // namespace
175+
} // namespace test
176+
} // namespace gloo

gloo/transport/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ set(GLOO_TRANSPORT_HDRS
1313
"${CMAKE_CURRENT_SOURCE_DIR}/context.h"
1414
"${CMAKE_CURRENT_SOURCE_DIR}/device.h"
1515
"${CMAKE_CURRENT_SOURCE_DIR}/pair.h"
16+
"${CMAKE_CURRENT_SOURCE_DIR}/remote_key.h"
1617
"${CMAKE_CURRENT_SOURCE_DIR}/unbound_buffer.h"
1718
)
1819

gloo/transport/context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ class Context {
6868
return timeout_;
6969
}
7070

71+
virtual std::unique_ptr<RemoteKey> deserializeRemoteKey(
72+
const std::string& serialized) {
73+
throw std::runtime_error("Not implemented");
74+
}
75+
7176
protected:
7277
// Protects access to the pending operations and expected
7378
// notifications vectors. These vectors can only be mutated by an

gloo/transport/ibverbs/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ list(APPEND GLOO_TRANSPORT_SRCS
55
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
66
"${CMAKE_CURRENT_SOURCE_DIR}/memory_region.cc"
77
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"
8+
"${CMAKE_CURRENT_SOURCE_DIR}/remote_key.cc"
89
"${CMAKE_CURRENT_SOURCE_DIR}/unbound_buffer.cc"
910
)
1011

@@ -15,6 +16,7 @@ list(APPEND GLOO_TRANSPORT_HDRS
1516
"${CMAKE_CURRENT_SOURCE_DIR}/device.h"
1617
"${CMAKE_CURRENT_SOURCE_DIR}/memory_region.h"
1718
"${CMAKE_CURRENT_SOURCE_DIR}/pair.h"
19+
"${CMAKE_CURRENT_SOURCE_DIR}/remote_key.h"
1820
"${CMAKE_CURRENT_SOURCE_DIR}/unbound_buffer.h"
1921
)
2022

gloo/transport/ibverbs/context.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "gloo/common/error.h"
1212
#include "gloo/transport/ibverbs/device.h"
1313
#include "gloo/transport/ibverbs/pair.h"
14+
#include "gloo/transport/ibverbs/remote_key.h"
1415
#include "gloo/transport/ibverbs/unbound_buffer.h"
1516

1617
namespace gloo {
@@ -46,6 +47,11 @@ void Context::signalException(const std::string& msg) {
4647
}
4748
}
4849

50+
std::unique_ptr<::gloo::transport::RemoteKey> Context::deserializeRemoteKey(
51+
const std::string& serialized) {
52+
return RemoteKey::deserialize(serialized);
53+
}
54+
4955
} // namespace ibverbs
5056
} // namespace transport
5157
} // namespace gloo

gloo/transport/ibverbs/context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class Context : public ::gloo::transport::Context,
3838
// out. All pairs should be signaled and closed in that event.
3939
void signalException(const std::string& msg);
4040

41+
std::unique_ptr<::gloo::transport::RemoteKey> deserializeRemoteKey(
42+
const std::string& serialized) override;
43+
4144
protected:
4245
std::shared_ptr<Device> device_;
4346

0 commit comments

Comments
 (0)