18
18
#ifndef XLA_TEST_CPP_TEST_STATUS_COMMON_H_
19
19
#define XLA_TEST_CPP_TEST_STATUS_COMMON_H_
20
20
21
+ #include < c10/util/Exception.h>
22
+ #include < gmock/gmock.h>
21
23
#include < gtest/gtest.h>
22
24
23
25
#include < cstdlib>
26
+ #include < stdexcept>
24
27
#include < utility>
25
28
26
29
#include " absl/status/status.h"
30
33
31
34
namespace torch_xla {
32
35
33
- // Enum to control whether C++ error context is shown in status messages
36
+ // Enum to control whether C++ error context is shown in status messages.
34
37
enum class CppStacktracesMode {
35
38
kShow ,
36
39
kHide ,
@@ -74,19 +77,68 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {
74
77
75
78
namespace testing {
76
79
77
- constexpr inline char new_message[] = " New test error message" ;
78
- constexpr inline char message[] = " Test error message" ;
79
- constexpr inline char test_file[] = " test_file.cpp" ;
80
- constexpr inline int32_t line = 42 ;
80
+ constexpr inline char kNewMessage [] = " New test error message" ;
81
+ constexpr inline char kMessage [] = " Test error message" ;
82
+ constexpr inline char kFile [] = " test_file.cpp" ;
83
+ constexpr inline char kFunction [] = " foo" ;
84
+ constexpr inline char kEntryPrefix [] = " \n " ;
85
+ constexpr inline int32_t kLine = 42 ;
86
+
87
+ // The PyTorch C++ stacktrace is ALWAYS appended to the error message.
88
+ // More specifically, when `what()` function is called.
89
+ //
90
+ // However, it's only when the raised `c10::Error` gets translated to a
91
+ // Python exception that PyTorch checks the value of the
92
+ // `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually
93
+ // controls whether the stacktrace will get shown or not by calling
94
+ // `what_without_backtraces()`, instead.
95
+ //
96
+ // Therefore, we need to mimic this behavior.
97
+ #define THROW_RUNTIME_ERROR_FROM_C10_ERROR (block ) \
98
+ try { \
99
+ block; \
100
+ } catch (const c10::Error& error) { \
101
+ throw std::runtime_error (IsShowCppStacktracesMode () \
102
+ ? error.what () \
103
+ : error.what_without_backtrace ()); \
104
+ }
105
+
106
+ // Prefix of the C++ stacktrace PyTorch adds to the error message.
107
+ constexpr inline char kTorchCppStacktracePrefix [] =
108
+ " Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:" ;
109
+
110
+ inline std::string GetStatusPropagationTrace (const absl::Status& status) {
111
+ if (status.ok ()) {
112
+ return " " ;
113
+ }
114
+ auto status_propagation_trace = status.GetPayload (kStatusPropagationTraceKey );
115
+ return status_propagation_trace.has_value ()
116
+ ? std::string (status_propagation_trace->Flatten ())
117
+ : " " ;
118
+ }
81
119
82
120
TEST_P (StatusTest, MaybeThrowWithOkStatus) {
83
121
absl::Status ok_status = absl::OkStatus ();
84
122
EXPECT_NO_THROW (MaybeThrow (ok_status));
85
123
}
86
124
87
125
TEST_P (StatusTest, MaybeThrowWithErrorStatus) {
88
- absl::Status error_status = absl::InvalidArgumentError (message);
89
- EXPECT_THROW (MaybeThrow (error_status), std::runtime_error);
126
+ auto throw_exception = [=]() {
127
+ THROW_RUNTIME_ERROR_FROM_C10_ERROR ({
128
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
129
+ MaybeThrow (error_status);
130
+ });
131
+ };
132
+
133
+ if (IsShowCppStacktracesMode ()) {
134
+ std::string expected_prefix =
135
+ absl::StrCat (kMessage , " \n\n " , kTorchCppStacktracePrefix );
136
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
137
+ ::testing::StartsWith (expected_prefix)));
138
+ } else {
139
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
140
+ ::testing::Eq (kMessage )));
141
+ }
90
142
}
91
143
92
144
TEST_P (StatusTest, GetValueOrThrowWithOkStatusOr) {
@@ -97,44 +149,75 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
97
149
}
98
150
99
151
TEST_P (StatusTest, GetValueOrThrowWithErrorStatusOr) {
100
- absl::StatusOr<int > status_or = absl::InvalidArgumentError (message);
101
- EXPECT_THROW (GetValueOrThrow (std::move (status_or)), std::runtime_error);
152
+ auto throw_exception = [=]() {
153
+ THROW_RUNTIME_ERROR_FROM_C10_ERROR ({
154
+ absl::StatusOr<int > error_status = absl::InvalidArgumentError (kMessage );
155
+ int value = GetValueOrThrow (error_status);
156
+ });
157
+ };
158
+ if (IsShowCppStacktracesMode ()) {
159
+ std::string expected_prefix =
160
+ absl::StrCat (kMessage , " \n\n " , kTorchCppStacktracePrefix );
161
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
162
+ ::testing::StartsWith (expected_prefix)));
163
+ } else {
164
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
165
+ ::testing::Eq (kMessage )));
166
+ }
102
167
}
103
168
104
169
TEST_P (StatusTest, MaybeWithLocationPropagatesErrorStatus) {
105
- absl::Status error_status = absl::InvalidArgumentError (message);
106
- absl::Status result = MaybeWithLocation (error_status, test_file, line);
170
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
171
+ absl::Status result =
172
+ status_internal::MaybeWithLocation (error_status, kFile , kLine , kFunction );
173
+
174
+ ASSERT_FALSE (result.ok ());
175
+ EXPECT_EQ (result.code (), error_status.code ());
176
+ EXPECT_EQ (result.message (), error_status.message ());
177
+
107
178
if (IsShowCppStacktracesMode ()) {
108
- ASSERT_NE (result, error_status);
109
- EXPECT_FALSE (result. ok ());
110
- EXPECT_EQ (result. code (), error_status. code ());
111
- EXPECT_EQ (result. message (), " Test error message (at test_file.cpp:42) " );
179
+ EXPECT_NE (result, error_status);
180
+ EXPECT_EQ ( GetStatusPropagationTrace (result),
181
+ absl::StrCat ( kEntryPrefix , " From: " , kFunction , " at " , kFile ,
182
+ " : " , kLine , " ( error: " , kMessage , " ) " ) );
112
183
} else {
113
184
EXPECT_EQ (result, error_status);
114
185
}
115
186
}
116
187
117
188
TEST_P (StatusTest, MaybeWithNewMessageEmptyNewMessage) {
118
- absl::Status error_status = absl::InvalidArgumentError (message);
119
- absl::Status result = MaybeWithNewMessage (error_status, test_file, line);
120
- EXPECT_EQ (result, error_status);
189
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
190
+ absl::Status result = status_internal::MaybeWithNewMessage (
191
+ error_status, kFile , kLine , kFunction );
192
+
193
+ ASSERT_FALSE (result.ok ());
194
+ EXPECT_EQ (result.code (), error_status.code ());
195
+ EXPECT_EQ (result.message (), error_status.message ());
196
+
197
+ if (IsShowCppStacktracesMode ()) {
198
+ EXPECT_NE (result, error_status);
199
+ EXPECT_EQ (GetStatusPropagationTrace (result),
200
+ absl::StrCat (kEntryPrefix , " From: " , kFunction , " at " , kFile ,
201
+ " :" , kLine ));
202
+ } else {
203
+ EXPECT_EQ (result, error_status);
204
+ }
121
205
}
122
206
123
207
TEST_P (StatusTest, MaybeWithNewMessageNonEmptyNewMessage) {
124
- absl::Status error_status = absl::InvalidArgumentError (message );
125
- absl::Status result =
126
- MaybeWithNewMessage ( error_status, test_file, line, new_message );
208
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
209
+ absl::Status result = status_internal::MaybeWithNewMessage (
210
+ error_status, kFile , kLine , kFunction , kNewMessage );
127
211
128
- ASSERT_NE (result, error_status);
129
212
ASSERT_FALSE (result.ok ());
130
213
EXPECT_EQ (result.code (), error_status.code ());
214
+ EXPECT_EQ (result.message (), std::string_view (kNewMessage ));
215
+ EXPECT_NE (result, error_status);
131
216
132
217
if (IsShowCppStacktracesMode ()) {
133
- EXPECT_EQ (result.message (),
134
- absl::StrCat (" New test error message (at test_file.cpp:42)\n "
135
- " From Error: Test error message" ));
136
- } else {
137
- EXPECT_EQ (result.message (), std::string_view (new_message));
218
+ EXPECT_EQ (GetStatusPropagationTrace (result),
219
+ absl::StrCat (kEntryPrefix , " From: " , kFunction , " at " , kFile ,
220
+ " :" , kLine , " (error: " , kNewMessage , " )" ));
138
221
}
139
222
}
140
223
@@ -154,29 +237,30 @@ TEST_P(StatusTest, MacroReturnIfError) {
154
237
155
238
TEST_P (StatusTest, MacroReturnIfErrorWithError) {
156
239
auto test_function = [=]() -> absl::Status {
157
- absl::Status error_status = absl::InvalidArgumentError (message );
240
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
158
241
XLA_RETURN_IF_ERROR (error_status);
159
242
return absl::OkStatus ();
160
243
};
161
244
162
245
absl::Status result = test_function ();
163
246
ASSERT_FALSE (result.ok ());
164
247
EXPECT_EQ (result.code (), absl::StatusCode::kInvalidArgument );
165
- EXPECT_EQ (result.message (), std::string_view (message ));
248
+ EXPECT_EQ (result.message (), std::string_view (kMessage ));
166
249
}
167
250
168
251
TEST_P (StatusTest, MacroReturnIfErrorWithNestedError) {
169
- int32_t errline = 0 ;
170
- auto inner_test_function = [&errline]() -> absl::Status {
171
- errline = __LINE__ + 1 ;
172
- return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (message));
252
+ int32_t errline0 = __LINE__ + 2 ;
253
+ auto inner_test_function = []() -> absl::Status {
254
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (kMessage ));
173
255
};
174
256
257
+ int32_t errline1 = __LINE__ + 2 ;
175
258
auto test_function = [&]() -> absl::Status {
176
259
XLA_RETURN_IF_ERROR (inner_test_function ());
177
260
return absl::OkStatus ();
178
261
};
179
262
263
+ int32_t errline2 = __LINE__ + 2 ;
180
264
auto outer_test_function = [&]() -> absl::Status {
181
265
XLA_RETURN_IF_ERROR (test_function ());
182
266
return absl::OkStatus ();
@@ -185,34 +269,37 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
185
269
absl::Status result = outer_test_function ();
186
270
ASSERT_FALSE (result.ok ());
187
271
EXPECT_EQ (result.code (), absl::StatusCode::kInvalidArgument );
272
+ EXPECT_EQ (result.message (), std::string_view (kMessage ));
188
273
189
274
if (IsShowCppStacktracesMode ()) {
190
- EXPECT_EQ (result.message (), absl::StrCat (" Test error message (at " ,
191
- __FILE__, " :" , errline, " )" ));
192
- } else {
193
- EXPECT_EQ (result.message (), std::string_view (message));
275
+ auto frame0 = absl::StrCat (kEntryPrefix , " From: operator() at " , __FILE__,
276
+ " :" , errline0, " (error: " , kMessage , " )" );
277
+ auto frame1 = absl::StrCat (kEntryPrefix , " From: operator() at " , __FILE__,
278
+ " :" , errline1);
279
+ auto frame2 = absl::StrCat (kEntryPrefix , " From: operator() at " , __FILE__,
280
+ " :" , errline2);
281
+ EXPECT_EQ (GetStatusPropagationTrace (result),
282
+ absl::StrCat (frame0, frame1, frame2));
194
283
}
195
284
}
196
285
197
286
TEST_P (StatusTest, MacroReturnIfErrorWithErrorWithNewMessage) {
198
- int32_t errline = 0 ;
199
- auto test_function = [&errline]() -> absl::Status {
200
- absl::Status error_status = absl::InvalidArgumentError (message);
201
- errline = __LINE__ + 1 ;
202
- XLA_RETURN_IF_ERROR (error_status, new_message);
287
+ int32_t errline = __LINE__ + 3 ;
288
+ auto test_function = []() -> absl::Status {
289
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
290
+ XLA_RETURN_IF_ERROR (error_status, kNewMessage );
203
291
return absl::OkStatus ();
204
292
};
205
293
206
294
absl::Status result = test_function ();
207
295
ASSERT_FALSE (result.ok ());
208
296
EXPECT_EQ (result.code (), absl::StatusCode::kInvalidArgument );
297
+ EXPECT_EQ (result.message (), std::string_view (kNewMessage ));
209
298
210
299
if (IsShowCppStacktracesMode ()) {
211
- EXPECT_EQ (result.message (),
212
- absl::StrCat (" New test error message (at " , __FILE__, " :" ,
213
- errline, " )\n From Error: Test error message" ));
214
- } else {
215
- EXPECT_EQ (result.message (), std::string_view (new_message));
300
+ EXPECT_EQ (GetStatusPropagationTrace (result),
301
+ absl::StrCat (kEntryPrefix , " From: operator() at " , __FILE__, " :" ,
302
+ errline, " (error: " , kNewMessage , " )" ));
216
303
}
217
304
}
218
305
@@ -233,51 +320,98 @@ TEST_P(StatusTest, MacroAssignOrReturn) {
233
320
234
321
TEST_P (StatusTest, MacroAssignOrReturnWithError) {
235
322
auto test_function = []() -> absl::StatusOr<int > {
236
- absl::StatusOr<int > status_or = absl::InvalidArgumentError (message );
323
+ absl::StatusOr<int > status_or = absl::InvalidArgumentError (kMessage );
237
324
XLA_ASSIGN_OR_RETURN (int value, status_or);
238
325
return value * 2 ;
239
326
};
240
327
241
328
absl::StatusOr<int > result = test_function ();
242
329
ASSERT_FALSE (result.ok ());
243
330
EXPECT_EQ (result.status ().code (), absl::StatusCode::kInvalidArgument );
244
- EXPECT_EQ (result.status ().message (), std::string_view (message ));
331
+ EXPECT_EQ (result.status ().message (), std::string_view (kMessage ));
245
332
}
246
333
247
334
TEST_P (StatusTest, MacroAssignOrReturnWithErrorWithNewMessage) {
248
- int32_t errline = 0 ;
249
-
250
- auto test_function = [&errline]() -> absl::StatusOr<int > {
251
- absl::StatusOr<int > status_or = absl::InvalidArgumentError (message);
252
- errline = __LINE__ + 1 ;
253
- XLA_ASSIGN_OR_RETURN (int value, status_or, new_message);
335
+ int32_t errline = __LINE__ + 3 ;
336
+ auto test_function = []() -> absl::StatusOr<int > {
337
+ absl::StatusOr<int > status_or = absl::InvalidArgumentError (kMessage );
338
+ XLA_ASSIGN_OR_RETURN (int value, status_or, kNewMessage );
254
339
return value * 2 ;
255
340
};
256
341
257
342
absl::StatusOr<int > result = test_function ();
258
343
ASSERT_FALSE (result.ok ());
259
344
EXPECT_EQ (result.status ().code (), absl::StatusCode::kInvalidArgument );
345
+ EXPECT_EQ (result.status ().message (), std::string_view (kNewMessage ));
260
346
261
347
if (IsShowCppStacktracesMode ()) {
262
- EXPECT_EQ (result.status ().message (),
263
- absl::StrCat (" New test error message (at " , __FILE__, " :" ,
264
- errline, " )\n From Error: Test error message" ));
265
- } else {
266
- EXPECT_EQ (result.status ().message (), std::string_view (new_message));
348
+ EXPECT_EQ (GetStatusPropagationTrace (result.status ()),
349
+ absl::StrCat (kEntryPrefix , " From: operator() at " , __FILE__, " :" ,
350
+ errline, " (error: " , kNewMessage , " )" ));
267
351
}
268
352
}
269
353
270
354
TEST_P (StatusTest, MacroErrorWithLocation) {
271
- absl::Status error_status = absl::InvalidArgumentError (message );
355
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
272
356
int32_t errline = __LINE__ + 1 ;
273
357
absl::Status result = XLA_ERROR_WITH_LOCATION (error_status);
274
358
ASSERT_FALSE (result.ok ());
275
359
EXPECT_EQ (result.code (), absl::StatusCode::kInvalidArgument );
360
+ EXPECT_EQ (result.message (), std::string_view (kMessage ));
361
+ if (IsShowCppStacktracesMode ()) {
362
+ EXPECT_EQ (GetStatusPropagationTrace (result),
363
+ absl::StrCat (kEntryPrefix , " From: " , __FUNCTION__, " at " ,
364
+ __FILE__, " :" , errline, " (error: " , kMessage , " )" ));
365
+ }
366
+ }
367
+
368
+ TEST_P (StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
369
+ int32_t errline0 = __LINE__ + 2 ;
370
+ auto innerfn = [&]() -> absl::Status {
371
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (kMessage ));
372
+ };
373
+
374
+ int32_t errline1 = __LINE__ + 2 ;
375
+ auto midfn = [&]() -> absl::Status {
376
+ XLA_RETURN_IF_ERROR (innerfn (), kNewMessage );
377
+ return absl::OkStatus ();
378
+ };
379
+
380
+ int32_t errline2 = __LINE__ + 2 ;
381
+ auto outerfn = [&]() -> absl::Status {
382
+ XLA_RETURN_IF_ERROR (midfn ());
383
+ return absl::OkStatus ();
384
+ };
385
+
386
+ auto throw_exception = [&]() {
387
+ THROW_RUNTIME_ERROR_FROM_C10_ERROR (MaybeThrow (outerfn ()));
388
+ };
389
+
276
390
if (IsShowCppStacktracesMode ()) {
277
- EXPECT_EQ (result.message (), absl::StrCat (" Test error message (at " ,
278
- __FILE__, " :" , errline, " )" ));
391
+ // Expected Error Message Prefix
392
+ // =============================
393
+ //
394
+ // New test error kMessage
395
+ //
396
+ // Status Propagation Stacktrace:
397
+ // From: ./test/cpp/test_status_common.h:329 (error: Test error
398
+ // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test
399
+ // error kMessage) From: ./test/cpp/test_status_common.h:342
400
+ //
401
+ // C++ Stacktrace:
402
+ //
403
+ std::string expected_prefix = absl::StrCat (
404
+ kNewMessage , " \n\n Status Propagation Trace:" , kEntryPrefix ,
405
+ " From: operator() at " , __FILE__, " :" , errline0, " (error: " , kMessage ,
406
+ " )" , kEntryPrefix , " From: operator() at " , __FILE__, " :" , errline1,
407
+ " (error: " , kNewMessage , " )" , kEntryPrefix , " From: operator() at " ,
408
+ __FILE__, " :" , errline2, " \n\n " , kTorchCppStacktracePrefix );
409
+
410
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
411
+ ::testing::StartsWith (expected_prefix)));
279
412
} else {
280
- EXPECT_EQ (result.message (), std::string_view (message));
413
+ EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
414
+ ::testing::Eq (kNewMessage )));
281
415
}
282
416
}
283
417
0 commit comments