Skip to content

Commit 77d76f8

Browse files
committed
Fix crash in optimized log softmax
#13551
1 parent c5ff74c commit 77d76f8

File tree

2 files changed

+76
-11
lines changed

2 files changed

+76
-11
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,15 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
103103
template <
104104
typename OUT_T,
105105
std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
106-
void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
106+
bool log_softmax_wrapper(KernelRuntimeContext& context, const Tensor& X, int64_t dim, Tensor& out) {
107107
auto input_scalar_type = X.scalar_type();
108108
switch (input_scalar_type) {
109109
// TODO: support Double as well
110110
case ScalarType::Float:
111111
log_softmax_kernel<float, OUT_T>(X, dim, out);
112-
break;
112+
return true;
113113
default:
114-
ET_CHECK_MSG(
115-
false,
116-
"Unhandled input dtype %" PRId8,
117-
static_cast<int8_t>(input_scalar_type));
114+
return false; // Unsupported input dtype
118115
}
119116
}
120117
} // namespace
@@ -146,14 +143,21 @@ Tensor& opt_log_softmax_out(
146143
auto out_scalar_type = out.scalar_type();
147144
switch (out_scalar_type) {
148145
// TODO: support Double as well
149-
case ScalarType::Float:
150-
log_softmax_wrapper<float>(self, dim, out);
146+
case ScalarType::Float: {
147+
bool success = log_softmax_wrapper<float>(context, self, dim, out);
148+
ET_KERNEL_CHECK(
149+
context,
150+
success,
151+
InvalidArgument,
152+
out);
151153
break;
154+
}
152155
default:
153-
ET_CHECK_MSG(
156+
ET_KERNEL_CHECK(
157+
context,
154158
false,
155-
"Unhandled out dtype %" PRId8,
156-
static_cast<int8_t>(out_scalar_type));
159+
InvalidArgument,
160+
out);
157161
}
158162
return out;
159163
}

kernels/test/op_log_softmax_test.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,64 @@ TEST_F(OpLogSoftmaxOutTest, DynamicShapeUnbound) {
447447
Tensor ret = op_log_softmax_out(x, 1, false, out);
448448
EXPECT_TENSOR_CLOSE(out, expected_result);
449449
}
450+
451+
TEST_F(OpLogSoftmaxOutTest, DoubleCase) {
452+
TensorFactory<ScalarType::Double> tf;
453+
454+
// Test case with specific inputs:
455+
// Input tensor: torch.float64 (8, 5, 7)
456+
// Dim: 2
457+
// half_to_float: False
458+
Tensor input = tf.zeros({8, 5, 7});
459+
auto in_data = input.mutable_data_ptr<double>();
460+
461+
// Fill with some test data (sequential values scaled)
462+
for (int i = 0; i < 8 * 5 * 7; i++) {
463+
in_data[i] = static_cast<double>(i) * 0.01;
464+
}
465+
466+
// Output tensor with same shape
467+
Tensor out = tf.zeros({8, 5, 7});
468+
469+
// Apply log_softmax along dimension 2 (the last dimension with size 7)
470+
op_log_softmax_out(input, /*dim=*/2, /*half_to_float=*/false, out);
471+
472+
if (!SupportedFeatures::get()->op_log_softmax_dtype_double) {
473+
// For optimized kernels, we expect the call above to fail gracefully
474+
expect_failure();
475+
GTEST_SKIP() << "This kernel does not support dtype double";
476+
}
477+
478+
// Verify output dimensions
479+
EXPECT_EQ(out.size(0), 8);
480+
EXPECT_EQ(out.size(1), 5);
481+
EXPECT_EQ(out.size(2), 7);
482+
483+
// Verify that output has reasonable values
484+
auto out_data = out.const_data_ptr<double>();
485+
486+
// Check for NaN or Inf values
487+
for (int i = 0; i < 8 * 5 * 7; i++) {
488+
EXPECT_FALSE(std::isnan(out_data[i])) << "Output should not contain NaN at index " << i;
489+
EXPECT_FALSE(std::isinf(out_data[i])) << "Output should not contain Inf at index " << i;
490+
}
491+
492+
// For log_softmax, all values should be <= 0 (since softmax values are <= 1, log is <= 0)
493+
for (int i = 0; i < 8 * 5 * 7; i++) {
494+
EXPECT_LE(out_data[i], 0.0) << "Log softmax values should be <= 0 at index " << i;
495+
}
496+
497+
// Verify that each slice along dimension 2 sums to approximately 1 when exp'd
498+
// This tests the core property of softmax: sum(softmax(x)) = 1
499+
for (int batch = 0; batch < 8; batch++) {
500+
for (int channel = 0; channel < 5; channel++) {
501+
double sum_exp = 0.0;
502+
for (int dim2 = 0; dim2 < 7; dim2++) {
503+
int idx = batch * 5 * 7 + channel * 7 + dim2;
504+
sum_exp += std::exp(out_data[idx]);
505+
}
506+
EXPECT_NEAR(sum_exp, 1.0, 1e-6)
507+
<< "Sum of exp(log_softmax) should be 1.0 for batch=" << batch << ", channel=" << channel;
508+
}
509+
}
510+
}

0 commit comments

Comments
 (0)