Skip to content

Commit b5d5e48

Browse files
committed
Use DLPack to wrap GPU and CPU external tensors.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 739fa82 commit b5d5e48

File tree

4 files changed

+34
-7
lines changed

4 files changed

+34
-7
lines changed

dali/python/backend_impl.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ void FillTensorFromDlPack(py::capsule capsule, SourceDataType<SrcBackend> *batch
212212
shape[i] = dl_tensor.shape[i];
213213
}
214214

215-
CheckContiguousTensor(dl_tensor.strides, dl_tensor.ndim, dl_tensor.shape, dl_tensor.ndim, 1);
215+
if (dl_tensor.strides)
216+
CheckContiguousTensor(dl_tensor.strides, dl_tensor.ndim, dl_tensor.shape, dl_tensor.ndim, 1);
217+
216218
size_t bytes = volume(shape) * dali_type.size();
217219

218220
auto typed_shape = ConvertShape(shape, batch);

dali/python/nvidia/dali/experimental/dali2/_eval_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __enter__(self):
4747
def __exit__(self, exc_type, exc_value, traceback):
4848
_tls.stack.pop()
4949

50+
@property
51+
def device(self):
52+
return self._device
53+
5054
@property
5155
def cuda_stream(self):
5256
return self._cuda_stream

dali/python/nvidia/dali/experimental/dali2/_tensor.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _backend_device(backend: Union[_backend.TensorCPU, _backend.TensorGPU]) -> D
3535
if isinstance(backend, _backend.TensorCPU):
3636
return Device("cpu")
3737
elif isinstance(backend, _backend.TensorGPU):
38-
return Device("gpu", backend.device())
38+
return Device("gpu", backend.device_id())
3939
else:
4040
raise ValueError(f"Unsupported backend type: {type(backend)}")
4141

@@ -121,8 +121,22 @@ def __init__(
121121
self.assign(fn.cast(data, dtype, device=device).evaluate())
122122
elif isinstance(data, TensorSlice):
123123
self._slice = data
124-
elif hasattr(data, "__dlpack__"):
125-
self._backend = _backend.TensorCPU(data, layout)
124+
elif hasattr(data, "__dlpack_device__"):
125+
dl_device_type, device_id = data.__dlpack_device__()
126+
if int(dl_device_type) == 1: # CPU
127+
self._backend = _backend.TensorCPU(data.__dlpack__(), layout)
128+
elif int(dl_device_type) == 2: # GPU
129+
# If the current context is on the same device, use the same stream.
130+
ctx = _EvalContext.get()
131+
if ctx.device.device_id == device_id:
132+
stream = ctx.cuda_stream
133+
args = {"stream": stream.handle}
134+
else:
135+
# TODO(michalz): Come up with better stream semantics
136+
args = {}
137+
self._backend = _backend.TensorGPU(data.__dlpack__(**args), layout)
138+
else:
139+
raise ValueError(f"Unsupported device type: {dl_device_type}")
126140
self._wraps_external_data = True
127141
elif hasattr(data, "__array__"):
128142
self._backend = _backend.TensorCPU(data, layout)
@@ -159,10 +173,14 @@ def __init__(
159173
copied = True
160174
self._wraps_external_data = False
161175

162-
if device is not None:
163-
device = self._device = device if isinstance(device, Device) else Device(device)
176+
if self._backend is not None:
177+
self._device = _backend_device(self._backend)
178+
if device is None:
179+
device = self._device
164180
else:
165-
device = self._device = Device("cpu")
181+
if device is None:
182+
device = Device("cpu")
183+
self._device = device
166184

167185
if self._backend is not None:
168186
self._shape = self._backend.shape()
@@ -214,6 +232,8 @@ def to_device(self, device: Device, force_copy: bool = False) -> "Tensor":
214232
return fn.copy(self, device=device.device_type)
215233

216234
def assign(self, other: "Tensor"):
235+
if other is self:
236+
return
217237
self._device = other._device
218238
self._shape = other._shape
219239
self._dtype = other._dtype

dali/python/nvidia/dali/experimental/dali2/_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_type2id = {}
2020
_name2type = {}
2121

22+
2223
class DType:
2324
class Kind(Enum):
2425
signed = auto()

0 commit comments

Comments
 (0)