Skip to content

Commit 0350c7e

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Introduce torch.AcceleratorError (pytorch#152023)
Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError` `torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's [`PyErr_SetString`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L282), namely - Convert cstr into Python string with `PyUnicode_FromString` - Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](https://github.com/python/cpython/blob/cb8a72b301f47e76d93a7fe5b259e9a5758792e1/Python/errors.c#L32) - Set `error_code` property using `PyObject_SetAttrString` - decref all temporary references Test that it works and captures CPP backtrace (in addition to CI) by running ```python import os os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' import torch x = torch.rand(10, device="cuda") y = torch.arange(20, device="cuda") try: x[y] = 2 print(x) except torch.AcceleratorError as e: print("Exception was raised", e.args[0]) print("Captured error code is ", e.error_code) ``` which produces following output ``` Exception was raised CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first): C++ CapturedTraceback: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 #5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 #6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0 #7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0 #8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0 #9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0 #10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0 #11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0 #12 at::native::abs(at::Tensor const&) from ??:0 #13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0 #14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0 #15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0 #16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0 #17 at::_ops::abs::call(at::Tensor const&) from ??:0 #18 at::native::isfinite(at::Tensor const&) from ??:0 #19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0 #20 at::_ops::isfinite::call(at::Tensor const&) from ??:0 #21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0 #22 PyObject_CallFunctionObjArgs from ??:0 #23 _PyObject_MakeTpCall from ??:0 #24 _PyEval_EvalFrameDefault from ??:0 pytorch#25 _PyObject_FastCallDictTstate from ??:0 pytorch#26 _PyStack_AsDict from ??:0 pytorch#27 _PyObject_MakeTpCall from ??:0 pytorch#28 _PyEval_EvalFrameDefault from ??:0 pytorch#29 _PyFunction_Vectorcall from ??:0 pytorch#30 _PyEval_EvalFrameDefault from ??:0 pytorch#31 _PyFunction_Vectorcall from ??:0 pytorch#32 _PyEval_EvalFrameDefault from ??:0 pytorch#33 _PyFunction_Vectorcall from ??:0 pytorch#34 _PyEval_EvalFrameDefault from ??:0 pytorch#35 PyFrame_GetCode from ??:0 pytorch#36 PyNumber_Xor from ??:0 pytorch#37 PyObject_Str from ??:0 pytorch#38 PyFile_WriteObject from ??:0 pytorch#39 _PyWideStringList_AsList from ??:0 pytorch#40 _PyDict_NewPresized from ??:0 pytorch#41 _PyEval_EvalFrameDefault from ??:0 pytorch#42 PyEval_EvalCode from ??:0 pytorch#43 PyEval_EvalCode from ??:0 pytorch#44 PyUnicode_Tailmatch from ??:0 pytorch#45 PyInit__collections from ??:0 pytorch#46 PyUnicode_Tailmatch from ??:0 pytorch#47 _PyRun_SimpleFileObject from ??:0 pytorch#48 _PyRun_AnyFileObject from ??:0 pytorch#49 Py_RunMain from ??:0 pytorch#50 Py_BytesMain from ??:0 pytorch#51 __libc_init_first from ??:0 pytorch#52 __libc_start_main from ??:0 pytorch#53 _start from ??:0 Captured error code is 710 ``` Pull Request resolved: pytorch#152023 Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel ghstack dependencies: pytorch#154436
1 parent f7c09f8 commit 0350c7e

File tree

9 files changed

+59
-4
lines changed

9 files changed

+59
-4
lines changed

c10/cuda/CUDAException.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ void c10_cuda_check_implementation(
3838
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
3939
}
4040
#endif
41-
42-
TORCH_CHECK(false, check_message);
41+
throw c10::AcceleratorError(
42+
{__func__, __FILE__, int32_t(__LINE__)}, err, check_message);
4343
}
4444

4545
} // namespace c10::cuda

c10/util/Exception.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,19 @@ class C10_API SyntaxError : public Error {
295295
using Error::Error;
296296
};
297297

298+
// Raised when accelerator API call hits an error.
299+
// These turn into AcceleratorError when the cross into Python
300+
class C10_API AcceleratorError : public Error {
301+
int32_t error_code;
302+
303+
public:
304+
AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg)
305+
: Error(loc, msg), error_code(code) {}
306+
int32_t get_error_code() const {
307+
return error_code;
308+
}
309+
};
310+
298311
// Base error type for all distributed errors.
299312
// These turn into DistError when they cross into Python.
300313
class C10_API DistError : public Error {

docs/source/cuda.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ torch.cuda
4040
temperature
4141
power_draw
4242
clock_rate
43+
AcceleratorError
4344
OutOfMemoryError
4445

4546
Random Number Generator

test/test_cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,8 @@ def _spawn_method(self, method, arg):
13871387
for e in errors:
13881388
if "device-side assert triggered" not in str(e):
13891389
self.fail(e)
1390+
if e.error_code != 710: # cudaErrorAssert == 710
1391+
self.fail(e)
13901392

13911393
@staticmethod
13921394
def _test_index_bounds_cuda(idx):

test/test_public_bindings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_no_new_bindings(self):
5959
#
6060
# {elem for elem in dir(torch._C) if not elem.startswith("_")}
6161
torch_C_allowlist_superset = {
62+
"AcceleratorError",
6263
"AggregationType",
6364
"AliasDb",
6465
"AnyType",

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2628,6 +2628,7 @@ def _will_engine_execute_node(node: _Node) -> _bool: ...
26282628
def _dispatch_key_set(tensor) -> str: ...
26292629

26302630
# Defined in torch/csrc/Exceptions.cpp
2631+
class AcceleratorError(RuntimeError): ...
26312632
class OutOfMemoryError(RuntimeError): ...
26322633
class _DistError(RuntimeError): ...
26332634
class _DistBackendError(RuntimeError): ...

torch/csrc/Exceptions.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
PyObject *THPException_FatalError, *THPException_LinAlgError,
1515
*THPException_OutOfMemoryError, *THPException_DistError,
1616
*THPException_DistBackendError, *THPException_DistNetworkError,
17-
*THPException_DistStoreError, *THPException_DistQueueEmptyError;
17+
*THPException_DistStoreError, *THPException_DistQueueEmptyError,
18+
*THPException_AcceleratorError;
1819

1920
#define ASSERT_TRUE(cond) \
2021
if (!(cond)) \
@@ -125,6 +126,18 @@ could not be completed because the input matrix is singular.",
125126
module, "_DistQueueEmptyError", THPException_DistQueueEmptyError) ==
126127
0);
127128

129+
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
130+
ASSERT_TRUE(
131+
THPException_AcceleratorError = PyErr_NewExceptionWithDoc(
132+
"torch.AcceleratorError",
133+
"Exception raised while executing on device",
134+
PyExc_RuntimeError,
135+
nullptr));
136+
type = (PyTypeObject*)THPException_AcceleratorError;
137+
ASSERT_TRUE(
138+
PyModule_AddObject(
139+
module, "AcceleratorError", THPException_AcceleratorError) == 0);
140+
128141
return true;
129142
}
130143

@@ -341,4 +354,18 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
341354
}
342355
}
343356

357+
namespace detail {
358+
PyObject* _new_accelerator_error_object(const c10::AcceleratorError& e) {
359+
auto msg = torch::get_cpp_stacktraces_enabled() ? e.what()
360+
: e.what_without_backtrace();
361+
362+
auto py_msg = PyUnicode_FromString(msg);
363+
auto rc = PyObject_CallOneArg(THPException_AcceleratorError, py_msg);
364+
auto error_code = PyInt_FromLong(e.get_error_code());
365+
PyObject_SetAttrString(rc, "error_code", error_code);
366+
Py_XDECREF(py_msg);
367+
Py_XDECREF(error_code);
368+
return rc;
369+
}
370+
} // namespace detail
344371
} // namespace torch

torch/csrc/Exceptions.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
8686
DistQueueEmptyError, THPException_DistQueueEmptyError, retstmnt) \
8787
_CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
8888
_CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \
89+
catch (c10::AcceleratorError & e) { \
90+
auto exc = torch::detail::_new_accelerator_error_object(e); \
91+
PyErr_SetObject(THPException_AcceleratorError, exc); \
92+
Py_XDECREF(exc); \
93+
retstmnt; \
94+
} \
8995
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
9096
catch (torch::PyTorchError & e) { \
9197
auto msg = torch::processErrorMsg(e.what()); \
@@ -141,7 +147,8 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
141147
extern PyObject *THPException_FatalError, *THPException_LinAlgError,
142148
*THPException_OutOfMemoryError, *THPException_DistError,
143149
*THPException_DistBackendError, *THPException_DistNetworkError,
144-
*THPException_DistStoreError, *THPException_DistQueueEmptyError;
150+
*THPException_DistStoreError, *THPException_DistQueueEmptyError,
151+
*THPException_AcceleratorError;
145152

146153
// Throwing this exception means that the python error flags have been already
147154
// set and control should be immediately returned to the interpreter.
@@ -369,6 +376,8 @@ auto wrap_pybind_function_impl_(
369376
END_HANDLE_TH_ERRORS_PYBIND
370377
};
371378
}
379+
380+
PyObject* _new_accelerator_error_object(const c10::AcceleratorError&);
372381
} // namespace detail
373382

374383
// Wrap a function with TH error and warning handling.

torch/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ class DeferredCudaCallError(Exception):
333333
pass
334334

335335

336+
AcceleratorError = torch._C.AcceleratorError
336337
OutOfMemoryError = torch._C.OutOfMemoryError
337338

338339

0 commit comments

Comments
 (0)