Skip to content

Commit ab54781

Browse files
committed
Rebase on reinterpret.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 9667466 commit ab54781

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

dali/python/nvidia/dali/experimental/dali2/_batch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import Any, Optional, Tuple, List, Union
1616
from ._type import DType, dtype as _dtype, type_id as _type_id
17-
from ._tensor import Tensor, _is_full_slice, _is_tensor_type
17+
from ._tensor import Tensor, _is_full_slice, _is_tensor_type, _try_convert_enums
1818
import nvidia.dali.backend as _backend
1919
from ._eval_context import EvalContext as _EvalContext
2020
from ._device import Device
@@ -164,16 +164,21 @@ def broadcast(sample, batch_size: int, device: Optional[Device] = None) -> "Batc
164164
import numpy as np
165165
with nvtx.annotate("to numpy and stack", domain="batch"):
166166
arr = np.array(sample)
167+
converted_dtype_id = None
167168
if arr.dtype == np.float64:
168169
arr = arr.astype(np.float32)
169170
elif arr.dtype == np.int64:
170171
arr = arr.astype(np.int32)
171172
elif arr.dtype == np.uint64:
172173
arr = arr.astype(np.uint32)
174+
elif arr.dtype == object:
175+
arr, converted_dtype_id = _try_convert_enums(arr)
173176
arr = np.repeat(arr[np.newaxis], batch_size, axis=0)
174177

175178
with nvtx.annotate("to backend", domain="batch"):
176179
tl = _backend.TensorListCPU(arr)
180+
if converted_dtype_id is not None:
181+
tl.reinterpret(converted_dtype_id)
177182
with nvtx.annotate("create batch", domain="batch"):
178183
return Batch(tl, device=device)
179184

dali/python/nvidia/dali/experimental/dali2/_op_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _get_batch_size(x):
9595
return x.batch_size
9696
if isinstance(x, (_b.TensorListCPU, _b.TensorListGPU)):
9797
return len(x)
98-
if isinstance(x, list) and any(_is_tensor_type(t, True) for t in x):
98+
if isinstance(x, list) and any(isinstance(t, Tensor) for t in x):
9999
return len(x)
100100
return None
101101

dali/python/nvidia/dali/experimental/dali2/_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _is_tensor_type(x, nested_list_warning=False):
3535
from . import _batch
3636

3737
if isinstance(x, _batch.Batch):
38-
raise ValueError("A list of Batchs is not a valid argument type")
38+
raise ValueError("A list of Batch objects is not a valid argument type")
3939
if isinstance(x, Tensor):
4040
return True
4141
if hasattr(x, "__array__"):
@@ -177,7 +177,7 @@ def __init__(
177177
(arr, converted_dtype_id) = _try_convert_enums(arr)
178178
self._backend = _backend.TensorCPU(arr, layout, False)
179179
if converted_dtype_id is not None:
180-
self._backend.reinterpret_as(converted_dtype_id)
180+
self._backend.reinterpret(converted_dtype_id)
181181
copied = True
182182
self._wraps_external_data = False
183183

0 commit comments

Comments
 (0)