Skip to content

Commit 0f36330

Browse files
committed
Convert DType arguments to Type ID enum.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent d8457b5 commit 0f36330

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

dali/python/nvidia/dali/experimental/dali2/_op_builder.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import warnings
2121
from . import ops
2222
from . import fn
23+
from . import _type
2324
import types
2425
import copy
2526
import sys
@@ -35,6 +36,16 @@ def is_external(x):
3536
return False
3637

3738

39+
def _scalar_decay(x):
40+
if isinstance(x, _device.Device):
41+
return x.device_type
42+
if isinstance(x, _type.DType):
43+
return x.type_id
44+
if x is str:
45+
return types.STRING
46+
return x
47+
48+
3849
def _get_input_device(x):
3950
with nvtx.annotate("get_input_device", domain="op_builder"):
4051
if x is None:
@@ -189,6 +200,7 @@ def build_constructor(schema, op_class):
189200
header = f"__init__({', '.join(header_args)})"
190201

191202
def init(self, max_batch_size, name, **kwargs):
203+
kwargs = {k: _scalar_decay(v) for k, v in kwargs.items()}
192204
op_class.__base__.__init__(self, max_batch_size, name, **kwargs)
193205
if stateful:
194206
self._call_id = 0
@@ -406,12 +418,12 @@ def fn_call(*inputs, batch_size=None, device=None, **raw_kwargs):
406418
break
407419
max_batch_size = _next_pow2(batch_size or 1)
408420
init_args = {
409-
arg: raw_kwargs[arg]
421+
arg: _scalar_decay(raw_kwargs[arg])
410422
for arg in fixed_args
411423
if arg != "max_batch_size" and arg in raw_kwargs and raw_kwargs[arg] is not None
412424
}
413425
call_args = {
414-
arg: raw_kwargs[arg]
426+
arg: _scalar_decay(raw_kwargs[arg])
415427
for arg in tensor_args
416428
if arg in raw_kwargs and raw_kwargs[arg] is not None
417429
}

0 commit comments

Comments
 (0)