@@ -178,9 +178,42 @@ void GDABackend::setup_ctxs() {
178178 new (&ctx_array[i]) GDAContext (this , i + 1 , gda_provider);
179179 ctx_free_list.get ()->push_back (ctx_array + i);
180180 }
181+
182+ rocshmem_ctx_t *rocshmem_ctx_array_device = nullptr ;
183+ rocshmem_ctx_t *rocshmem_ctx_array_ptr = nullptr ;
184+ size_t ctx_array_size = sizeof (rocshmem_ctx_t ) * envvar::max_num_contexts;
185+
186+ CHECK_HIP (hipMalloc ((void **)&rocshmem_ctx_array_device, ctx_array_size));
187+
188+ for (size_t i = 0 ; i < envvar::max_num_contexts; i++) {
189+ rocshmem_ctx_array_device[i].ctx_opaque = &ctx_array[i];
190+ rocshmem_ctx_array_device[i].team_opaque = team_tracker.get_team_world ()->tinfo_wrt_world ;
191+ }
192+
193+ CHECK_HIP (hipGetSymbolAddress (reinterpret_cast <void **>(&rocshmem_ctx_array_ptr),
194+ HIP_SYMBOL (rocshmem_ctx_array)));
195+
196+ CHECK_HIP (hipMemcpy (rocshmem_ctx_array_ptr,
197+ &rocshmem_ctx_array_device,
198+ sizeof (rocshmem_ctx*),
199+ hipMemcpyDefault));
181200}
182201
183202void GDABackend::cleanup_ctxs () {
203+ /* Free ctx array */
204+ rocshmem_ctx_t *rocshmem_ctx_array_ptr = nullptr ;
205+ rocshmem_ctx_t *rocshmem_ctx_array_device = nullptr ;
206+
207+ CHECK_HIP (hipGetSymbolAddress (reinterpret_cast <void **>(&rocshmem_ctx_array_ptr),
208+ HIP_SYMBOL (rocshmem_ctx_array)));
209+
210+ CHECK_HIP (hipMemcpy (&rocshmem_ctx_array_device,
211+ rocshmem_ctx_array_ptr,
212+ sizeof (rocshmem_ctx*),
213+ hipMemcpyDefault));
214+
215+ CHECK_HIP (hipFree (rocshmem_ctx_array_device));
216+
184217 ctx_free_list.~FreeListProxy ();
185218 for (size_t i = 0 ; i < envvar::max_num_contexts; i++) {
186219 ctx_array[i].~GDAContext ();
0 commit comments