Skip to content

Commit bbbca34

Browse files
authored
Merge pull request #3 from jantonguirao/janton_imperative_mode
dali2: operator placement handling (string, Device, torch.device) and tests
2 parents 2c866f1 + ee3752d commit bbbca34

File tree

4 files changed

+480
-26
lines changed

4 files changed

+480
-26
lines changed

dali/python/nvidia/dali/experimental/dali2/_device.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,43 @@
1414

1515
import nvidia.dali.backend as _backend
1616
from threading import local
17-
17+
from typing import Union, Optional
1818

1919
class Device:
2020
_thread_local = local()
2121

2222
def __init__(self, name: str, device_id: int = None):
23+
device_type, name_device_id = Device.split_device_type_and_id(name)
24+
if name_device_id is not None and device_id is not None:
25+
raise ValueError(f"Invalid device name: {name}\n"
26+
f"Ordinal ':{name_device_id}' should not appear "
27+
"in device name when device_id is provided")
2328
if device_id is None:
24-
type_and_id = name.split(":")
25-
if len(type_and_id) < 1 or len(type_and_id) > 2:
26-
raise ValueError(f"Invalid device name: {name}")
27-
device_type = type_and_id[0]
28-
if len(type_and_id) == 2:
29-
device_id = int(type_and_id[1])
30-
else:
31-
if ":" in name:
32-
raise ValueError(
33-
f"Invalid device name: {name}\n"
34-
f"':' should not appear in device name when device_id is provided"
35-
)
36-
device_type = name
29+
device_id = name_device_id
30+
31+
if device_type == "cuda":
32+
device_type = "gpu"
3733

3834
Device.validate_device_type(device_type)
3935
if device_id is not None:
4036
Device.validate_device_id(device_id, device_type)
4137
else:
4238
device_id = Device.default_device_id(device_type)
39+
4340
self.device_type = device_type
4441
self.device_id = device_id
4542

43+
@staticmethod
44+
def split_device_type_and_id(name: str) -> tuple[str, int]:
45+
type_and_id = name.split(":")
46+
if len(type_and_id) < 1 or len(type_and_id) > 2:
47+
raise ValueError(f"Invalid device name: {name}")
48+
device_type = type_and_id[0]
49+
device_id = None
50+
if len(type_and_id) == 2:
51+
device_id = int(type_and_id[1])
52+
return device_type, device_id
53+
4654
@staticmethod
4755
def default_device_id(device_type: str) -> int:
4856
if device_type == "cpu":
@@ -135,3 +143,43 @@ def __exit__(self, exc_type, exc_value, traceback):
135143

136144
Device._thread_local.devices = None
137145
Device._thread_local.previous_device_ids = None
146+
147+
148+
def device(obj: Union[Device, str, "torch.device"], id: Optional[int] = None) -> Device:
149+
"""
150+
Returns a Device object from various input types.
151+
152+
- If `obj` is already a `Device`, returns it. In this case, `id` must be `None`.
153+
- If `obj` is a `str`, parses it as a device name (e.g., `"gpu"`, `"cpu:0"`, `"cuda:1"`). In this case, `id` can be specified.
154+
Note: If the string already contains a device id and `id` is also provided, a `ValueError` is raised.
155+
- If `obj` is a `torch.device`, converts it to a `Device`. In this case, `id` must be `None`.
156+
- If `obj` is None, returns it.
157+
- If `obj` is not a `Device`, `str`, or `torch.device` or None, raises a `TypeError`.
158+
"""
159+
160+
# None
161+
if obj is None:
162+
return obj
163+
164+
# Device instance
165+
if isinstance(obj, Device):
166+
if id is not None:
167+
raise ValueError("Cannot specify id when passing a Device instance")
168+
return obj
169+
170+
if isinstance(obj, str):
171+
return Device(obj, id)
172+
173+
# torch.device detected by duck-typing
174+
is_torch_device = (
175+
obj.__class__.__module__ == "torch" and
176+
obj.__class__.__name__ == "device" and
177+
hasattr(obj, "type") and
178+
hasattr(obj, "index"))
179+
if is_torch_device:
180+
dev_type = "gpu" if obj.type == "cuda" else obj.type
181+
if id is not None:
182+
raise ValueError("Cannot specify id when passing a torch.device")
183+
return Device(dev_type, obj.index)
184+
185+
raise TypeError(f"Cannot convert {type(obj)} to Device")

dali/python/nvidia/dali/experimental/dali2/_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
self,
6868
data: Optional[Any] = None,
6969
dtype: Optional[Any] = None,
70-
device: Optional[Device] = None,
70+
device: Optional[Union[Device, str, "torch.device"]] = None,
7171
layout: Optional[str] = None,
7272
batch: Optional[Any] = None,
7373
index_in_batch: Optional[int] = None,
@@ -92,6 +92,9 @@ def __init__(
9292

9393
copied = False
9494

95+
from ._device import device as _to_device
96+
device = _to_device(device)
97+
9598
from . import _fn
9699

97100
if dtype is not None:

dali/python/nvidia/dali/experimental/dali2/ops.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,10 @@ def __init__(
3838
self._num_inputs = num_inputs
3939
self._call_arg_names = None if call_arg_names is None else tuple(call_arg_names)
4040
self._api_type = None
41-
if isinstance(device, str):
42-
self._device = _device.Device(
43-
name=device,
44-
device_id=kwargs.get("device_id", _device.Device.default_device_id(device)),
45-
)
46-
else:
47-
if not isinstance(device, _device.Device):
48-
raise TypeError(
49-
f"`device` must be a Device instance or a string, got {type(device)}"
50-
)
51-
self._device = device
41+
42+
from ._device import device as _to_device
43+
self._device = _to_device(device)
44+
5245
self._input_meta = []
5346
self._arg_meta = {}
5447
self._num_outputs = None

0 commit comments

Comments
 (0)