2323#include < vector>
2424#include " dali/dali.h"
2525#include " dali/pipeline/data/tensor_list.h"
26+ #include " dali/pipeline/data/copy_to_external.h"
2627#include " dali/c_api_2/ref_counting.h"
2728#include " dali/c_api_2/validation.h"
2829
@@ -43,6 +44,15 @@ struct _DALITensor {
4344namespace dali {
4445namespace c_api {
4546
47+ constexpr mm::memory_kind_id GetMemoryKind (const daliBufferPlacement_t &placement) {
48+ if (placement.device_type == DALI_STORAGE_GPU) {
49+ return mm::memory_kind_id::device;
50+ } else {
51+ assert (placement.device_type == DALI_STORAGE_CPU);
52+ return placement.pinned ? mm::memory_kind_id::pinned : mm::memory_kind_id::host;
53+ }
54+ }
55+
4656// ////////////////////////////////////////////////////////////////////////////
4757// Interfaces
4858// ////////////////////////////////////////////////////////////////////////////
@@ -88,10 +98,21 @@ class ITensor : public _DALITensor, public RefCountedObject {
8898
8999 virtual const TensorShape<> &GetShape () const & = 0;
90100
101+ virtual size_t GetByteSize () const = 0;
102+
103+ virtual daliDataType_t GetDType () const = 0;
104+
91105 virtual const char *GetSourceInfo () const & = 0;
92106
93107 virtual void SetSourceInfo (const char *source_info) = 0;
94108
109+ virtual void CopyOut (
110+ void *dst_buffer,
111+ daliBufferPlacement_t dst_buffer_placement,
112+ std::optional<cudaStream_t> stream,
113+ daliCopyFlags_t flags) = 0;
114+
115+
95116 /* * Retrieves the underlying DALI Tensor<Backend> pointer.
96117 *
97118 * Returns a shared pointer to the underlying DALI object. If the backend doesn't match,
@@ -155,12 +176,22 @@ class ITensorList : public _DALITensorList, public RefCountedObject {
155176
156177 virtual const TensorListShape<> &GetShape () const & = 0;
157178
179+ virtual size_t GetByteSize () const = 0;
180+
181+ virtual daliDataType_t GetDType () const = 0;
182+
158183 virtual RefCountedPtr<ITensor> ViewAsTensor () const = 0;
159184
160185 virtual const char *GetSourceInfo (int sample) const & = 0;
161186
162187 virtual void SetSourceInfo (int sample, const char *source_info) = 0;
163188
189+ virtual void CopyOut (
190+ void *dst_buffer,
191+ daliBufferPlacement_t dst_buffer_placement,
192+ std::optional<cudaStream_t> stream,
193+ daliCopyFlags_t flags) = 0;
194+
164195 /* * Retrieves the underlying DALI TensorList<Backend> pointer.
165196 *
166197 * Returns a shared pointer to the underlying DALI object. If the backend doesn't match,
@@ -334,6 +365,14 @@ class TensorWrapper : public ITensor {
334365 return t_->shape ();
335366 }
336367
368+ size_t GetByteSize () const override {
369+ return t_->nbytes ();
370+ }
371+
372+ daliDataType_t GetDType () const override {
373+ return t_->type ();
374+ }
375+
337376 const char *GetSourceInfo () const & override {
338377 const char *info = t_->GetMeta ().GetSourceInfo ().c_str ();
339378 if (info && !*info)
@@ -345,6 +384,19 @@ class TensorWrapper : public ITensor {
345384 t_->SetSourceInfo (source_info ? source_info : " " );
346385 }
347386
387+ void CopyOut (
388+ void *dst_buffer,
389+ daliBufferPlacement_t dst_buffer_placement,
390+ std::optional<cudaStream_t> stream,
391+ daliCopyFlags_t flags) override {
392+ Validate (dst_buffer_placement);
393+ AccessOrder order = stream ? *stream : t_->order ();
394+ mm::memory_kind_id mem_kind = GetMemoryKind (dst_buffer_placement);
395+ CopyToExternal (dst_buffer, mem_kind, *t_, order, flags & DALI_COPY_USE_KERNEL);
396+ if (flags & DALI_COPY_SYNC)
397+ AccessOrder::host ().wait (order);
398+ }
399+
348400 const auto &NativePtr () const & {
349401 return t_;
350402 }
@@ -646,6 +698,14 @@ class TensorListWrapper : public ITensorList {
646698 return tl_->shape ();
647699 }
648700
701+ size_t GetByteSize () const override {
702+ return tl_->nbytes ();
703+ }
704+
705+ daliDataType_t GetDType () const override {
706+ return tl_->type ();
707+ }
708+
649709 const char *GetSourceInfo (int sample) const & override {
650710 ValidateSampleIdx (sample);
651711 const char *info = tl_->GetMeta (sample).GetSourceInfo ().c_str ();
@@ -684,6 +744,19 @@ class TensorListWrapper : public ITensorList {
684744 return Wrap (std::move (t));
685745 }
686746
747+ void CopyOut (
748+ void *dst_buffer,
749+ daliBufferPlacement_t dst_buffer_placement,
750+ std::optional<cudaStream_t> stream,
751+ daliCopyFlags_t flags) override {
752+ Validate (dst_buffer_placement);
753+ AccessOrder order = stream ? *stream : tl_->order ();
754+ mm::memory_kind_id mem_kind = GetMemoryKind (dst_buffer_placement);
755+ CopyToExternal (dst_buffer, mem_kind, *tl_, order, flags & DALI_COPY_USE_KERNEL);
756+ if (flags & DALI_COPY_SYNC)
757+ AccessOrder::host ().wait (order);
758+ }
759+
687760 const auto &NativePtr () const & {
688761 return tl_;
689762 }
0 commit comments