|
1 | | -# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +# Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 | import nvidia.dali.ops as ops |
| 17 | +import nvidia.dali.fn as fn |
17 | 18 | import nvidia.dali.plugin.tf as dali_tf |
18 | 19 | import nvidia.dali.types as types |
19 | 20 | import os.path |
@@ -67,7 +68,7 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus): |
67 | 68 |
|
68 | 69 | def define_graph(self): |
69 | 70 | images, labels = self.input() |
70 | | - return self.base_define_graph(images, labels) |
| 71 | + return self.base_define_graph(images, fn.reshape(labels, shape=[])) |
71 | 72 |
|
72 | 73 |
|
73 | 74 | def get_batch_dali(batch_size, pipe_type, label_type, num_gpus=1): |
@@ -118,6 +119,47 @@ def test_dali_tf_op(pipe_type=CaffeReadPipeline, batch_size=16, iterations=32): |
118 | 119 | assert (label <= 999).all() |
119 | 120 |
|
120 | 121 |
|
| 122 | +def get_batch_dali_sparse(batch_size, pipe_type, label_type): |
| 123 | + pipe = pipe_type(batch_size=batch_size, num_threads=2, device_id=None, num_gpus=1) |
| 124 | + |
| 125 | + daliop = dali_tf.DALIIterator() |
| 126 | + with tf.device("/cpu"): |
| 127 | + image, label = daliop( |
| 128 | + pipeline=pipe, |
| 129 | + shapes=[(batch_size, 3, 227, 227), ()], |
| 130 | + dtypes=[tf.int32, label_type], |
| 131 | + device_id=None, |
| 132 | + sparse=[False, True], |
| 133 | + ) |
| 134 | + |
| 135 | + return image, label |
| 136 | + |
| 137 | + |
| 138 | +def test_dali_tf_op_sparse(pipe_type=CaffeReadPipeline, batch_size=16, iterations=32): |
| 139 | + test_batch = get_batch_dali_sparse(batch_size, pipe_type, tf.int32) |
| 140 | + try: |
| 141 | + from tensorflow.compat.v1 import GPUOptions |
| 142 | + from tensorflow.compat.v1 import ConfigProto |
| 143 | + from tensorflow.compat.v1 import Session |
| 144 | + except ImportError: |
| 145 | + # Older TF versions don't have compat.v1 layer |
| 146 | + from tensorflow import GPUOptions |
| 147 | + from tensorflow import ConfigProto |
| 148 | + from tensorflow import Session |
| 149 | + |
| 150 | + gpu_options = GPUOptions(per_process_gpu_memory_fraction=0.5) |
| 151 | + config = ConfigProto(gpu_options=gpu_options) |
| 152 | + with Session(config=config) as sess: |
| 153 | + for i in range(iterations): |
| 154 | + _, label = sess.run(test_batch) |
| 155 | + # Testing correctness of labels |
| 156 | + # labels need to be integers |
| 157 | + label = label.values |
| 158 | + assert np.equal(np.mod(label, 1), 0).all() |
| 159 | + assert (label >= 0).all() |
| 160 | + assert (label <= 999).all() |
| 161 | + |
| 162 | + |
121 | 163 | class PythonOperatorPipeline(Pipeline): |
122 | 164 | def __init__(self): |
123 | 165 | super().__init__(1, 1, 0, 0) |
|
0 commit comments