diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index 7cb63d4f38a..4d4b173f643 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -18,9 +18,12 @@ #ifndef XLA_TEST_CPP_TEST_STATUS_COMMON_H_ #define XLA_TEST_CPP_TEST_STATUS_COMMON_H_ +#include +#include #include #include +#include #include #include "absl/status/status.h" @@ -30,7 +33,7 @@ namespace torch_xla { -// Enum to control whether C++ error context is shown in status messages +// Enum to control whether C++ error context is shown in status messages. enum class CppStacktracesMode { kShow, kHide, @@ -74,10 +77,45 @@ class StatusTest : public testing::TestWithParam { namespace testing { -constexpr inline char new_message[] = "New test error message"; -constexpr inline char message[] = "Test error message"; -constexpr inline char test_file[] = "test_file.cpp"; -constexpr inline int32_t line = 42; +constexpr inline char kNewMessage[] = "New test error message"; +constexpr inline char kMessage[] = "Test error message"; +constexpr inline char kFile[] = "test_file.cpp"; +constexpr inline char kFunction[] = "foo"; +constexpr inline char kEntryPrefix[] = "\n "; +constexpr inline int32_t kLine = 42; + +// The PyTorch C++ stacktrace is ALWAYS appended to the error message. +// More specifically, when `what()` function is called. +// +// However, it's only when the raised `c10::Error` gets translated to a +// Python exception that PyTorch checks the value of the +// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually +// controls whether the stacktrace will get shown or not by calling +// `what_without_backtraces()`, instead. +// +// Therefore, we need to mimic this behavior. +#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \ + try { \ + block; \ + } catch (const c10::Error& error) { \ + throw std::runtime_error(IsShowCppStacktracesMode() \ + ? error.what() \ + : error.what_without_backtrace()); \ + } + +// Prefix of the C++ stacktrace PyTorch adds to the error message. +constexpr inline char kTorchCppStacktracePrefix[] = + "Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:"; + +inline std::string GetStatusPropagationTrace(const absl::Status& status) { + if (status.ok()) { + return ""; + } + auto status_propagation_trace = status.GetPayload(kStatusPropagationTraceKey); + return status_propagation_trace.has_value() + ? std::string(status_propagation_trace->Flatten()) + : ""; +} TEST_P(StatusTest, MaybeThrowWithOkStatus) { absl::Status ok_status = absl::OkStatus(); @@ -85,8 +123,22 @@ TEST_P(StatusTest, MaybeThrowWithOkStatus) { } TEST_P(StatusTest, MaybeThrowWithErrorStatus) { - absl::Status error_status = absl::InvalidArgumentError(message); - EXPECT_THROW(MaybeThrow(error_status), std::runtime_error); + auto throw_exception = [=]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR({ + absl::Status error_status = absl::InvalidArgumentError(kMessage); + MaybeThrow(error_status); + }); + }; + + if (IsShowCppStacktracesMode()) { + std::string expected_prefix = + absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); + } else { + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kMessage))); + } } TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { @@ -97,44 +149,75 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { } TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) { - absl::StatusOr status_or = absl::InvalidArgumentError(message); - EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error); + auto throw_exception = [=]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR({ + absl::StatusOr error_status = absl::InvalidArgumentError(kMessage); + int value = GetValueOrThrow(error_status); + }); + }; + if (IsShowCppStacktracesMode()) { + std::string expected_prefix = + absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); + } else { + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kMessage))); + } } TEST_P(StatusTest, MaybeWithLocationPropagatesErrorStatus) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithLocation(error_status, test_file, line); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = + status_internal::MaybeWithLocation(error_status, kFile, kLine, kFunction); + + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), error_status.message()); + if (IsShowCppStacktracesMode()) { - ASSERT_NE(result, error_status); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.code(), error_status.code()); - EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)"); + EXPECT_NE(result, error_status); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine, " (error: ", kMessage, ")")); } else { EXPECT_EQ(result, error_status); } } TEST_P(StatusTest, MaybeWithNewMessageEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithNewMessage(error_status, test_file, line); - EXPECT_EQ(result, error_status); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = status_internal::MaybeWithNewMessage( + error_status, kFile, kLine, kFunction); + + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), error_status.message()); + + if (IsShowCppStacktracesMode()) { + EXPECT_NE(result, error_status); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine)); + } else { + EXPECT_EQ(result, error_status); + } } TEST_P(StatusTest, MaybeWithNewMessageNonEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = - MaybeWithNewMessage(error_status, test_file, line, new_message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = status_internal::MaybeWithNewMessage( + error_status, kFile, kLine, kFunction, kNewMessage); - ASSERT_NE(result, error_status); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), std::string_view(kNewMessage)); + EXPECT_NE(result, error_status); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), - absl::StrCat("New test error message (at test_file.cpp:42)\n" - "From Error: Test error message")); - } else { - EXPECT_EQ(result.message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine, " (error: ", kNewMessage, ")")); } } @@ -154,7 +237,7 @@ TEST_P(StatusTest, MacroReturnIfError) { TEST_P(StatusTest, MacroReturnIfErrorWithError) { auto test_function = [=]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); XLA_RETURN_IF_ERROR(error_status); return absl::OkStatus(); }; @@ -162,21 +245,22 @@ TEST_P(StatusTest, MacroReturnIfErrorWithError) { absl::Status result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.message(), std::string_view(message)); + EXPECT_EQ(result.message(), std::string_view(kMessage)); } TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) { - int32_t errline = 0; - auto inner_test_function = [&errline]() -> absl::Status { - errline = __LINE__ + 1; - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message)); + int32_t errline0 = __LINE__ + 2; + auto inner_test_function = []() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); }; + int32_t errline1 = __LINE__ + 2; auto test_function = [&]() -> absl::Status { XLA_RETURN_IF_ERROR(inner_test_function()); return absl::OkStatus(); }; + int32_t errline2 = __LINE__ + 2; auto outer_test_function = [&]() -> absl::Status { XLA_RETURN_IF_ERROR(test_function()); return absl::OkStatus(); @@ -185,34 +269,37 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) { absl::Status result = outer_test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", - __FILE__, ":", errline, ")")); - } else { - EXPECT_EQ(result.message(), std::string_view(message)); + auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline0, " (error: ", kMessage, ")"); + auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline1); + auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline2); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(frame0, frame1, frame2)); } } TEST_P(StatusTest, MacroReturnIfErrorWithErrorWithNewMessage) { - int32_t errline = 0; - auto test_function = [&errline]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(message); - errline = __LINE__ + 1; - XLA_RETURN_IF_ERROR(error_status, new_message); + int32_t errline = __LINE__ + 3; + auto test_function = []() -> absl::Status { + absl::Status error_status = absl::InvalidArgumentError(kMessage); + XLA_RETURN_IF_ERROR(error_status, kNewMessage); return absl::OkStatus(); }; absl::Status result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kNewMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), - absl::StrCat("New test error message (at ", __FILE__, ":", - errline, ")\nFrom Error: Test error message")); - } else { - EXPECT_EQ(result.message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":", + errline, " (error: ", kNewMessage, ")")); } } @@ -233,7 +320,7 @@ TEST_P(StatusTest, MacroAssignOrReturn) { TEST_P(StatusTest, MacroAssignOrReturnWithError) { auto test_function = []() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError(message); + absl::StatusOr status_or = absl::InvalidArgumentError(kMessage); XLA_ASSIGN_OR_RETURN(int value, status_or); return value * 2; }; @@ -241,43 +328,90 @@ TEST_P(StatusTest, MacroAssignOrReturnWithError) { absl::StatusOr result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.status().message(), std::string_view(message)); + EXPECT_EQ(result.status().message(), std::string_view(kMessage)); } TEST_P(StatusTest, MacroAssignOrReturnWithErrorWithNewMessage) { - int32_t errline = 0; - - auto test_function = [&errline]() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError(message); - errline = __LINE__ + 1; - XLA_ASSIGN_OR_RETURN(int value, status_or, new_message); + int32_t errline = __LINE__ + 3; + auto test_function = []() -> absl::StatusOr { + absl::StatusOr status_or = absl::InvalidArgumentError(kMessage); + XLA_ASSIGN_OR_RETURN(int value, status_or, kNewMessage); return value * 2; }; absl::StatusOr result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), std::string_view(kNewMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.status().message(), - absl::StrCat("New test error message (at ", __FILE__, ":", - errline, ")\nFrom Error: Test error message")); - } else { - EXPECT_EQ(result.status().message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result.status()), + absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":", + errline, " (error: ", kNewMessage, ")")); } } TEST_P(StatusTest, MacroErrorWithLocation) { - absl::Status error_status = absl::InvalidArgumentError(message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); int32_t errline = __LINE__ + 1; absl::Status result = XLA_ERROR_WITH_LOCATION(error_status); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kMessage)); + if (IsShowCppStacktracesMode()) { + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", __FUNCTION__, " at ", + __FILE__, ":", errline, " (error: ", kMessage, ")")); + } +} + +TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) { + int32_t errline0 = __LINE__ + 2; + auto innerfn = [&]() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); + }; + + int32_t errline1 = __LINE__ + 2; + auto midfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(innerfn(), kNewMessage); + return absl::OkStatus(); + }; + + int32_t errline2 = __LINE__ + 2; + auto outerfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(midfn()); + return absl::OkStatus(); + }; + + auto throw_exception = [&]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn())); + }; + if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", - __FILE__, ":", errline, ")")); + // Expected Error Message Prefix + // ============================= + // + // New test error kMessage + // + // Status Propagation Stacktrace: + // From: ./test/cpp/test_status_common.h:329 (error: Test error + // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test + // error kMessage) From: ./test/cpp/test_status_common.h:342 + // + // C++ Stacktrace: + // + std::string expected_prefix = absl::StrCat( + kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix, + "From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage, + ")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1, + " (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ", + __FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix); + + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); } else { - EXPECT_EQ(result.message(), std::string_view(message)); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kNewMessage))); } } diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 6c34eca1450..31ab65dbbca 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -377,6 +377,7 @@ cc_library( hdrs = ["status.h"], deps = [ "@torch//:headers", + "@tsl//tsl/platform:stacktrace", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", ], diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 1eb3511cb33..270f3487867 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -1,19 +1,49 @@ #include "torch_xla/csrc/status.h" +#include #include +#include +#include +#include + #include "absl/log/absl_check.h" +#include "tsl/platform/stacktrace.h" namespace torch_xla { -// Common function for generating file location information with a space in the -// beginning. -static std::string LocationStrWithSpace(const char* file, const int32_t line) { - return absl::StrCat(" (at ", file, ":", line, ")"); +// Indent the stack frame representation so that it's easier to see. +constexpr char kFramePrefix[] = "\n "; + +// Creates the stack frame representation for the status propagation trace +// entry. +// +// The resulting string will be appended to the existing status propagation +// trace of the status currently being processed. +// +// Example: +// \n From: at : [(error: )] +// +static std::string GetStackFrame(const char* file, const int32_t line, + const char* function, + const std::string_view new_message) { + auto error_suffix = + new_message.empty() ? "" : absl::StrCat(" (error: ", new_message, ")"); + return absl::StrCat(kFramePrefix, "From: ", function, " at ", file, ":", line, + error_suffix); +} + +// Convenient function that retrieves the status propagation trace payload +// if it exists. Otherwise, returns an empty absl::Cord. +static absl::Cord GetStatusPropagationTraceOrEmpty(const absl::Status& status) { + auto opt = status.GetPayload(kStatusPropagationTraceKey); + return opt.has_value() ? *opt : absl::Cord(); } -absl::Status MaybeWithLocation(const absl::Status& status, const char* file, - const int32_t line) { +absl::Status status_internal::MaybeWithLocation(const absl::Status& status, + const char* file, + const int32_t line, + const char* function) { ABSL_CHECK(!status.ok()); // Return the same status if we don't need to add the C++ source location. @@ -21,14 +51,19 @@ absl::Status MaybeWithLocation(const absl::Status& status, const char* file, return status; } - return absl::Status( - status.code(), - absl::StrCat(status.message(), LocationStrWithSpace(file, line))); + // Make sure this is only called on fresh `status` instances. + ABSL_CHECK(GetStatusPropagationTraceOrEmpty(status).empty()); + + // Adding source location to `status` has the same semantics as overwriting + // the status message: + // 1. An stack frame will be added to the status propagation trace + // 2. The status' message will be the same + return MaybeWithNewMessage(status, file, line, function, status.message()); } -absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, - const int32_t line, - const std::string_view new_message) { +absl::Status status_internal::MaybeWithNewMessage( + const absl::Status& status, const char* file, const int32_t line, + const char* function, const std::string_view new_message) { ABSL_CHECK(!status.ok()); // Return the same status if: @@ -38,39 +73,55 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, return status; } - std::string_view old_message = status.message(); - // Replace the old status message with `new_message`, if it's not empty. // // The idea is that whenever `new_message` is given, it should have more // context to give a better error message to the user. - std::string_view message = new_message.empty() ? old_message : new_message; + auto new_status = absl::Status( + status.code(), new_message.empty() ? status.message() : new_message); - // If `TORCH_SHOW_CPP_STACKTRACES` is set, show the context of this error. - // In other words, show: - // 1. The error location - // 2. The old messages that were replaced by `new_message`. + // If `TORCH_SHOW_CPP_STACKTRACES` is set: // - // This should give more context for developers. Showing the older error - // messages alongside their debug information. + // 1. append the current stack frame to the status propagation trace + // payload // - // Note that we also condition showing source location information by (2) - // (i.e. `new_message` is not empty) because we don't really wish to show - // a stacktrace. Instead, we show only the history of error messages that - // has led to the current error. - const std::string context = - (torch::get_cpp_stacktraces_enabled() && !new_message.empty()) - ? absl::StrCat(LocationStrWithSpace(file, line), - "\nFrom Error: ", old_message) - : ""; - - return absl::Status(status.code(), absl::StrCat(message, context)); + // 2. append the new error message, if not empty + if (torch::get_cpp_stacktraces_enabled()) { + auto status_propagation_trace = GetStatusPropagationTraceOrEmpty(status); + status_propagation_trace.Append( + GetStackFrame(file, line, function, new_message)); + new_status.SetPayload(kStatusPropagationTraceKey, status_propagation_trace); + } + + return new_status; +} + +// Get a formatted string representation of the status propagation trace +// if it's not empty. +static std::string GetFormattedStatusPropagationTrace( + const absl::Status& status) { + auto status_propagation_trace = GetStatusPropagationTraceOrEmpty(status); + return status_propagation_trace.empty() + ? "" + : absl::StrCat("\nStatus Propagation Trace:", + status_propagation_trace.Flatten(), "\n"); +} + +// Get the status message followed by a line break, if we are printing the +// C++ stacktraces. +// +// This is needed so we have a blank line in between the status message and +// the dumped C++ traces (either the status propagation one, or the C++ +// stacktrace). +static std::string MaybeGetMessageWithLineBreak(const absl::Status& status) { + return torch::get_cpp_stacktraces_enabled() + ? absl::StrCat(status.message(), "\n") + : std::string(status.message()); } void MaybeThrow(const absl::Status& status) { - if (!status.ok()) { - throw std::runtime_error(std::string(status.message())); - } + TORCH_CHECK(status.ok(), MaybeGetMessageWithLineBreak(status), + GetFormattedStatusPropagationTrace(status)); } } // namespace torch_xla diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index d64a78ba58d..2f53b37381f 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -14,6 +14,24 @@ namespace torch_xla { +// `type_url` for retrieving the status propagation trace payload of a given +// status. +// +// The payload is composed of multiple lines, where each line represents a stack +// frame in the status propagation trace. Each line is in the following format: +// +// \n From: :[ErrorSuffix] +// | ---- | +// | | |_ error message produced in that source +// | | location (it might be overwritten later). +// | | +// | |_ leading 4 spaces for improved readability. +// | +// |_ start with a line break. +// +constexpr char kStatusPropagationTraceKey[] = + "type.googleapis.com/torch_xla.status_trace"; + // If `TORCH_SHOW_CPP_STACKTRACES` is set, creates a new Status instance, // appending the current location (e.g. file and line information) to the // status message. @@ -28,10 +46,12 @@ namespace torch_xla { // // If `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be: // -// Error message. (at :) +// RuntimeError: Error message. +// From: : (error: Error message.) // -#define XLA_ERROR_WITH_LOCATION(status) \ - ::torch_xla::MaybeWithLocation(status, __FILE__, __LINE__) +#define XLA_ERROR_WITH_LOCATION(status) \ + ::torch_xla::status_internal::MaybeWithLocation(status, __FILE__, __LINE__, \ + __FUNCTION__) #define XLA_CONCAT_(a, b) XLA_CONCAT_IMPL_(a, b) #define XLA_CONCAT_IMPL_(a, b) a##b @@ -41,15 +61,15 @@ namespace torch_xla { // Provides a flexible way to handle error checking with optional message // modification. It evaluates `expr`, checks if it's OK, and either: -// 1. Returns early with an error status (potentially modified by the provided -// additional messages) -// 2. Proceeds with the given `then` block if successful -#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ - auto var = (expr); \ - if (!var.ok()) { \ - return ::torch_xla::MaybeWithNewMessage( \ - ::torch_xla::GetStatus(var), __FILE__, __LINE__, ##__VA_ARGS__); \ - } \ +// 1. Returns early with an error status +// 2. Proceeds with the given `then` block if successful +#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ + auto var = (expr); \ + if (!var.ok()) { \ + return ::torch_xla::status_internal::MaybeWithNewMessage( \ + ::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \ + __FUNCTION__, ##__VA_ARGS__); \ + } \ then // Propagates `rexpr`, in case it's a non-ok status. @@ -65,9 +85,13 @@ namespace torch_xla { // we early return a non-ok status. Then, if `TORCH_SHOW_CPP_STACKTRACES` is // set, the error shown will be: // -// New error message. (at :) -// Previous error message. (at :) -// ... +// RuntimeError: New error message. +// +// Status Propagation Stacktrace: +// ... +// From: : (error: Previous error message.) +// ... +// From: : (error: New error message.) // #define XLA_RETURN_IF_ERROR(rexpr, ...) \ do { \ @@ -93,26 +117,29 @@ namespace torch_xla { // If the function call results in an ok status, execution continues with // `result` set to `ret.value()`, where `ret` is the returned value of the // function. Otherwise, we early return a non-ok status. Then, if -// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be: -// -// New error message. (at :) -// Previous error message. (at :) -// ... +// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be similar to +// the one above. // #define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \ XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, \ lhs = std::move(XLA_STATUS_VAR_).value(), \ ##__VA_ARGS__) -// Maybe shows location information in the status message. +namespace status_internal { + +// Adds source location information to the status propagation trace if +// `TORCH_SHOW_CPP_STACKTRACES` is set. // -// This function assumes that `status` is a non-ok status. +// This function assumes that: +// +// 1. `status` is a non-ok status. +// 2. `status` doesn't have a status propagation trace payload +// +// If any of the above assumptions is false, this function crashes the +// whole program. // -// If `TORCH_SHOW_CPP_STACKTRACES` is set, appends the current source -// location information to the status message. Otherwise, it simply returns -// `status`. absl::Status MaybeWithLocation(const absl::Status& status, const char* file, - int32_t line); + int32_t line, const char* function); // Returns an `absl::Status` from an `absl::Status`. // In this case, this function is a no-op. It simply returns the argument. @@ -126,7 +153,8 @@ const absl::Status& GetStatus(const absl::StatusOr& status) { return status.status(); } -// Maybe replace the current `status` message with `new_message`. +// Maybe replace the current `status` message with `new_message`, and also +// add source location information if enabled. // // This function assumes that `status` is a non-ok status. // @@ -137,12 +165,15 @@ const absl::Status& GetStatus(const absl::StatusOr& status) { // Rationale: if given, `new_message` has more context, which makes it possible // to construct better error messages to the user. // -// This function also appends file location information to the error message, if +// This function also appends the source location information to the status +// propagation trace payload (creates a new one if needed), if // `TORCH_SHOW_CPP_STACKTRACES` is set. absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, - int32_t line, + int32_t line, const char* function, std::string_view new_message = ""); +} // namespace status_internal + // Maybe throws an exception if `status` has a non-ok code. // // Ideally, this function should be used only used in the project's