Skip to content

Commit debd095

Browse files
cyyeverpytorchmergebot
authored andcommitted
Avoid index integer overflow in gemm_notrans_ (pytorch#154809)
Use uint64_t index types to avoid ``` torch_np/numpy_tests/core/test_einsum.py::TestEinsum::test_einsum_broadcast /var/lib/jenkins/workspace/aten/src/ATen/native/cpu/BlasKernel.cpp:132:24: runtime error: signed integer overflow: 9223365439786057728 + 13194139533312 cannot be represented in type 'long' #0 0x7f30d26166ba in std::enable_if<std::is_same_v<long, long>, void>::type at::native::cpublas::(anonymous namespace)::gemm_notrans_<long, long, long>(long, long, long, long, long const*, long, long const*, long, long, long*, long) /var/lib/jenkins/workspace/aten/src/ATen/native/cpu/BlasKernel.cpp:132:24 #1 0x7f30d26166ba in void at::native::cpublas::(anonymous namespace)::gemm_core_<long, long, long>(at::native::TransposeType, at::native::TransposeType, long, long, long, long, long const*, long, long const*, long, long, long*, long) /var/lib/jenkins/workspace/aten/src/ATen/native/cpu/BlasKernel.cpp:451:12 #2 0x7f30d25fba1b in at::native::cpublas::(anonymous namespace)::cpublas_gemm_impl(c10::ScalarType, at::native::TransposeType, at::native::TransposeType, long, long, long, c10::Scalar const&, void const*, long, void const*, long, c10::Scalar const&, void*, long)::$_2::operator()() const::'lambda2'()::operator()() const /var/lib/jenkins/workspace/aten/src/ATen/native/cpu/BlasKernel.cpp:485:3 #3 0x7f30d25fba1b in at::native::cpublas::(anonymous namespace)::cpublas_gemm_impl(c10::ScalarType, at::native::TransposeType, at::native::TransposeType, long, long, long, c10::Scalar const&, void const*, long, void const*, long, c10::Scalar const&, void*, long)::$_2::operator()() const /var/lib/jenkins/workspace/aten/src/ATen/native/cpu/BlasKernel.cpp:485:3 ``` Pull Request resolved: pytorch#154809 Approved by: https://github.com/soulitzer
1 parent 10c3e6e commit debd095

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

aten/src/ATen/native/cpu/BlasKernel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,19 @@ gemm_notrans_(
117117
scale_(m, n, beta, c, ldc);
118118

119119
// c += alpha * (a @ b)
120-
for (const auto l : c10::irange(k)) {
121-
for (const auto j : c10::irange(n)) {
120+
const uint64_t unsigned_m = static_cast<int64_t>(m);
121+
const uint64_t i_m = unsigned_m / 4;
122+
for (const uint64_t l : c10::irange(k)) {
123+
for (const uint64_t j : c10::irange(n)) {
122124
opmath_t val = b[l + j * ldb] * alpha;
123-
int64_t i_m = m / 4;
124125
for (const auto i_i : c10::irange(i_m)) {
125126
c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
126127
c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
127128
c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
128129
c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
129130
}
130-
int64_t i = i_m * 4;
131-
for (; i < m; i++)
131+
uint64_t i = i_m * 4;
132+
for (; i < unsigned_m; i++)
132133
c[j * ldc + i] += a[i + l * lda] * val;
133134
}
134135
}

0 commit comments

Comments
 (0)