Skip to content

Commit 7aa466e

Browse files
authored
Dump C++ and Status propagation stacktraces. (#9492)
1 parent cd3bd91 commit 7aa466e

File tree

4 files changed

+345
-128
lines changed

4 files changed

+345
-128
lines changed

test/cpp/test_status_common.h

Lines changed: 198 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
#ifndef XLA_TEST_CPP_TEST_STATUS_COMMON_H_
1919
#define XLA_TEST_CPP_TEST_STATUS_COMMON_H_
2020

21+
#include <c10/util/Exception.h>
22+
#include <gmock/gmock.h>
2123
#include <gtest/gtest.h>
2224

2325
#include <cstdlib>
26+
#include <stdexcept>
2427
#include <utility>
2528

2629
#include "absl/status/status.h"
@@ -30,7 +33,7 @@
3033

3134
namespace torch_xla {
3235

33-
// Enum to control whether C++ error context is shown in status messages
36+
// Enum to control whether C++ error context is shown in status messages.
3437
enum class CppStacktracesMode {
3538
kShow,
3639
kHide,
@@ -74,19 +77,68 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {
7477

7578
namespace testing {
7679

77-
constexpr inline char new_message[] = "New test error message";
78-
constexpr inline char message[] = "Test error message";
79-
constexpr inline char test_file[] = "test_file.cpp";
80-
constexpr inline int32_t line = 42;
80+
constexpr inline char kNewMessage[] = "New test error message";
81+
constexpr inline char kMessage[] = "Test error message";
82+
constexpr inline char kFile[] = "test_file.cpp";
83+
constexpr inline char kFunction[] = "foo";
84+
constexpr inline char kEntryPrefix[] = "\n ";
85+
constexpr inline int32_t kLine = 42;
86+
87+
// The PyTorch C++ stacktrace is ALWAYS appended to the error message.
88+
// More specifically, when `what()` function is called.
89+
//
90+
// However, it's only when the raised `c10::Error` gets translated to a
91+
// Python exception that PyTorch checks the value of the
92+
// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually
93+
// controls whether the stacktrace will get shown or not by calling
94+
// `what_without_backtraces()`, instead.
95+
//
96+
// Therefore, we need to mimic this behavior.
97+
#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \
98+
try { \
99+
block; \
100+
} catch (const c10::Error& error) { \
101+
throw std::runtime_error(IsShowCppStacktracesMode() \
102+
? error.what() \
103+
: error.what_without_backtrace()); \
104+
}
105+
106+
// Prefix of the C++ stacktrace PyTorch adds to the error message.
107+
constexpr inline char kTorchCppStacktracePrefix[] =
108+
"Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:";
109+
110+
inline std::string GetStatusPropagationTrace(const absl::Status& status) {
111+
if (status.ok()) {
112+
return "";
113+
}
114+
auto status_propagation_trace = status.GetPayload(kStatusPropagationTraceKey);
115+
return status_propagation_trace.has_value()
116+
? std::string(status_propagation_trace->Flatten())
117+
: "";
118+
}
81119

82120
TEST_P(StatusTest, MaybeThrowWithOkStatus) {
83121
absl::Status ok_status = absl::OkStatus();
84122
EXPECT_NO_THROW(MaybeThrow(ok_status));
85123
}
86124

87125
TEST_P(StatusTest, MaybeThrowWithErrorStatus) {
88-
absl::Status error_status = absl::InvalidArgumentError(message);
89-
EXPECT_THROW(MaybeThrow(error_status), std::runtime_error);
126+
auto throw_exception = [=]() {
127+
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
128+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
129+
MaybeThrow(error_status);
130+
});
131+
};
132+
133+
if (IsShowCppStacktracesMode()) {
134+
std::string expected_prefix =
135+
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
136+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
137+
::testing::StartsWith(expected_prefix)));
138+
} else {
139+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
140+
::testing::Eq(kMessage)));
141+
}
90142
}
91143

92144
TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
@@ -97,44 +149,75 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
97149
}
98150

99151
TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) {
100-
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
101-
EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error);
152+
auto throw_exception = [=]() {
153+
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
154+
absl::StatusOr<int> error_status = absl::InvalidArgumentError(kMessage);
155+
int value = GetValueOrThrow(error_status);
156+
});
157+
};
158+
if (IsShowCppStacktracesMode()) {
159+
std::string expected_prefix =
160+
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
161+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
162+
::testing::StartsWith(expected_prefix)));
163+
} else {
164+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
165+
::testing::Eq(kMessage)));
166+
}
102167
}
103168

104169
TEST_P(StatusTest, MaybeWithLocationPropagatesErrorStatus) {
105-
absl::Status error_status = absl::InvalidArgumentError(message);
106-
absl::Status result = MaybeWithLocation(error_status, test_file, line);
170+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
171+
absl::Status result =
172+
status_internal::MaybeWithLocation(error_status, kFile, kLine, kFunction);
173+
174+
ASSERT_FALSE(result.ok());
175+
EXPECT_EQ(result.code(), error_status.code());
176+
EXPECT_EQ(result.message(), error_status.message());
177+
107178
if (IsShowCppStacktracesMode()) {
108-
ASSERT_NE(result, error_status);
109-
EXPECT_FALSE(result.ok());
110-
EXPECT_EQ(result.code(), error_status.code());
111-
EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)");
179+
EXPECT_NE(result, error_status);
180+
EXPECT_EQ(GetStatusPropagationTrace(result),
181+
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
182+
":", kLine, " (error: ", kMessage, ")"));
112183
} else {
113184
EXPECT_EQ(result, error_status);
114185
}
115186
}
116187

117188
TEST_P(StatusTest, MaybeWithNewMessageEmptyNewMessage) {
118-
absl::Status error_status = absl::InvalidArgumentError(message);
119-
absl::Status result = MaybeWithNewMessage(error_status, test_file, line);
120-
EXPECT_EQ(result, error_status);
189+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
190+
absl::Status result = status_internal::MaybeWithNewMessage(
191+
error_status, kFile, kLine, kFunction);
192+
193+
ASSERT_FALSE(result.ok());
194+
EXPECT_EQ(result.code(), error_status.code());
195+
EXPECT_EQ(result.message(), error_status.message());
196+
197+
if (IsShowCppStacktracesMode()) {
198+
EXPECT_NE(result, error_status);
199+
EXPECT_EQ(GetStatusPropagationTrace(result),
200+
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
201+
":", kLine));
202+
} else {
203+
EXPECT_EQ(result, error_status);
204+
}
121205
}
122206

123207
TEST_P(StatusTest, MaybeWithNewMessageNonEmptyNewMessage) {
124-
absl::Status error_status = absl::InvalidArgumentError(message);
125-
absl::Status result =
126-
MaybeWithNewMessage(error_status, test_file, line, new_message);
208+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
209+
absl::Status result = status_internal::MaybeWithNewMessage(
210+
error_status, kFile, kLine, kFunction, kNewMessage);
127211

128-
ASSERT_NE(result, error_status);
129212
ASSERT_FALSE(result.ok());
130213
EXPECT_EQ(result.code(), error_status.code());
214+
EXPECT_EQ(result.message(), std::string_view(kNewMessage));
215+
EXPECT_NE(result, error_status);
131216

132217
if (IsShowCppStacktracesMode()) {
133-
EXPECT_EQ(result.message(),
134-
absl::StrCat("New test error message (at test_file.cpp:42)\n"
135-
"From Error: Test error message"));
136-
} else {
137-
EXPECT_EQ(result.message(), std::string_view(new_message));
218+
EXPECT_EQ(GetStatusPropagationTrace(result),
219+
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
220+
":", kLine, " (error: ", kNewMessage, ")"));
138221
}
139222
}
140223

@@ -154,29 +237,30 @@ TEST_P(StatusTest, MacroReturnIfError) {
154237

155238
TEST_P(StatusTest, MacroReturnIfErrorWithError) {
156239
auto test_function = [=]() -> absl::Status {
157-
absl::Status error_status = absl::InvalidArgumentError(message);
240+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
158241
XLA_RETURN_IF_ERROR(error_status);
159242
return absl::OkStatus();
160243
};
161244

162245
absl::Status result = test_function();
163246
ASSERT_FALSE(result.ok());
164247
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
165-
EXPECT_EQ(result.message(), std::string_view(message));
248+
EXPECT_EQ(result.message(), std::string_view(kMessage));
166249
}
167250

168251
TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
169-
int32_t errline = 0;
170-
auto inner_test_function = [&errline]() -> absl::Status {
171-
errline = __LINE__ + 1;
172-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message));
252+
int32_t errline0 = __LINE__ + 2;
253+
auto inner_test_function = []() -> absl::Status {
254+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
173255
};
174256

257+
int32_t errline1 = __LINE__ + 2;
175258
auto test_function = [&]() -> absl::Status {
176259
XLA_RETURN_IF_ERROR(inner_test_function());
177260
return absl::OkStatus();
178261
};
179262

263+
int32_t errline2 = __LINE__ + 2;
180264
auto outer_test_function = [&]() -> absl::Status {
181265
XLA_RETURN_IF_ERROR(test_function());
182266
return absl::OkStatus();
@@ -185,34 +269,37 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
185269
absl::Status result = outer_test_function();
186270
ASSERT_FALSE(result.ok());
187271
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
272+
EXPECT_EQ(result.message(), std::string_view(kMessage));
188273

189274
if (IsShowCppStacktracesMode()) {
190-
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ",
191-
__FILE__, ":", errline, ")"));
192-
} else {
193-
EXPECT_EQ(result.message(), std::string_view(message));
275+
auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
276+
":", errline0, " (error: ", kMessage, ")");
277+
auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
278+
":", errline1);
279+
auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
280+
":", errline2);
281+
EXPECT_EQ(GetStatusPropagationTrace(result),
282+
absl::StrCat(frame0, frame1, frame2));
194283
}
195284
}
196285

197286
TEST_P(StatusTest, MacroReturnIfErrorWithErrorWithNewMessage) {
198-
int32_t errline = 0;
199-
auto test_function = [&errline]() -> absl::Status {
200-
absl::Status error_status = absl::InvalidArgumentError(message);
201-
errline = __LINE__ + 1;
202-
XLA_RETURN_IF_ERROR(error_status, new_message);
287+
int32_t errline = __LINE__ + 3;
288+
auto test_function = []() -> absl::Status {
289+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
290+
XLA_RETURN_IF_ERROR(error_status, kNewMessage);
203291
return absl::OkStatus();
204292
};
205293

206294
absl::Status result = test_function();
207295
ASSERT_FALSE(result.ok());
208296
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
297+
EXPECT_EQ(result.message(), std::string_view(kNewMessage));
209298

210299
if (IsShowCppStacktracesMode()) {
211-
EXPECT_EQ(result.message(),
212-
absl::StrCat("New test error message (at ", __FILE__, ":",
213-
errline, ")\nFrom Error: Test error message"));
214-
} else {
215-
EXPECT_EQ(result.message(), std::string_view(new_message));
300+
EXPECT_EQ(GetStatusPropagationTrace(result),
301+
absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":",
302+
errline, " (error: ", kNewMessage, ")"));
216303
}
217304
}
218305

@@ -233,51 +320,98 @@ TEST_P(StatusTest, MacroAssignOrReturn) {
233320

234321
TEST_P(StatusTest, MacroAssignOrReturnWithError) {
235322
auto test_function = []() -> absl::StatusOr<int> {
236-
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
323+
absl::StatusOr<int> status_or = absl::InvalidArgumentError(kMessage);
237324
XLA_ASSIGN_OR_RETURN(int value, status_or);
238325
return value * 2;
239326
};
240327

241328
absl::StatusOr<int> result = test_function();
242329
ASSERT_FALSE(result.ok());
243330
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
244-
EXPECT_EQ(result.status().message(), std::string_view(message));
331+
EXPECT_EQ(result.status().message(), std::string_view(kMessage));
245332
}
246333

247334
TEST_P(StatusTest, MacroAssignOrReturnWithErrorWithNewMessage) {
248-
int32_t errline = 0;
249-
250-
auto test_function = [&errline]() -> absl::StatusOr<int> {
251-
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
252-
errline = __LINE__ + 1;
253-
XLA_ASSIGN_OR_RETURN(int value, status_or, new_message);
335+
int32_t errline = __LINE__ + 3;
336+
auto test_function = []() -> absl::StatusOr<int> {
337+
absl::StatusOr<int> status_or = absl::InvalidArgumentError(kMessage);
338+
XLA_ASSIGN_OR_RETURN(int value, status_or, kNewMessage);
254339
return value * 2;
255340
};
256341

257342
absl::StatusOr<int> result = test_function();
258343
ASSERT_FALSE(result.ok());
259344
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
345+
EXPECT_EQ(result.status().message(), std::string_view(kNewMessage));
260346

261347
if (IsShowCppStacktracesMode()) {
262-
EXPECT_EQ(result.status().message(),
263-
absl::StrCat("New test error message (at ", __FILE__, ":",
264-
errline, ")\nFrom Error: Test error message"));
265-
} else {
266-
EXPECT_EQ(result.status().message(), std::string_view(new_message));
348+
EXPECT_EQ(GetStatusPropagationTrace(result.status()),
349+
absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":",
350+
errline, " (error: ", kNewMessage, ")"));
267351
}
268352
}
269353

270354
TEST_P(StatusTest, MacroErrorWithLocation) {
271-
absl::Status error_status = absl::InvalidArgumentError(message);
355+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
272356
int32_t errline = __LINE__ + 1;
273357
absl::Status result = XLA_ERROR_WITH_LOCATION(error_status);
274358
ASSERT_FALSE(result.ok());
275359
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
360+
EXPECT_EQ(result.message(), std::string_view(kMessage));
361+
if (IsShowCppStacktracesMode()) {
362+
EXPECT_EQ(GetStatusPropagationTrace(result),
363+
absl::StrCat(kEntryPrefix, "From: ", __FUNCTION__, " at ",
364+
__FILE__, ":", errline, " (error: ", kMessage, ")"));
365+
}
366+
}
367+
368+
TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
369+
int32_t errline0 = __LINE__ + 2;
370+
auto innerfn = [&]() -> absl::Status {
371+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
372+
};
373+
374+
int32_t errline1 = __LINE__ + 2;
375+
auto midfn = [&]() -> absl::Status {
376+
XLA_RETURN_IF_ERROR(innerfn(), kNewMessage);
377+
return absl::OkStatus();
378+
};
379+
380+
int32_t errline2 = __LINE__ + 2;
381+
auto outerfn = [&]() -> absl::Status {
382+
XLA_RETURN_IF_ERROR(midfn());
383+
return absl::OkStatus();
384+
};
385+
386+
auto throw_exception = [&]() {
387+
THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn()));
388+
};
389+
276390
if (IsShowCppStacktracesMode()) {
277-
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ",
278-
__FILE__, ":", errline, ")"));
391+
// Expected Error Message Prefix
392+
// =============================
393+
//
394+
// New test error kMessage
395+
//
396+
// Status Propagation Stacktrace:
397+
// From: ./test/cpp/test_status_common.h:329 (error: Test error
398+
// kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test
399+
// error kMessage) From: ./test/cpp/test_status_common.h:342
400+
//
401+
// C++ Stacktrace:
402+
//
403+
std::string expected_prefix = absl::StrCat(
404+
kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix,
405+
"From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage,
406+
")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1,
407+
" (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ",
408+
__FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix);
409+
410+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
411+
::testing::StartsWith(expected_prefix)));
279412
} else {
280-
EXPECT_EQ(result.message(), std::string_view(message));
413+
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
414+
::testing::Eq(kNewMessage)));
281415
}
282416
}
283417

0 commit comments

Comments
 (0)