Skip to content

Commit 47042ce

Browse files
authored
Fix handling of scalars in TF sparse tensors. (#5887)
The TF code assumed that 0-dim tensors are empty and constructed incorrect sparse tensors. We now support 0D tensors (scalars) and this change adds a proper handling for those. Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent eeb1b87 commit 47042ce

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

dali/test/python/test_dali_tf_plugin_run.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import nvidia.dali.ops as ops
17+
import nvidia.dali.fn as fn
1718
import nvidia.dali.plugin.tf as dali_tf
1819
import nvidia.dali.types as types
1920
import os.path
@@ -67,7 +68,7 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus):
6768

6869
def define_graph(self):
6970
images, labels = self.input()
70-
return self.base_define_graph(images, labels)
71+
return self.base_define_graph(images, fn.reshape(labels, shape=[]))
7172

7273

7374
def get_batch_dali(batch_size, pipe_type, label_type, num_gpus=1):
@@ -118,6 +119,47 @@ def test_dali_tf_op(pipe_type=CaffeReadPipeline, batch_size=16, iterations=32):
118119
assert (label <= 999).all()
119120

120121

122+
def get_batch_dali_sparse(batch_size, pipe_type, label_type):
123+
pipe = pipe_type(batch_size=batch_size, num_threads=2, device_id=None, num_gpus=1)
124+
125+
daliop = dali_tf.DALIIterator()
126+
with tf.device("/cpu"):
127+
image, label = daliop(
128+
pipeline=pipe,
129+
shapes=[(batch_size, 3, 227, 227), ()],
130+
dtypes=[tf.int32, label_type],
131+
device_id=None,
132+
sparse=[False, True],
133+
)
134+
135+
return image, label
136+
137+
138+
def test_dali_tf_op_sparse(pipe_type=CaffeReadPipeline, batch_size=16, iterations=32):
139+
test_batch = get_batch_dali_sparse(batch_size, pipe_type, tf.int32)
140+
try:
141+
from tensorflow.compat.v1 import GPUOptions
142+
from tensorflow.compat.v1 import ConfigProto
143+
from tensorflow.compat.v1 import Session
144+
except ImportError:
145+
# Older TF versions don't have compat.v1 layer
146+
from tensorflow import GPUOptions
147+
from tensorflow import ConfigProto
148+
from tensorflow import Session
149+
150+
gpu_options = GPUOptions(per_process_gpu_memory_fraction=0.5)
151+
config = ConfigProto(gpu_options=gpu_options)
152+
with Session(config=config) as sess:
153+
for i in range(iterations):
154+
_, label = sess.run(test_batch)
155+
# Testing correctness of labels
156+
# labels need to be integers
157+
label = label.values
158+
assert np.equal(np.mod(label, 1), 0).all()
159+
assert (label >= 0).all()
160+
assert (label <= 999).all()
161+
162+
121163
class PythonOperatorPipeline(Pipeline):
122164
def __init__(self):
123165
super().__init__(1, 1, 0, 0)

dali_tf_plugin/daliop.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -289,14 +289,6 @@ class DaliOp : public tf::OpKernel {
289289
for (unsigned n = 0; n < elms; ++n) {
290290
TF_DALI_CALL(data_output_shape = DaliToShape(AutoCPtr<int64_t>(
291291
daliShapeAtSample(&pipe_handle_, i, n))));
292-
// it seems that num_elements() return 1 for empty tensors
293-
if (data_output_shape.dims() == 0) {
294-
continue;
295-
}
296-
// squeeze
297-
if (data_output_shape.dim_size(data_output_shape.dims() - 1) == 1) {
298-
data_output_shape.RemoveLastDims(1);
299-
}
300292
for (unsigned elm_idx = 0; elm_idx < data_output_shape.num_elements(); ++elm_idx) {
301293
unsigned idx_val = elm_idx;
302294
// first value of indices is tensor index

0 commit comments

Comments
 (0)