Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 198 additions & 64 deletions test/cpp/test_status_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
#ifndef XLA_TEST_CPP_TEST_STATUS_COMMON_H_
#define XLA_TEST_CPP_TEST_STATUS_COMMON_H_

#include <c10/util/Exception.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <cstdlib>
#include <stdexcept>
#include <utility>

#include "absl/status/status.h"
Expand All @@ -30,7 +33,7 @@

namespace torch_xla {

// Enum to control whether C++ error context is shown in status messages
// Enum to control whether C++ error context is shown in status messages.
enum class CppStacktracesMode {
kShow,
kHide,
Expand Down Expand Up @@ -74,19 +77,68 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {

namespace testing {

constexpr inline char new_message[] = "New test error message";
constexpr inline char message[] = "Test error message";
constexpr inline char test_file[] = "test_file.cpp";
constexpr inline int32_t line = 42;
constexpr inline char kNewMessage[] = "New test error message";
constexpr inline char kMessage[] = "Test error message";
constexpr inline char kFile[] = "test_file.cpp";
constexpr inline char kFunction[] = "foo";
constexpr inline char kEntryPrefix[] = "\n ";
constexpr inline int32_t kLine = 42;

// The PyTorch C++ stacktrace is ALWAYS appended to the error message.
// More specifically, when `what()` function is called.
//
// However, it's only when the raised `c10::Error` gets translated to a
// Python exception that PyTorch checks the value of the
// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually
// controls whether the stacktrace will get shown or not by calling
// `what_without_backtraces()`, instead.
//
// Therefore, we need to mimic this behavior.
#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \
try { \
block; \
} catch (const c10::Error& error) { \
throw std::runtime_error(IsShowCppStacktracesMode() \
? error.what() \
: error.what_without_backtrace()); \
}

// Prefix of the C++ stacktrace PyTorch adds to the error message.
constexpr inline char kTorchCppStacktracePrefix[] =
"Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:";

inline std::string GetStatusPropagationTrace(const absl::Status& status) {
if (status.ok()) {
return "";
}
auto status_propagation_trace = status.GetPayload(kStatusPropagationTraceKey);
return status_propagation_trace.has_value()
? std::string(status_propagation_trace->Flatten())
: "";
}

TEST_P(StatusTest, MaybeThrowWithOkStatus) {
absl::Status ok_status = absl::OkStatus();
EXPECT_NO_THROW(MaybeThrow(ok_status));
}

TEST_P(StatusTest, MaybeThrowWithErrorStatus) {
absl::Status error_status = absl::InvalidArgumentError(message);
EXPECT_THROW(MaybeThrow(error_status), std::runtime_error);
auto throw_exception = [=]() {
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
absl::Status error_status = absl::InvalidArgumentError(kMessage);
MaybeThrow(error_status);
});
};

if (IsShowCppStacktracesMode()) {
std::string expected_prefix =
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::StartsWith(expected_prefix)));
} else {
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::Eq(kMessage)));
}
}

TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
Expand All @@ -97,44 +149,75 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
}

TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) {
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error);
auto throw_exception = [=]() {
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
absl::StatusOr<int> error_status = absl::InvalidArgumentError(kMessage);
int value = GetValueOrThrow(error_status);
});
};
if (IsShowCppStacktracesMode()) {
std::string expected_prefix =
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::StartsWith(expected_prefix)));
} else {
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::Eq(kMessage)));
}
}

TEST_P(StatusTest, MaybeWithLocationPropagatesErrorStatus) {
absl::Status error_status = absl::InvalidArgumentError(message);
absl::Status result = MaybeWithLocation(error_status, test_file, line);
absl::Status error_status = absl::InvalidArgumentError(kMessage);
absl::Status result =
status_internal::MaybeWithLocation(error_status, kFile, kLine, kFunction);

ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), error_status.code());
EXPECT_EQ(result.message(), error_status.message());

if (IsShowCppStacktracesMode()) {
ASSERT_NE(result, error_status);
EXPECT_FALSE(result.ok());
EXPECT_EQ(result.code(), error_status.code());
EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)");
EXPECT_NE(result, error_status);
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
":", kLine, " (error: ", kMessage, ")"));
} else {
EXPECT_EQ(result, error_status);
}
}

TEST_P(StatusTest, MaybeWithNewMessageEmptyNewMessage) {
absl::Status error_status = absl::InvalidArgumentError(message);
absl::Status result = MaybeWithNewMessage(error_status, test_file, line);
EXPECT_EQ(result, error_status);
absl::Status error_status = absl::InvalidArgumentError(kMessage);
absl::Status result = status_internal::MaybeWithNewMessage(
error_status, kFile, kLine, kFunction);

ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), error_status.code());
EXPECT_EQ(result.message(), error_status.message());

if (IsShowCppStacktracesMode()) {
EXPECT_NE(result, error_status);
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
":", kLine));
} else {
EXPECT_EQ(result, error_status);
}
}

TEST_P(StatusTest, MaybeWithNewMessageNonEmptyNewMessage) {
absl::Status error_status = absl::InvalidArgumentError(message);
absl::Status result =
MaybeWithNewMessage(error_status, test_file, line, new_message);
absl::Status error_status = absl::InvalidArgumentError(kMessage);
absl::Status result = status_internal::MaybeWithNewMessage(
error_status, kFile, kLine, kFunction, kNewMessage);

ASSERT_NE(result, error_status);
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), error_status.code());
EXPECT_EQ(result.message(), std::string_view(kNewMessage));
EXPECT_NE(result, error_status);

if (IsShowCppStacktracesMode()) {
EXPECT_EQ(result.message(),
absl::StrCat("New test error message (at test_file.cpp:42)\n"
"From Error: Test error message"));
} else {
EXPECT_EQ(result.message(), std::string_view(new_message));
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile,
":", kLine, " (error: ", kNewMessage, ")"));
}
}

Expand All @@ -154,29 +237,30 @@ TEST_P(StatusTest, MacroReturnIfError) {

TEST_P(StatusTest, MacroReturnIfErrorWithError) {
auto test_function = [=]() -> absl::Status {
absl::Status error_status = absl::InvalidArgumentError(message);
absl::Status error_status = absl::InvalidArgumentError(kMessage);
XLA_RETURN_IF_ERROR(error_status);
return absl::OkStatus();
};

absl::Status result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(message));
EXPECT_EQ(result.message(), std::string_view(kMessage));
}

TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
int32_t errline = 0;
auto inner_test_function = [&errline]() -> absl::Status {
errline = __LINE__ + 1;
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message));
int32_t errline0 = __LINE__ + 2;
auto inner_test_function = []() -> absl::Status {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
};

int32_t errline1 = __LINE__ + 2;
auto test_function = [&]() -> absl::Status {
XLA_RETURN_IF_ERROR(inner_test_function());
return absl::OkStatus();
};

int32_t errline2 = __LINE__ + 2;
auto outer_test_function = [&]() -> absl::Status {
XLA_RETURN_IF_ERROR(test_function());
return absl::OkStatus();
Expand All @@ -185,34 +269,37 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
absl::Status result = outer_test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(kMessage));

if (IsShowCppStacktracesMode()) {
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ",
__FILE__, ":", errline, ")"));
} else {
EXPECT_EQ(result.message(), std::string_view(message));
auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
":", errline0, " (error: ", kMessage, ")");
auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
":", errline1);
auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
":", errline2);
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(frame0, frame1, frame2));
}
}

TEST_P(StatusTest, MacroReturnIfErrorWithErrorWithNewMessage) {
int32_t errline = 0;
auto test_function = [&errline]() -> absl::Status {
absl::Status error_status = absl::InvalidArgumentError(message);
errline = __LINE__ + 1;
XLA_RETURN_IF_ERROR(error_status, new_message);
int32_t errline = __LINE__ + 3;
auto test_function = []() -> absl::Status {
absl::Status error_status = absl::InvalidArgumentError(kMessage);
XLA_RETURN_IF_ERROR(error_status, kNewMessage);
return absl::OkStatus();
};

absl::Status result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(kNewMessage));

if (IsShowCppStacktracesMode()) {
EXPECT_EQ(result.message(),
absl::StrCat("New test error message (at ", __FILE__, ":",
errline, ")\nFrom Error: Test error message"));
} else {
EXPECT_EQ(result.message(), std::string_view(new_message));
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":",
errline, " (error: ", kNewMessage, ")"));
}
}

Expand All @@ -233,51 +320,98 @@ TEST_P(StatusTest, MacroAssignOrReturn) {

TEST_P(StatusTest, MacroAssignOrReturnWithError) {
auto test_function = []() -> absl::StatusOr<int> {
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
absl::StatusOr<int> status_or = absl::InvalidArgumentError(kMessage);
XLA_ASSIGN_OR_RETURN(int value, status_or);
return value * 2;
};

absl::StatusOr<int> result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.status().message(), std::string_view(message));
EXPECT_EQ(result.status().message(), std::string_view(kMessage));
}

TEST_P(StatusTest, MacroAssignOrReturnWithErrorWithNewMessage) {
int32_t errline = 0;

auto test_function = [&errline]() -> absl::StatusOr<int> {
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
errline = __LINE__ + 1;
XLA_ASSIGN_OR_RETURN(int value, status_or, new_message);
int32_t errline = __LINE__ + 3;
auto test_function = []() -> absl::StatusOr<int> {
absl::StatusOr<int> status_or = absl::InvalidArgumentError(kMessage);
XLA_ASSIGN_OR_RETURN(int value, status_or, kNewMessage);
return value * 2;
};

absl::StatusOr<int> result = test_function();
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.status().message(), std::string_view(kNewMessage));

if (IsShowCppStacktracesMode()) {
EXPECT_EQ(result.status().message(),
absl::StrCat("New test error message (at ", __FILE__, ":",
errline, ")\nFrom Error: Test error message"));
} else {
EXPECT_EQ(result.status().message(), std::string_view(new_message));
EXPECT_EQ(GetStatusPropagationTrace(result.status()),
absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":",
errline, " (error: ", kNewMessage, ")"));
}
}

TEST_P(StatusTest, MacroErrorWithLocation) {
absl::Status error_status = absl::InvalidArgumentError(message);
absl::Status error_status = absl::InvalidArgumentError(kMessage);
int32_t errline = __LINE__ + 1;
absl::Status result = XLA_ERROR_WITH_LOCATION(error_status);
ASSERT_FALSE(result.ok());
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(result.message(), std::string_view(kMessage));
if (IsShowCppStacktracesMode()) {
EXPECT_EQ(GetStatusPropagationTrace(result),
absl::StrCat(kEntryPrefix, "From: ", __FUNCTION__, " at ",
__FILE__, ":", errline, " (error: ", kMessage, ")"));
}
}

TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
int32_t errline0 = __LINE__ + 2;
auto innerfn = [&]() -> absl::Status {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
};

int32_t errline1 = __LINE__ + 2;
auto midfn = [&]() -> absl::Status {
XLA_RETURN_IF_ERROR(innerfn(), kNewMessage);
return absl::OkStatus();
};

int32_t errline2 = __LINE__ + 2;
auto outerfn = [&]() -> absl::Status {
XLA_RETURN_IF_ERROR(midfn());
return absl::OkStatus();
};

auto throw_exception = [&]() {
THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn()));
};

if (IsShowCppStacktracesMode()) {
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ",
__FILE__, ":", errline, ")"));
// Expected Error Message Prefix
// =============================
//
// New test error kMessage
//
// Status Propagation Stacktrace:
// From: ./test/cpp/test_status_common.h:329 (error: Test error
// kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test
// error kMessage) From: ./test/cpp/test_status_common.h:342
//
// C++ Stacktrace:
//
std::string expected_prefix = absl::StrCat(
kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix,
"From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage,
")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1,
" (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ",
__FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix);

EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::StartsWith(expected_prefix)));
} else {
EXPECT_EQ(result.message(), std::string_view(message));
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
::testing::Eq(kNewMessage)));
}
}

Expand Down
Loading
Loading