Skip to content

Commit 9065034

Browse files
noemotiovontheo77186
authored andcommitted
CANN: Improve device ID handling and aclnnArange checks (ggml-org#16752)
* cann: improve device ID handling and aclnnArange checks - Stop relying on CANN's internal device ID retrieval; use a global variable instead. - Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions. * cann: use thread local var
1 parent a1b3682 commit 9065034

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
22342234
ACL_MEM_MALLOC_HUGE_FIRST));
22352235

22362236
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
2237-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2237+
theta_scale_ne, theta_scale_nb, 1);
22382238

22392239
float start = 0;
22402240
float step = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
22512251
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
22522252
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
22532253
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
2254-
theta_scale_nb, GGML_MAX_DIMS);
2254+
theta_scale_nb, 1);
22552255
float zero_value = 0, one_value = 1;
22562256
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
22572257
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,30 @@
6767
GGML_ABORT("CANN error");
6868
}
6969

70+
// Thread-local variable to record the current device of this thread.
71+
thread_local int g_current_cann_device = -1;
72+
7073
/**
71-
* @brief Sets the device to be used by CANN.
74+
* @brief Set the CANN device to be used.
7275
*
73-
* @param device The device ID to set.
76+
* @param device The target device ID to set.
7477
*/
7578
void ggml_cann_set_device(const int32_t device) {
76-
int current_device = -1;
77-
aclrtGetDevice(&current_device);
79+
// int current_device = -1;
80+
// Note: In some CANN versions, if no device has been set yet,
81+
// aclrtGetDevice(&current_device) may return 0 by default.
82+
// aclrtGetDevice(&current_device);
7883

79-
if (device == current_device) {
84+
// If the current device is already the target one, no need to switch.
85+
if (device == g_current_cann_device) {
8086
return;
8187
}
88+
89+
// Switch to the new device.
8290
ACL_CHECK(aclrtSetDevice(device));
91+
92+
// Update the global device record.
93+
g_current_cann_device = device;
8394
}
8495

8596
/**

0 commit comments

Comments
 (0)