Skip to content

Commit e4ffc5c

Browse files
committed
Combine setup into one task
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent ac7be11 commit e4ffc5c

File tree

1 file changed

+163
-119
lines changed

1 file changed

+163
-119
lines changed

dali/operators/imgcodec/image_decoder.h

Lines changed: 163 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "dali/pipeline/operator/common.h"
3333
#include "dali/pipeline/operator/operator.h"
3434

35+
3536
#if not(WITH_DYNAMIC_NVIMGCODEC_ENABLED)
3637
nvimgcodecStatus_t get_libjpeg_turbo_extension_desc(nvimgcodecExtensionDesc_t *ext_desc);
3738
nvimgcodecStatus_t get_libtiff_extension_desc(nvimgcodecExtensionDesc_t *ext_desc);
@@ -674,58 +675,79 @@ class ImageDecoder : public StatelessOperator<Backend> {
674675
TensorListShape<> out_shape(nsamples, 3);
675676

676677
const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8;
677-
auto setup_block = [&](int block_idx, int nblocks, int tid) {
678-
int i_start = nsamples * block_idx / nblocks;
679-
int i_end = nsamples * (block_idx + 1) / nblocks;
680-
DomainTimeRange tr("Setup #" + std::to_string(block_idx) + "/" + std::to_string(nblocks),
681-
DomainTimeRange::kOrange);
682-
for (int i = i_start; i < i_end; i++) {
683-
auto *st = state_[i].get();
684-
st->image_info.buffer = nullptr;
685-
assert(st != nullptr);
686-
const auto &input_sample = input[i];
687-
688-
auto src_info = input.GetMeta(i).GetSourceInfo();
689-
if (use_cache && cache_->IsInCache(src_info)) {
690-
auto cached_shape = cache_->CacheImageShape(src_info);
691-
auto roi = GetRoi(spec_, ws, i, cached_shape);
692-
if (!roi.use_roi()) {
693-
out_shape.set_tensor_shape(i, st->out_shape);
694-
st->load_from_cache = true;
695-
continue;
696-
}
697-
}
698-
st->load_from_cache = false;
699-
ParseSample(st->parsed_sample,
700-
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
701-
volume(input_sample.shape())});
702-
st->sub_encoded_stream.reset();
703-
st->out_shape = st->parsed_sample.dali_img_info.shape;
704-
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
705-
if (use_orientation_ &&
706-
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
707-
std::swap(st->out_shape[0], st->out_shape[1]);
678+
auto setup_sample = [&](int sample_idx, int tid) {
679+
auto *st = state_[sample_idx].get();
680+
st->image_info.buffer = nullptr;
681+
assert(st != nullptr);
682+
const auto &input_sample = input[sample_idx];
683+
684+
auto src_info = input.GetMeta(sample_idx).GetSourceInfo();
685+
if (use_cache && cache_->IsInCache(src_info)) {
686+
auto cached_shape = cache_->CacheImageShape(src_info);
687+
auto roi = GetRoi(spec_, ws, sample_idx, cached_shape);
688+
if (!roi.use_roi()) {
689+
out_shape.set_tensor_shape(sample_idx, st->out_shape);
690+
st->load_from_cache = true;
691+
return;
708692
}
693+
}
694+
st->load_from_cache = false;
695+
ParseSample(st->parsed_sample,
696+
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
697+
volume(input_sample.shape())});
698+
st->sub_encoded_stream.reset();
699+
st->out_shape = st->parsed_sample.dali_img_info.shape;
700+
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
701+
if (use_orientation_ &&
702+
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
703+
std::swap(st->out_shape[0], st->out_shape[1]);
704+
}
709705

710-
ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape);
711-
if (roi.use_roi()) {
712-
auto roi_sh = roi.shape();
713-
if (roi.end.size() >= 2) {
714-
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && 0 <= roi.end[1] &&
715-
roi.end[1] <= st->out_shape[1],
716-
"ROI end must fit within the image bounds");
717-
}
718-
if (roi.begin.size() >= 2) {
719-
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
720-
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
721-
"ROI begin must fit within the image bounds");
722-
}
723-
st->out_shape[0] = roi_sh[0];
724-
st->out_shape[1] = roi_sh[1];
706+
ROI &roi = rois_[sample_idx] = GetRoi(spec_, ws, sample_idx, st->out_shape);
707+
if (roi.use_roi()) {
708+
auto roi_sh = roi.shape();
709+
if (roi.end.size() >= 2) {
710+
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && 0 <= roi.end[1] &&
711+
roi.end[1] <= st->out_shape[1],
712+
"ROI end must fit within the image bounds");
713+
}
714+
if (roi.begin.size() >= 2) {
715+
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
716+
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
717+
"ROI begin must fit within the image bounds");
725718
}
726-
out_shape.set_tensor_shape(i, st->out_shape);
727-
PrepareOutput(*state_[i], rois_[i], ws);
728-
assert(!ws.has_stream() || ws.stream() == st->image_info.cuda_stream);
719+
st->out_shape[0] = roi_sh[0];
720+
st->out_shape[1] = roi_sh[1];
721+
}
722+
out_shape.set_tensor_shape(sample_idx, st->out_shape);
723+
PrepareOutput(*state_[sample_idx], rois_[sample_idx], ws);
724+
assert(!ws.has_stream() || ws.stream() == st->image_info.cuda_stream);
725+
};
726+
727+
// The image descriptors are created in parallel, in block-wise fashion.
728+
auto init_desc_task = [&](int sample_idx) {
729+
auto &st = *state_[sample_idx];
730+
if (use_cache && st.load_from_cache) {
731+
return;
732+
}
733+
if (!st.need_processing) {
734+
st.image_info.buffer = output.raw_mutable_tensor(sample_idx);
735+
}
736+
st.image = NvImageCodecImage::Create(instance_, &st.image_info);
737+
if (rois_[sample_idx].use_roi()) {
738+
auto &roi = rois_[sample_idx];
739+
nvimgcodecCodeStreamView_t cs_view = {
740+
NVIMGCODEC_STRUCTURE_TYPE_CODE_STREAM_VIEW,
741+
sizeof(nvimgcodecCodeStreamView_t),
742+
nullptr,
743+
0, // image_idx
744+
{NVIMGCODEC_STRUCTURE_TYPE_REGION, sizeof(nvimgcodecRegion_t), nullptr, 2}};
745+
cs_view.region.start[0] = roi.begin[0];
746+
cs_view.region.start[1] = roi.begin[1];
747+
cs_view.region.end[0] = roi.end[0];
748+
cs_view.region.end[1] = roi.end[1];
749+
st.sub_encoded_stream = NvImageCodecCodeStream::FromSubCodeStream(
750+
st.parsed_sample.encoded_stream.get(), &cs_view);
729751
}
730752
};
731753

@@ -734,91 +756,70 @@ class ImageDecoder : public StatelessOperator<Backend> {
734756
int ntasks = std::min<int>(nblocks, std::min<int>(8, tp_->NumThreads() + 1));
735757

736758
if (ntasks < 2) {
759+
// run all in current thread
737760
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
738-
setup_block(0, 1, -1); // run all in current thread
761+
{
762+
DomainTimeRange tr("Parse", DomainTimeRange::kOrange);
763+
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
764+
setup_sample(sample_idx, -1);
765+
}
766+
}
767+
{
768+
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
769+
output.Resize(out_shape);
770+
}
771+
{
772+
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
773+
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
774+
init_desc_task(sample_idx);
775+
}
776+
}
739777
} else {
778+
// run in parallel
740779
int block_idx = 0;
741-
atomic_idx_.store(0);
742-
auto setup_task = [&, nblocks](int tid) {
780+
// relaxed, only need atomicity, not ordering
781+
atomic_idx_.store(0, std::memory_order_relaxed);
782+
parse_barrier_.Reset(ntasks);
783+
alloc_output_barrier_.Reset(ntasks);
784+
auto setup_task = [&](int tid) {
785+
int sample_idx;
743786
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
744-
int block_idx;
745-
while ((block_idx = atomic_idx_.fetch_add(1)) < nblocks) {
746-
setup_block(block_idx, nblocks, tid);
787+
{
788+
DomainTimeRange tr("Parse", DomainTimeRange::kOrange);
789+
while ((sample_idx = atomic_idx_.fetch_add(1, std::memory_order_relaxed)) < nsamples) {
790+
setup_sample(sample_idx, tid);
791+
}
747792
}
748-
};
749-
750-
for (int task_idx = 0; task_idx < ntasks - 1; task_idx++) {
751-
tp_->AddWork(setup_task, -task_idx);
752-
}
753-
assert(ntasks >= 2);
754-
tp_->RunAll(false); // start work but not wait
755-
setup_task(-1); // last task in current thread
756-
tp_->WaitForWork(); // wait for the other threads
757-
}
758-
759-
// Allocate the memory for the outputs...
760-
{
761-
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
762-
output.Resize(out_shape);
763-
}
764-
// ... and create image descriptors.
793+
parse_barrier_.ArriveAndWait(); // wait until parsing is done
765794

766-
// The image descriptors are created in parallel, in block-wise fashion.
767-
auto init_desc_task = [&](int start_sample, int end_sample) {
768-
DomainTimeRange tr(
769-
"Create images " + std::to_string(start_sample) + ".." + std::to_string(end_sample),
770-
DomainTimeRange::kOrange);
771-
for (int orig_idx = start_sample; orig_idx < end_sample; orig_idx++) {
772-
auto &st = *state_[orig_idx];
773-
if (use_cache && st.load_from_cache) {
774-
continue;
775-
}
776-
if (!st.need_processing) {
777-
st.image_info.buffer = output.raw_mutable_tensor(orig_idx);
778-
}
779-
st.image = NvImageCodecImage::Create(instance_, &st.image_info);
780-
if (rois_[orig_idx].use_roi()) {
781-
auto &roi = rois_[orig_idx];
782-
nvimgcodecCodeStreamView_t cs_view = {
783-
NVIMGCODEC_STRUCTURE_TYPE_CODE_STREAM_VIEW,
784-
sizeof(nvimgcodecCodeStreamView_t),
785-
nullptr,
786-
0, // image_idx
787-
{NVIMGCODEC_STRUCTURE_TYPE_REGION, sizeof(nvimgcodecRegion_t), nullptr, 2}};
788-
cs_view.region.start[0] = roi.begin[0];
789-
cs_view.region.start[1] = roi.begin[1];
790-
cs_view.region.end[0] = roi.end[0];
791-
cs_view.region.end[1] = roi.end[1];
792-
st.sub_encoded_stream = NvImageCodecCodeStream::FromSubCodeStream(
793-
st.parsed_sample.encoded_stream.get(), &cs_view);
795+
if (tid == -1) {
796+
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
797+
output.Resize(out_shape);
798+
atomic_idx_.store(0, std::memory_order_relaxed);
799+
alloc_output_barrier_.Arrive(); // No need to wait here, we are in the main thread
800+
} else {
801+
alloc_output_barrier_.ArriveAndWait(); // wait until allocation is done
794802
}
795-
}
796-
};
797803

798-
// Just one task? Run it in this thread!
799-
if (ntasks < 2) {
800-
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
801-
init_desc_task(0, nsamples);
802-
} else {
803-
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
804-
// Many tasks? Run in thread pool.
805-
int block_idx = 0;
806-
atomic_idx_.store(0);
807-
auto create_images_task = [&, nblocks](int tid) {
808-
int block_idx;
809-
while ((block_idx = atomic_idx_.fetch_add(1)) < nblocks) {
810-
int64_t start = nsamples * block_idx / nblocks;
811-
int64_t end = nsamples * (block_idx + 1) / nblocks;
812-
init_desc_task(start, end);
804+
// Create image descriptors
805+
{
806+
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
807+
while ((sample_idx = atomic_idx_.fetch_add(16, std::memory_order_relaxed)) < nsamples) {
808+
int sample_start = sample_idx;
809+
int sample_end = std::min(sample_idx + 16, nsamples);
810+
for (int i = sample_start; i < sample_end; i++) {
811+
init_desc_task(i);
812+
}
813+
}
813814
}
814815
};
815816

816817
for (int task_idx = 0; task_idx < ntasks - 1; task_idx++) {
817-
tp_->AddWork(create_images_task, -task_idx);
818+
tp_->AddWork(setup_task, -task_idx);
818819
}
819820
assert(ntasks >= 2);
820821
tp_->RunAll(false); // start work but not wait
821-
create_images_task(-1);
822+
setup_task(-1); // last task in current thread
822823
tp_->WaitForWork(); // wait for the other threads
823824
}
824825

@@ -985,6 +986,49 @@ class ImageDecoder : public StatelessOperator<Backend> {
985986
std::vector<nvimgcodecExtension_t> extensions_;
986987

987988
std::vector<std::function<void(int)>> nvimgcodec_scheduled_tasks_;
989+
990+
class ThreadBarrier {
991+
public:
992+
explicit ThreadBarrier(std::size_t count) : count_(count), current_(count) {}
993+
void Arrive() {
994+
std::unique_lock<std::mutex> lock(lock_);
995+
if (current_ == 0) {
996+
throw std::logic_error("barrier is already completed");
997+
}
998+
current_--;
999+
if (current_ == 0) {
1000+
cv_.notify_all();
1001+
}
1002+
}
1003+
void ArriveAndWait(bool reset = false) {
1004+
std::unique_lock<std::mutex> lock(lock_);
1005+
if (current_ == 0) {
1006+
throw std::logic_error("barrier is already completed");
1007+
}
1008+
current_--;
1009+
if (current_ == 0 || count_ == 0) {
1010+
if (reset)
1011+
current_ = count_;
1012+
cv_.notify_all();
1013+
} else {
1014+
cv_.wait(lock, [this] { return current_ == 0; });
1015+
}
1016+
}
1017+
void Reset(std::size_t count) {
1018+
std::lock_guard<std::mutex> lock(lock_);
1019+
count_ = count;
1020+
current_ = count;
1021+
}
1022+
1023+
private:
1024+
std::mutex lock_;
1025+
std::condition_variable cv_;
1026+
size_t count_;
1027+
size_t current_;
1028+
};
1029+
1030+
ThreadBarrier parse_barrier_{0};
1031+
ThreadBarrier alloc_output_barrier_{0};
9881032
};
9891033

9901034
} // namespace imgcodec

0 commit comments

Comments
 (0)