Skip to content

Commit 262f821

Browse files
authored
C API changes required for TF plugin. (#5898)
* Tensor and TensorList copy-out functions * Query for Tensor and TensorList DType (including 0-sample tensor lists) * Query for Tensor and TensorList byte size * Utility macro to set optional field in param structures. --------- Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 83ac37a commit 262f821

File tree

8 files changed

+453
-49
lines changed

8 files changed

+453
-49
lines changed

dali/c_api_2/data_objects.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ daliResult_t daliTensorGetShape(
222222
DALI_EPILOG();
223223
}
224224

225+
daliResult_t daliTensorGetByteSize(
226+
daliTensor_h tensor,
227+
size_t *out_size) {
228+
DALI_PROLOG();
229+
auto *ptr = ToPointer(tensor);
230+
CHECK_OUTPUT(out_size);
231+
*out_size = ptr->GetByteSize();
232+
DALI_EPILOG();
233+
}
234+
235+
daliResult_t daliTensorGetDType(
236+
daliTensor_h tensor,
237+
daliDataType_t *out_dtype) {
238+
DALI_PROLOG();
239+
auto *ptr = ToPointer(tensor);
240+
CHECK_OUTPUT(out_dtype);
241+
*out_dtype = ptr->GetDType();
242+
DALI_EPILOG();
243+
}
244+
225245
daliResult_t daliTensorGetSourceInfo(
226246
daliTensor_h tensor,
227247
const char **out_source_info) {
@@ -391,6 +411,26 @@ daliResult_t daliTensorListGetShape(
391411
DALI_EPILOG();
392412
}
393413

414+
daliResult_t daliTensorListGetByteSize(
415+
daliTensorList_h tensor_list,
416+
size_t *out_size) {
417+
DALI_PROLOG();
418+
auto *ptr = ToPointer(tensor_list);
419+
CHECK_OUTPUT(out_size);
420+
*out_size = ptr->GetByteSize();
421+
DALI_EPILOG();
422+
}
423+
424+
daliResult_t daliTensorListGetDType(
425+
daliTensorList_h tensor_list,
426+
daliDataType_t *out_dtype) {
427+
DALI_PROLOG();
428+
auto *ptr = ToPointer(tensor_list);
429+
CHECK_OUTPUT(out_dtype);
430+
*out_dtype = ptr->GetDType();
431+
DALI_EPILOG();
432+
}
433+
394434
daliResult_t daliTensorListGetTensorDesc(
395435
daliTensorList_h tensor_list,
396436
daliTensorDesc_t *out_tensor,
@@ -432,3 +472,16 @@ daliResult_t daliTensorListViewAsTensor(
432472
*out_tensor = t.release(); // no throwing allowed after this line
433473
DALI_EPILOG();
434474
}
475+
476+
daliResult_t daliTensorListCopyOut(
477+
daliTensorList_h tensor_list,
478+
void *dst_buffer,
479+
daliBufferPlacement_t dst_buffer_placement,
480+
const cudaStream_t *stream,
481+
daliCopyFlags_t flags) {
482+
DALI_PROLOG();
483+
auto *ptr = ToPointer(tensor_list);
484+
CHECK_OUTPUT(dst_buffer)
485+
ptr->CopyOut(dst_buffer, dst_buffer_placement, ToOptional(stream), flags);
486+
DALI_EPILOG();
487+
}

dali/c_api_2/data_objects.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <vector>
2424
#include "dali/dali.h"
2525
#include "dali/pipeline/data/tensor_list.h"
26+
#include "dali/pipeline/data/copy_to_external.h"
2627
#include "dali/c_api_2/ref_counting.h"
2728
#include "dali/c_api_2/validation.h"
2829

@@ -43,6 +44,15 @@ struct _DALITensor {
4344
namespace dali {
4445
namespace c_api {
4546

47+
constexpr mm::memory_kind_id GetMemoryKind(const daliBufferPlacement_t &placement) {
48+
if (placement.device_type == DALI_STORAGE_GPU) {
49+
return mm::memory_kind_id::device;
50+
} else {
51+
assert(placement.device_type == DALI_STORAGE_CPU);
52+
return placement.pinned ? mm::memory_kind_id::pinned : mm::memory_kind_id::host;
53+
}
54+
}
55+
4656
//////////////////////////////////////////////////////////////////////////////
4757
// Interfaces
4858
//////////////////////////////////////////////////////////////////////////////
@@ -88,10 +98,21 @@ class ITensor : public _DALITensor, public RefCountedObject {
8898

8999
virtual const TensorShape<> &GetShape() const & = 0;
90100

101+
virtual size_t GetByteSize() const = 0;
102+
103+
virtual daliDataType_t GetDType() const = 0;
104+
91105
virtual const char *GetSourceInfo() const & = 0;
92106

93107
virtual void SetSourceInfo(const char *source_info) = 0;
94108

109+
virtual void CopyOut(
110+
void *dst_buffer,
111+
daliBufferPlacement_t dst_buffer_placement,
112+
std::optional<cudaStream_t> stream,
113+
daliCopyFlags_t flags) = 0;
114+
115+
95116
/** Retrieves the underlying DALI Tensor<Backend> pointer.
96117
*
97118
* Returns a shared pointer to the underlying DALI object. If the backend doesn't match,
@@ -155,12 +176,22 @@ class ITensorList : public _DALITensorList, public RefCountedObject {
155176

156177
virtual const TensorListShape<> &GetShape() const & = 0;
157178

179+
virtual size_t GetByteSize() const = 0;
180+
181+
virtual daliDataType_t GetDType() const = 0;
182+
158183
virtual RefCountedPtr<ITensor> ViewAsTensor() const = 0;
159184

160185
virtual const char *GetSourceInfo(int sample) const & = 0;
161186

162187
virtual void SetSourceInfo(int sample, const char *source_info) = 0;
163188

189+
virtual void CopyOut(
190+
void *dst_buffer,
191+
daliBufferPlacement_t dst_buffer_placement,
192+
std::optional<cudaStream_t> stream,
193+
daliCopyFlags_t flags) = 0;
194+
164195
/** Retrieves the underlying DALI TensorList<Backend> pointer.
165196
*
166197
* Returns a shared pointer to the underlying DALI object. If the backend doesn't match,
@@ -334,6 +365,14 @@ class TensorWrapper : public ITensor {
334365
return t_->shape();
335366
}
336367

368+
size_t GetByteSize() const override {
369+
return t_->nbytes();
370+
}
371+
372+
daliDataType_t GetDType() const override {
373+
return t_->type();
374+
}
375+
337376
const char *GetSourceInfo() const & override {
338377
const char *info = t_->GetMeta().GetSourceInfo().c_str();
339378
if (info && !*info)
@@ -345,6 +384,19 @@ class TensorWrapper : public ITensor {
345384
t_->SetSourceInfo(source_info ? source_info : "");
346385
}
347386

387+
void CopyOut(
388+
void *dst_buffer,
389+
daliBufferPlacement_t dst_buffer_placement,
390+
std::optional<cudaStream_t> stream,
391+
daliCopyFlags_t flags) override {
392+
Validate(dst_buffer_placement);
393+
AccessOrder order = stream ? *stream : t_->order();
394+
mm::memory_kind_id mem_kind = GetMemoryKind(dst_buffer_placement);
395+
CopyToExternal(dst_buffer, mem_kind, *t_, order, flags & DALI_COPY_USE_KERNEL);
396+
if (flags & DALI_COPY_SYNC)
397+
AccessOrder::host().wait(order);
398+
}
399+
348400
const auto &NativePtr() const & {
349401
return t_;
350402
}
@@ -646,6 +698,14 @@ class TensorListWrapper : public ITensorList {
646698
return tl_->shape();
647699
}
648700

701+
size_t GetByteSize() const override {
702+
return tl_->nbytes();
703+
}
704+
705+
daliDataType_t GetDType() const override {
706+
return tl_->type();
707+
}
708+
649709
const char *GetSourceInfo(int sample) const & override {
650710
ValidateSampleIdx(sample);
651711
const char *info = tl_->GetMeta(sample).GetSourceInfo().c_str();
@@ -684,6 +744,19 @@ class TensorListWrapper : public ITensorList {
684744
return Wrap(std::move(t));
685745
}
686746

747+
void CopyOut(
748+
void *dst_buffer,
749+
daliBufferPlacement_t dst_buffer_placement,
750+
std::optional<cudaStream_t> stream,
751+
daliCopyFlags_t flags) override {
752+
Validate(dst_buffer_placement);
753+
AccessOrder order = stream ? *stream : tl_->order();
754+
mm::memory_kind_id mem_kind = GetMemoryKind(dst_buffer_placement);
755+
CopyToExternal(dst_buffer, mem_kind, *tl_, order, flags & DALI_COPY_USE_KERNEL);
756+
if (flags & DALI_COPY_SYNC)
757+
AccessOrder::host().wait(order);
758+
}
759+
687760
const auto &NativePtr() const & {
688761
return tl_;
689762
}

0 commit comments

Comments
 (0)