Skip to content

Commit c9c9685

Browse files
committed
Add workaround for cuTENSOR 2.3.x - 2.5.x OOB host write bug impacting large problem sizes (#113)
* Add workaround for cuTENSOR 2.3.x - 2.5.x OOB host write bug impacting large problem sizes. Signed-off-by: Josh Romero <joshr@nvidia.com> * Formatting. Signed-off-by: Josh Romero <joshr@nvidia.com> * Bump cuDecomp version to prep for hotfix release. Signed-off-by: Josh Romero <joshr@nvidia.com> --------- Signed-off-by: Josh Romero <joshr@nvidia.com>
1 parent 87d5551 commit c9c9685

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

include/cudecomp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
#define CUDECOMP_MAJOR 0
3131
#define CUDECOMP_MINOR 6
32-
#define CUDECOMP_PATCH 1
32+
#define CUDECOMP_PATCH 2
3333

3434
#ifdef __cplusplus
3535
extern "C" {

include/internal/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ struct cudecompHandle {
7676

7777
cutensorHandle_t cutensor_handle; // cuTENSOR handle;
7878
#if CUTENSOR_MAJOR >= 2
79-
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
79+
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
80+
bool cutensor_needs_permute_chunking = false; // Flag to enable large tensor workaround
8081
#endif
8182

8283
std::vector<std::array<char, MPI_MAX_PROCESSOR_NAME>> hostnames; // list of hostnames by rank

include/internal/transpose.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,50 @@ static void localPermute(const cudecompHandle_t handle, const std::array<int64_t
8282
if (extent_out[i] == 0) return;
8383
}
8484

85+
// Workaround for an out-of-bounds host write bug in cuTENSOR triggered when the
86+
// total number of tensor elements exceeds INT32_MAX/2. We split the tensor so each
87+
// cuTENSOR call stays below that limit.
88+
static constexpr int64_t CUTENSOR_EXTENT_LIMIT = (int64_t)std::numeric_limits<int32_t>::max() / 2;
89+
int64_t total_elems = extent_in[0] * extent_in[1] * extent_in[2];
90+
if (handle->cutensor_needs_permute_chunking && total_elems > CUTENSOR_EXTENT_LIMIT) {
91+
92+
// Always pass explicit strides when splitting
93+
std::array<int64_t, 3> actual_strides_in = strides_in;
94+
if (!anyNonzeros(strides_in)) { actual_strides_in = {extent_in[1] * extent_in[2], extent_in[2], 1}; }
95+
std::array<int64_t, 3> actual_strides_out = strides_out;
96+
if (!anyNonzeros(strides_out)) { actual_strides_out = {extent_out[1] * extent_out[2], extent_out[2], 1}; }
97+
// Try to split on input dims, starting with outermost dim.
98+
std::array<int, 3> inv_order_out;
99+
for (int i = 0; i < 3; ++i)
100+
inv_order_out[order_out[i]] = i;
101+
int split_dim_in = -1;
102+
int64_t elems_per_slice = 0;
103+
for (int j = 2; j >= 0; --j) {
104+
elems_per_slice = total_elems / extent_in[j];
105+
if (elems_per_slice <= CUTENSOR_EXTENT_LIMIT) {
106+
split_dim_in = j;
107+
break;
108+
}
109+
}
110+
111+
if (split_dim_in >= 0) {
112+
// Run localPermute multiple times, once per chunk.
113+
int64_t chunk = std::max((int64_t)1, CUTENSOR_EXTENT_LIMIT / elems_per_slice);
114+
for (int64_t offset = 0; offset < extent_in[split_dim_in]; offset += chunk) {
115+
int64_t this_chunk = std::min(chunk, extent_in[split_dim_in] - offset);
116+
std::array<int64_t, 3> chunk_extent_in = extent_in;
117+
chunk_extent_in[split_dim_in] = this_chunk;
118+
localPermute(handle, chunk_extent_in, order_out, actual_strides_in, actual_strides_out,
119+
input + offset * actual_strides_in[split_dim_in],
120+
output + offset * actual_strides_out[inv_order_out[split_dim_in]], stream);
121+
}
122+
return;
123+
}
124+
// All pairwise products exceed the limit so splitting isn't possible (requires each dimension > sqrt(INT32_MAX/2)
125+
// ~= 32768). This is not a realistic scenario, but throw an error here for completeness.
126+
THROW_INTERNAL_ERROR("Input too large to work around CUTENSOR large-tensor bug");
127+
}
128+
85129
auto strides_in_ptr = anyNonzeros(strides_in) ? strides_in.data() : nullptr;
86130
auto strides_out_ptr = anyNonzeros(strides_out) ? strides_out.data() : nullptr;
87131

src/cudecomp.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) {
439439
CHECK_CUTENSOR(cutensorCreate(&handle->cutensor_handle));
440440
CHECK_CUTENSOR(cutensorCreatePlanPreference(handle->cutensor_handle, &handle->cutensor_plan_pref,
441441
CUTENSOR_ALGO_DEFAULT, CUTENSOR_JIT_MODE_NONE));
442+
// cuTENSOR versions 2.3.x - 2.5.x have a bug where cutensorCreatePlan performs an out-of-bounds
443+
// host write when the total number of tensor elements exceeds INT32_MAX/2. Set a flag
444+
// to enable workaround in localPermute to split large tensors.
445+
size_t cutensor_ver = cutensorGetVersion();
446+
handle->cutensor_needs_permute_chunking = (cutensor_ver >= 20300 && cutensor_ver < 20600);
442447
#else
443448
CHECK_CUTENSOR(cutensorInit(&handle->cutensor_handle));
444449
#endif

0 commit comments

Comments
 (0)