@@ -264,13 +264,7 @@ inline detail::MemoryLogger &GlobalMemoryLogger() {
264264}
265265
266266#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
267-
268- #if (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
269- using DeviceAsyncResourceRef = cuda::mr::resource_ref<cuda::mr::device_accessible>;
270- #else // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
271- using DeviceAsyncResourceRef = cuda::mr::async_resource_ref<cuda::mr::device_accessible>;
272- #endif // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
273-
267+ using DeviceAsyncResourceRef = rmm::device_async_resource_ref;
274268#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
275269
276270namespace detail {
@@ -280,13 +274,6 @@ namespace detail {
280274 */
281275template <typename T>
282276class ThrustAllocMrAdapter : public thrust ::device_malloc_allocator<T> {
283- // TODO(hcho3): Remove this guard once we require Rapids 25.12+
284- #if (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
285- DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource_ref ()};
286- #else // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
287- DeviceAsyncResourceRef mr_{rmm::mr::get_current_device_resource ()};
288- #endif // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
289-
290277 public:
291278 using Super = thrust::device_malloc_allocator<T>;
292279 using pointer = typename Super::pointer; // NOLINT(readability-identifier-naming)
@@ -299,22 +286,38 @@ class ThrustAllocMrAdapter : public thrust::device_malloc_allocator<T> {
299286
300287 ThrustAllocMrAdapter () = default ;
301288 pointer allocate (size_type n) { // NOLINT(readability-identifier-naming)
289+
290+ // TODO(hcho3): Remove this guard once we require Rapids 25.12+
291+ #if (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
292+ DeviceAsyncResourceRef mr{rmm::mr::get_current_device_resource_ref ()};
293+ #else // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
294+ DeviceAsyncResourceRef mr{rmm::mr::get_current_device_resource ()};
295+ #endif // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
296+
302297 auto n_bytes = xgboost::common::SizeBytes<T>(n);
303298 auto s = cuda::stream_ref{::xgboost::curt::DefaultStream ()};
304299#if (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
305- auto p = static_cast <T *>(mr_ .allocate (s, n_bytes, std::alignment_of_v<T>));
306- #else // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
307- auto p = static_cast <T *>(mr_ .allocate_async (n_bytes, std::alignment_of_v<T>, s));
300+ auto p = static_cast <T *>(mr .allocate (s, n_bytes, std::alignment_of_v<T>));
301+ #else // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
302+ auto p = static_cast <T *>(mr .allocate_async (n_bytes, std::alignment_of_v<T>, s));
308303#endif // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
309304 return thrust::device_pointer_cast (p);
310305 }
311306 void deallocate (pointer ptr, size_type n) { // NOLINT(readability-identifier-naming)
307+
308+ // TODO(hcho3): Remove this guard once we require Rapids 25.12+
309+ #if (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
310+ DeviceAsyncResourceRef mr{rmm::mr::get_current_device_resource_ref ()};
311+ #else // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
312+ DeviceAsyncResourceRef mr{rmm::mr::get_current_device_resource ()};
313+ #endif // (RMM_VERSION_MAJOR == 25 && RMM_VERSION_MINOR == 12) || RMM_VERSION_MAJOR >= 26
314+
312315 auto n_bytes = xgboost::common::SizeBytes<T>(n);
313316 auto s = ::xgboost::curt::DefaultStream ();
314317#if (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
315- return mr_ .deallocate (cuda::stream_ref{s}, thrust::raw_pointer_cast (ptr), n_bytes);
316- #else // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
317- return mr_ .deallocate_async (thrust::raw_pointer_cast (ptr), n_bytes, cuda::stream_ref{s});
318+ return mr .deallocate (cuda::stream_ref{s}, thrust::raw_pointer_cast (ptr), n_bytes);
319+ #else // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
320+ return mr .deallocate_async (thrust::raw_pointer_cast (ptr), n_bytes, cuda::stream_ref{s});
318321#endif // (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1) || CCCL_MAJOR_VERSION > 3
319322 }
320323};
0 commit comments