Skip to content

Commit 60ddee7

Browse files
authored
[rocshmem] Expose CTX Array (#4121)
AIROCSHMEM-8
1 parent 77cb138 commit 60ddee7

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

projects/rocshmem/include/rocshmem/rocshmem_common.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ extern "C" __device__ rocshmem_ctx_t __attribute__((visibility("default"))) ROC
133133
* valid context.
134134
*/
135135
extern __constant__ rocshmem_ctx_t ROCSHMEM_CTX_INVALID;
136+
137+
extern __constant__ rocshmem_ctx_t *rocshmem_ctx_array;
136138
/**
137139
* Used internally to set default context.
138140
*/
@@ -180,7 +182,7 @@ const rocshmem_team_t ROCSHMEM_TEAM_INVALID = nullptr;
180182
using rocshmem_uniqueid_t = std::array<uint8_t, ROCSHMEM_UNIQUE_ID_BYTES>;
181183

182184
/**
183-
* @brief Data structure used for attribute based
185+
* @brief Data structure used for attribute based
184186
* initialization
185187
*/
186188
struct rocshmem_init_attr_t {

projects/rocshmem/src/gda/backend_gda.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,42 @@ void GDABackend::setup_ctxs() {
178178
new (&ctx_array[i]) GDAContext(this, i + 1, gda_provider);
179179
ctx_free_list.get()->push_back(ctx_array + i);
180180
}
181+
182+
rocshmem_ctx_t *rocshmem_ctx_array_device = nullptr;
183+
rocshmem_ctx_t *rocshmem_ctx_array_ptr = nullptr;
184+
size_t ctx_array_size = sizeof(rocshmem_ctx_t) * envvar::max_num_contexts;
185+
186+
CHECK_HIP(hipMalloc((void**)&rocshmem_ctx_array_device, ctx_array_size));
187+
188+
for (size_t i = 0; i < envvar::max_num_contexts; i++) {
189+
rocshmem_ctx_array_device[i].ctx_opaque = &ctx_array[i];
190+
rocshmem_ctx_array_device[i].team_opaque = team_tracker.get_team_world()->tinfo_wrt_world;
191+
}
192+
193+
CHECK_HIP(hipGetSymbolAddress(reinterpret_cast<void**>(&rocshmem_ctx_array_ptr),
194+
HIP_SYMBOL(rocshmem_ctx_array)));
195+
196+
CHECK_HIP(hipMemcpy(rocshmem_ctx_array_ptr,
197+
&rocshmem_ctx_array_device,
198+
sizeof(rocshmem_ctx*),
199+
hipMemcpyDefault));
181200
}
182201

183202
void GDABackend::cleanup_ctxs() {
203+
/* Free ctx array */
204+
rocshmem_ctx_t *rocshmem_ctx_array_ptr = nullptr;
205+
rocshmem_ctx_t *rocshmem_ctx_array_device = nullptr;
206+
207+
CHECK_HIP(hipGetSymbolAddress(reinterpret_cast<void**>(&rocshmem_ctx_array_ptr),
208+
HIP_SYMBOL(rocshmem_ctx_array)));
209+
210+
CHECK_HIP(hipMemcpy(&rocshmem_ctx_array_device,
211+
rocshmem_ctx_array_ptr,
212+
sizeof(rocshmem_ctx*),
213+
hipMemcpyDefault));
214+
215+
CHECK_HIP(hipFree(rocshmem_ctx_array_device));
216+
184217
ctx_free_list.~FreeListProxy();
185218
for (size_t i = 0; i < envvar::max_num_contexts; i++) {
186219
ctx_array[i].~GDAContext();

projects/rocshmem/src/rocshmem_gpu.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ namespace rocshmem {
7777
__device__ rocshmem_ctx_t
7878
__attribute__((visibility("default"))) ROCSHMEM_CTX_DEFAULT{};
7979

80+
__constant__ rocshmem_ctx_t *rocshmem_ctx_array;
81+
8082
__constant__ Backend *device_backend_proxy;
8183

8284
__constant__ constmem_t constmem;

0 commit comments

Comments
 (0)