|
14 | 14 |
|
15 | 15 | import nvidia.dali.backend as _backend |
16 | 16 | from threading import local |
17 | | - |
| 17 | +from typing import Union, Optional |
18 | 18 |
|
19 | 19 | class Device: |
20 | 20 | _thread_local = local() |
21 | 21 |
|
22 | 22 | def __init__(self, name: str, device_id: int = None): |
| 23 | + device_type, name_device_id = Device.split_device_type_and_id(name) |
| 24 | + if name_device_id is not None and device_id is not None: |
| 25 | + raise ValueError(f"Invalid device name: {name}\n" |
| 26 | + f"Ordinal ':{name_device_id}' should not appear " |
| 27 | + "in device name when device_id is provided") |
23 | 28 | if device_id is None: |
24 | | - type_and_id = name.split(":") |
25 | | - if len(type_and_id) < 1 or len(type_and_id) > 2: |
26 | | - raise ValueError(f"Invalid device name: {name}") |
27 | | - device_type = type_and_id[0] |
28 | | - if len(type_and_id) == 2: |
29 | | - device_id = int(type_and_id[1]) |
30 | | - else: |
31 | | - if ":" in name: |
32 | | - raise ValueError( |
33 | | - f"Invalid device name: {name}\n" |
34 | | - f"':' should not appear in device name when device_id is provided" |
35 | | - ) |
36 | | - device_type = name |
| 29 | + device_id = name_device_id |
| 30 | + |
| 31 | + if device_type == "cuda": |
| 32 | + device_type = "gpu" |
37 | 33 |
|
38 | 34 | Device.validate_device_type(device_type) |
39 | 35 | if device_id is not None: |
40 | 36 | Device.validate_device_id(device_id, device_type) |
41 | 37 | else: |
42 | 38 | device_id = Device.default_device_id(device_type) |
| 39 | + |
43 | 40 | self.device_type = device_type |
44 | 41 | self.device_id = device_id |
45 | 42 |
|
| 43 | + @staticmethod |
| 44 | + def split_device_type_and_id(name: str) -> tuple[str, int]: |
| 45 | + type_and_id = name.split(":") |
| 46 | + if len(type_and_id) < 1 or len(type_and_id) > 2: |
| 47 | + raise ValueError(f"Invalid device name: {name}") |
| 48 | + device_type = type_and_id[0] |
| 49 | + device_id = None |
| 50 | + if len(type_and_id) == 2: |
| 51 | + device_id = int(type_and_id[1]) |
| 52 | + return device_type, device_id |
| 53 | + |
46 | 54 | @staticmethod |
47 | 55 | def default_device_id(device_type: str) -> int: |
48 | 56 | if device_type == "cpu": |
@@ -135,3 +143,43 @@ def __exit__(self, exc_type, exc_value, traceback): |
135 | 143 |
|
136 | 144 | Device._thread_local.devices = None |
137 | 145 | Device._thread_local.previous_device_ids = None |
| 146 | + |
| 147 | + |
| 148 | +def device(obj: Union[Device, str, "torch.device"], id: Optional[int] = None) -> Device: |
| 149 | + """ |
| 150 | + Returns a Device object from various input types. |
| 151 | +
|
| 152 | + - If `obj` is already a `Device`, returns it. In this case, `id` must be `None`. |
| 153 | + - If `obj` is a `str`, parses it as a device name (e.g., `"gpu"`, `"cpu:0"`, `"cuda:1"`). In this case, `id` can be specified. |
| 154 | + Note: If the string already contains a device id and `id` is also provided, a `ValueError` is raised. |
| 155 | + - If `obj` is a `torch.device`, converts it to a `Device`. In this case, `id` must be `None`. |
| 156 | + - If `obj` is None, returns it. |
| 157 | + - If `obj` is not a `Device`, `str`, or `torch.device` or None, raises a `TypeError`. |
| 158 | + """ |
| 159 | + |
| 160 | + # None |
| 161 | + if obj is None: |
| 162 | + return obj |
| 163 | + |
| 164 | + # Device instance |
| 165 | + if isinstance(obj, Device): |
| 166 | + if id is not None: |
| 167 | + raise ValueError("Cannot specify id when passing a Device instance") |
| 168 | + return obj |
| 169 | + |
| 170 | + if isinstance(obj, str): |
| 171 | + return Device(obj, id) |
| 172 | + |
| 173 | + # torch.device detected by duck-typing |
| 174 | + is_torch_device = ( |
| 175 | + obj.__class__.__module__ == "torch" and |
| 176 | + obj.__class__.__name__ == "device" and |
| 177 | + hasattr(obj, "type") and |
| 178 | + hasattr(obj, "index")) |
| 179 | + if is_torch_device: |
| 180 | + dev_type = "gpu" if obj.type == "cuda" else obj.type |
| 181 | + if id is not None: |
| 182 | + raise ValueError("Cannot specify id when passing a torch.device") |
| 183 | + return Device(dev_type, obj.index) |
| 184 | + |
| 185 | + raise TypeError(f"Cannot convert {type(obj)} to Device") |
0 commit comments