Skip to content

Commit ebef6eb

Browse files
committed
[WIP] Tensor stream
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent b5d5e48 commit ebef6eb

File tree

1 file changed

+7
-4
lines changed
  • dali/python/nvidia/dali/experimental/dali2

1 file changed

+7
-4
lines changed

dali/python/nvidia/dali/experimental/dali2/_tensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,14 @@ def __init__(
130130
ctx = _EvalContext.get()
131131
if ctx.device.device_id == device_id:
132132
stream = ctx.cuda_stream
133-
args = {"stream": stream.handle}
134133
else:
135-
# TODO(michalz): Come up with better stream semantics
136-
args = {}
137-
self._backend = _backend.TensorGPU(data.__dlpack__(**args), layout)
134+
stream = backend.Stream(device_id)
135+
args = {"stream": stream.handle}
136+
self._backend = _backend.TensorGPU(
137+
data.__dlpack__(**args),
138+
layout=layout,
139+
stream=stream,
140+
)
138141
else:
139142
raise ValueError(f"Unsupported device type: {dl_device_type}")
140143
self._wraps_external_data = True

0 commit comments

Comments
 (0)