Skip to content

Commit 65ff9b4

Browse files
committed
Combine setup into one task
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent ac7be11 commit 65ff9b4

File tree

1 file changed

+147
-119
lines changed

1 file changed

+147
-119
lines changed

dali/operators/imgcodec/image_decoder.h

Lines changed: 147 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <vector>
2020
#include "dali/core/call_at_exit.h"
2121
#include "dali/core/mm/memory.h"
22+
#include "dali/core/semaphore.h"
2223
#include "dali/operators.h"
2324
#include "dali/operators/decoder/cache/cached_decoder_impl.h"
2425
#include "dali/operators/generic/slice/slice_attr.h"
@@ -674,58 +675,79 @@ class ImageDecoder : public StatelessOperator<Backend> {
674675
TensorListShape<> out_shape(nsamples, 3);
675676

676677
const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8;
677-
auto setup_block = [&](int block_idx, int nblocks, int tid) {
678-
int i_start = nsamples * block_idx / nblocks;
679-
int i_end = nsamples * (block_idx + 1) / nblocks;
680-
DomainTimeRange tr("Setup #" + std::to_string(block_idx) + "/" + std::to_string(nblocks),
681-
DomainTimeRange::kOrange);
682-
for (int i = i_start; i < i_end; i++) {
683-
auto *st = state_[i].get();
684-
st->image_info.buffer = nullptr;
685-
assert(st != nullptr);
686-
const auto &input_sample = input[i];
687-
688-
auto src_info = input.GetMeta(i).GetSourceInfo();
689-
if (use_cache && cache_->IsInCache(src_info)) {
690-
auto cached_shape = cache_->CacheImageShape(src_info);
691-
auto roi = GetRoi(spec_, ws, i, cached_shape);
692-
if (!roi.use_roi()) {
693-
out_shape.set_tensor_shape(i, st->out_shape);
694-
st->load_from_cache = true;
695-
continue;
696-
}
697-
}
698-
st->load_from_cache = false;
699-
ParseSample(st->parsed_sample,
700-
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
701-
volume(input_sample.shape())});
702-
st->sub_encoded_stream.reset();
703-
st->out_shape = st->parsed_sample.dali_img_info.shape;
704-
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
705-
if (use_orientation_ &&
706-
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
707-
std::swap(st->out_shape[0], st->out_shape[1]);
678+
auto setup_sample = [&](int sample_idx, int tid) {
679+
auto *st = state_[sample_idx].get();
680+
st->image_info.buffer = nullptr;
681+
assert(st != nullptr);
682+
const auto &input_sample = input[sample_idx];
683+
684+
auto src_info = input.GetMeta(sample_idx).GetSourceInfo();
685+
if (use_cache && cache_->IsInCache(src_info)) {
686+
auto cached_shape = cache_->CacheImageShape(src_info);
687+
auto roi = GetRoi(spec_, ws, sample_idx, cached_shape);
688+
if (!roi.use_roi()) {
689+
out_shape.set_tensor_shape(sample_idx, st->out_shape);
690+
st->load_from_cache = true;
691+
return;
708692
}
693+
}
694+
st->load_from_cache = false;
695+
ParseSample(st->parsed_sample,
696+
span<const uint8_t>{static_cast<const uint8_t *>(input_sample.raw_data()),
697+
volume(input_sample.shape())});
698+
st->sub_encoded_stream.reset();
699+
st->out_shape = st->parsed_sample.dali_img_info.shape;
700+
st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]);
701+
if (use_orientation_ &&
702+
(st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) {
703+
std::swap(st->out_shape[0], st->out_shape[1]);
704+
}
709705

710-
ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape);
711-
if (roi.use_roi()) {
712-
auto roi_sh = roi.shape();
713-
if (roi.end.size() >= 2) {
714-
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && 0 <= roi.end[1] &&
715-
roi.end[1] <= st->out_shape[1],
716-
"ROI end must fit within the image bounds");
717-
}
718-
if (roi.begin.size() >= 2) {
719-
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
720-
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
721-
"ROI begin must fit within the image bounds");
722-
}
723-
st->out_shape[0] = roi_sh[0];
724-
st->out_shape[1] = roi_sh[1];
706+
ROI &roi = rois_[sample_idx] = GetRoi(spec_, ws, sample_idx, st->out_shape);
707+
if (roi.use_roi()) {
708+
auto roi_sh = roi.shape();
709+
if (roi.end.size() >= 2) {
710+
DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && 0 <= roi.end[1] &&
711+
roi.end[1] <= st->out_shape[1],
712+
"ROI end must fit within the image bounds");
725713
}
726-
out_shape.set_tensor_shape(i, st->out_shape);
727-
PrepareOutput(*state_[i], rois_[i], ws);
728-
assert(!ws.has_stream() || ws.stream() == st->image_info.cuda_stream);
714+
if (roi.begin.size() >= 2) {
715+
DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] &&
716+
0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1],
717+
"ROI begin must fit within the image bounds");
718+
}
719+
st->out_shape[0] = roi_sh[0];
720+
st->out_shape[1] = roi_sh[1];
721+
}
722+
out_shape.set_tensor_shape(sample_idx, st->out_shape);
723+
PrepareOutput(*state_[sample_idx], rois_[sample_idx], ws);
724+
assert(!ws.has_stream() || ws.stream() == st->image_info.cuda_stream);
725+
};
726+
727+
// The image descriptors are created in parallel, in block-wise fashion.
728+
auto init_desc_task = [&](int sample_idx) {
729+
auto &st = *state_[sample_idx];
730+
if (use_cache && st.load_from_cache) {
731+
return;
732+
}
733+
if (!st.need_processing) {
734+
st.image_info.buffer = output.raw_mutable_tensor(sample_idx);
735+
}
736+
st.image = NvImageCodecImage::Create(instance_, &st.image_info);
737+
if (rois_[sample_idx].use_roi()) {
738+
auto &roi = rois_[sample_idx];
739+
nvimgcodecCodeStreamView_t cs_view = {
740+
NVIMGCODEC_STRUCTURE_TYPE_CODE_STREAM_VIEW,
741+
sizeof(nvimgcodecCodeStreamView_t),
742+
nullptr,
743+
0, // image_idx
744+
{NVIMGCODEC_STRUCTURE_TYPE_REGION, sizeof(nvimgcodecRegion_t), nullptr, 2}};
745+
cs_view.region.start[0] = roi.begin[0];
746+
cs_view.region.start[1] = roi.begin[1];
747+
cs_view.region.end[0] = roi.end[0];
748+
cs_view.region.end[1] = roi.end[1];
749+
st.sub_encoded_stream = NvImageCodecCodeStream::FromSubCodeStream(
750+
st.parsed_sample.encoded_stream.get(), &cs_view);
729751
}
730752
};
731753

@@ -734,91 +756,64 @@ class ImageDecoder : public StatelessOperator<Backend> {
734756
int ntasks = std::min<int>(nblocks, std::min<int>(8, tp_->NumThreads() + 1));
735757

736758
if (ntasks < 2) {
759+
// run all in current thread
737760
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
738-
setup_block(0, 1, -1); // run all in current thread
761+
{
762+
DomainTimeRange tr("Parse", DomainTimeRange::kOrange);
763+
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
764+
setup_sample(sample_idx, -1);
765+
}
766+
}
767+
{
768+
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
769+
output.Resize(out_shape);
770+
}
771+
{
772+
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
773+
for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) {
774+
init_desc_task(sample_idx);
775+
}
776+
}
739777
} else {
778+
// run in parallel
740779
int block_idx = 0;
741-
atomic_idx_.store(0);
742-
auto setup_task = [&, nblocks](int tid) {
780+
// relaxed, only need atomicity, not ordering
781+
atomic_idx_.store(0, std::memory_order_relaxed);
782+
parse_barrier_.Reset(ntasks);
783+
alloc_output_barrier_.Reset(ntasks);
784+
auto setup_task = [&](int tid) {
785+
int sample_idx;
743786
DomainTimeRange tr("Setup", DomainTimeRange::kOrange);
744-
int block_idx;
745-
while ((block_idx = atomic_idx_.fetch_add(1)) < nblocks) {
746-
setup_block(block_idx, nblocks, tid);
787+
{
788+
DomainTimeRange tr("Parse", DomainTimeRange::kOrange);
789+
while ((sample_idx = atomic_idx_.fetch_add(1, std::memory_order_relaxed)) < nsamples) {
790+
setup_sample(sample_idx, tid);
791+
}
747792
}
748-
};
793+
parse_barrier_.Wait(); // wait until parsing is done
749794

750-
for (int task_idx = 0; task_idx < ntasks - 1; task_idx++) {
751-
tp_->AddWork(setup_task, -task_idx);
752-
}
753-
assert(ntasks >= 2);
754-
tp_->RunAll(false); // start work but not wait
755-
setup_task(-1); // last task in current thread
756-
tp_->WaitForWork(); // wait for the other threads
757-
}
758-
759-
// Allocate the memory for the outputs...
760-
{
761-
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
762-
output.Resize(out_shape);
763-
}
764-
// ... and create image descriptors.
765-
766-
// The image descriptors are created in parallel, in block-wise fashion.
767-
auto init_desc_task = [&](int start_sample, int end_sample) {
768-
DomainTimeRange tr(
769-
"Create images " + std::to_string(start_sample) + ".." + std::to_string(end_sample),
770-
DomainTimeRange::kOrange);
771-
for (int orig_idx = start_sample; orig_idx < end_sample; orig_idx++) {
772-
auto &st = *state_[orig_idx];
773-
if (use_cache && st.load_from_cache) {
774-
continue;
775-
}
776-
if (!st.need_processing) {
777-
st.image_info.buffer = output.raw_mutable_tensor(orig_idx);
778-
}
779-
st.image = NvImageCodecImage::Create(instance_, &st.image_info);
780-
if (rois_[orig_idx].use_roi()) {
781-
auto &roi = rois_[orig_idx];
782-
nvimgcodecCodeStreamView_t cs_view = {
783-
NVIMGCODEC_STRUCTURE_TYPE_CODE_STREAM_VIEW,
784-
sizeof(nvimgcodecCodeStreamView_t),
785-
nullptr,
786-
0, // image_idx
787-
{NVIMGCODEC_STRUCTURE_TYPE_REGION, sizeof(nvimgcodecRegion_t), nullptr, 2}};
788-
cs_view.region.start[0] = roi.begin[0];
789-
cs_view.region.start[1] = roi.begin[1];
790-
cs_view.region.end[0] = roi.end[0];
791-
cs_view.region.end[1] = roi.end[1];
792-
st.sub_encoded_stream = NvImageCodecCodeStream::FromSubCodeStream(
793-
st.parsed_sample.encoded_stream.get(), &cs_view);
795+
if (tid == -1) {
796+
DomainTimeRange tr("Alloc output", DomainTimeRange::kOrange);
797+
output.Resize(out_shape);
798+
atomic_idx_.store(0, std::memory_order_relaxed);
794799
}
795-
}
796-
};
797800

798-
// Just one task? Run it in this thread!
799-
if (ntasks < 2) {
800-
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
801-
init_desc_task(0, nsamples);
802-
} else {
803-
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
804-
// Many tasks? Run in thread pool.
805-
int block_idx = 0;
806-
atomic_idx_.store(0);
807-
auto create_images_task = [&, nblocks](int tid) {
808-
int block_idx;
809-
while ((block_idx = atomic_idx_.fetch_add(1)) < nblocks) {
810-
int64_t start = nsamples * block_idx / nblocks;
811-
int64_t end = nsamples * (block_idx + 1) / nblocks;
812-
init_desc_task(start, end);
801+
alloc_output_barrier_.Wait(); // wait until allocation is done
802+
// Create image descriptors
803+
{
804+
DomainTimeRange tr("Create images", DomainTimeRange::kOrange);
805+
while ((sample_idx = atomic_idx_.fetch_add(1, std::memory_order_relaxed)) < nsamples) {
806+
init_desc_task(sample_idx);
807+
}
813808
}
814809
};
815810

816811
for (int task_idx = 0; task_idx < ntasks - 1; task_idx++) {
817-
tp_->AddWork(create_images_task, -task_idx);
812+
tp_->AddWork(setup_task, -task_idx);
818813
}
819814
assert(ntasks >= 2);
820815
tp_->RunAll(false); // start work but not wait
821-
create_images_task(-1);
816+
setup_task(-1); // last task in current thread
822817
tp_->WaitForWork(); // wait for the other threads
823818
}
824819

@@ -985,6 +980,39 @@ class ImageDecoder : public StatelessOperator<Backend> {
985980
std::vector<nvimgcodecExtension_t> extensions_;
986981

987982
std::vector<std::function<void(int)>> nvimgcodec_scheduled_tasks_;
983+
984+
class ThreadBarrier {
985+
public:
986+
explicit ThreadBarrier(std::size_t count) : count_(count), current_(count) {}
987+
void Wait(bool reset = false) {
988+
std::unique_lock<std::mutex> lock(mutex_);
989+
if (current_ == 0) {
990+
throw std::logic_error("barrier is already completed");
991+
}
992+
current_--;
993+
if (current_ == 0 || count_ == 0) {
994+
if (reset)
995+
current_ = count_;
996+
cv_.notify_all();
997+
} else {
998+
cv_.wait(lock, [this] { return current_ == 0; });
999+
}
1000+
}
1001+
void Reset(std::size_t count) {
1002+
std::lock_guard<std::mutex> lock(mutex_);
1003+
count_ = count;
1004+
current_ = count;
1005+
}
1006+
1007+
private:
1008+
std::mutex mutex_;
1009+
std::condition_variable cv_;
1010+
size_t count_;
1011+
size_t current_;
1012+
};
1013+
1014+
ThreadBarrier parse_barrier_{0};
1015+
ThreadBarrier alloc_output_barrier_{0};
9881016
};
9891017

9901018
} // namespace imgcodec

0 commit comments

Comments
 (0)