Skip to content

Commit a9022b7

Browse files
authored
Improve error handling in ThreadPool (#6011)
* Bug fix: clear all errors after rethrowing * Semantics improvement - rethrow the actual exception instead of just keeping the message. It also helps with errors not derived from `std::exception` * Performance: remove unnecessary mutex guarding the per-thread error lists. Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 0158ef1 commit a9022b7

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

dali/pipeline/util/thread_pool.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,16 @@ void ThreadPool::WaitForWork(bool checkForErrors) {
9090
started_ = false;
9191
if (checkForErrors) {
9292
// Check for errors
93-
std::lock_guard lock(error_mutex_);
93+
std::exception_ptr err;
9494
for (size_t i = 0; i < threads_.size(); ++i) {
95-
if (!tl_errors_[i].empty()) {
95+
if (!err && !tl_errors_[i].empty()) {
9696
// Throw the first error that occurred
97-
string error = make_string("Error in thread ", i, ": ", tl_errors_[i].front());
98-
tl_errors_[i].pop();
99-
throw std::runtime_error(error);
97+
err = std::move(tl_errors_[i].front());
10098
}
99+
tl_errors_[i] = {};
101100
}
101+
if (err)
102+
std::rethrow_exception(err);
102103
}
103104
}
104105

@@ -150,12 +151,8 @@ void ThreadPool::ThreadMain(int thread_id, int device_id, bool set_affinity,
150151
nvml::SetCPUAffinity(core);
151152
}
152153
#endif
153-
} catch (std::exception &e) {
154-
std::lock_guard lock(error_mutex_);
155-
tl_errors_[thread_id].push(e.what());
156154
} catch (...) {
157-
std::lock_guard lock(error_mutex_);
158-
tl_errors_[thread_id].push("Caught unknown exception");
155+
tl_errors_[thread_id].push(std::current_exception());
159156
}
160157

161158
while (running_) {
@@ -179,12 +176,8 @@ void ThreadPool::ThreadMain(int thread_id, int device_id, bool set_affinity,
179176
// in the threads and return an error if one occured.
180177
try {
181178
work(thread_id);
182-
} catch (std::exception &e) {
183-
std::lock_guard lock(error_mutex_);
184-
tl_errors_[thread_id].push(e.what());
185179
} catch (...) {
186-
std::lock_guard lock(error_mutex_);
187-
tl_errors_[thread_id].push("Caught unknown exception");
180+
tl_errors_[thread_id].push(std::current_exception());
188181
}
189182

190183
// The task is now complete - we can atomically decrement the number of outstanding work.

dali/pipeline/util/thread_pool.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <queue>
2525
#include <thread>
2626
#include <vector>
27+
#include <stdexcept>
2728
#include <string>
2829
#include "dali/core/common.h"
2930
#if NVML_ENABLED
@@ -94,11 +95,11 @@ class DLL_PUBLIC ThreadPool {
9495
bool running_ = true;
9596
bool started_ = false;
9697
alignas(64) std::atomic_int outstanding_work_{0};
97-
std::mutex error_mutex_, completed_mutex_;
98+
std::mutex completed_mutex_;
9899
std::condition_variable completed_;
99100

100-
// Stored error strings for each thread
101-
vector<std::queue<string>> tl_errors_;
101+
// Stored errors for each thread
102+
vector<std::queue<std::exception_ptr>> tl_errors_;
102103
#if NVML_ENABLED
103104
nvml::NvmlInstance nvml_handle_;
104105
#endif

0 commit comments

Comments
 (0)