Skip to content

Commit 9543827

Browse files
committed
mps_test
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent 47042ce commit 9543827

File tree

5 files changed

+485
-103
lines changed

5 files changed

+485
-103
lines changed
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__all__ = ["DALIPipelineRunner"]
16+
17+
#from torch.cuda import nvtx as _nvtx
18+
import torch.multiprocessing as _mp
19+
from torch.utils import data as _torchdata
20+
from torch.utils.data._utils.collate import default_collate_fn_map as _default_collate_fn_map
21+
from nvidia.dali import Pipeline as _Pipeline
22+
from nvidia.dali.external_source import ExternalSource as _ExternalSource
23+
from nvidia.dali.plugin.pytorch.torch_utils import to_torch_tensor
24+
from inspect import Parameter, Signature
25+
26+
def _modify_signature(new_sig):
27+
from functools import wraps
28+
29+
def decorator(func):
30+
@wraps(func)
31+
def wrapper(*args, **kwargs):
32+
try:
33+
bound_args = new_sig.bind(*args, **kwargs)
34+
bound_args.apply_defaults()
35+
return func(*bound_args.args, **bound_args.kwargs)
36+
except Exception as err:
37+
args_str = ", ".join([f"{type(arg)}" for arg in args])
38+
kwargs_str = ", ".join([f"{key}={type(arg)}" for key, arg in kwargs.items()])
39+
raise ValueError(
40+
f"Expected signature is: {new_sig}. "
41+
f"Got: args=({args_str}), kwargs={kwargs_str}, error: {err}"
42+
)
43+
44+
wrapper.__signature__ = new_sig
45+
return wrapper
46+
47+
return decorator
48+
49+
50+
def _external_source_node_names(pipeline):
51+
"""
52+
extract the names of all the ExternalSource nodes in the pipeline
53+
"""
54+
# TODO(janton): Add a native function to query those names, so that we can do it
55+
# also on deserialized pipelines
56+
if pipeline._deserialized:
57+
raise RuntimeError(
58+
"Not able to find the external source "
59+
"operator names, since the pipeline was deserialized"
60+
)
61+
if not pipeline._py_graph_built:
62+
pipeline._build_graph()
63+
input_node_names = []
64+
for op in pipeline._ops:
65+
if isinstance(op._op, _ExternalSource):
66+
input_node_names.append(op.name)
67+
return input_node_names
68+
69+
70+
class DALIOutputSampleRef:
71+
"""
72+
Reference for a single sample output bound to a pipeline run.
73+
"""
74+
75+
def __init__(self, pipe, output_idx, sample_idx):
76+
"""
77+
Args:
78+
pipe (DALIPipeline): The pipeline object used.
79+
output_idx (int): The index of the output in the pipeline.
80+
sample_idx (int): The index of the sample within the batch.
81+
"""
82+
self.pipe = pipe
83+
self.output_idx = output_idx
84+
self.sample_idx = sample_idx
85+
86+
def __repr__(self):
87+
return (
88+
f"DALIOutputSampleRef(pipe={self.pipe}, "
89+
+ f"output_idx={self.output_idx}, sample_idx={self.sample_idx})"
90+
)
91+
92+
class DALIOutputBatchRef:
93+
"""
94+
Reference for a batched output bound to a pipeline run.
95+
"""
96+
97+
def __init__(self, pipe, output_idx):
98+
"""
99+
Args:
100+
pipe (_DALIPipeline): A reference to the pipeline.
101+
output_idx (int): The index of the output in the pipeline.
102+
"""
103+
self.pipe = pipe
104+
self.output_idx = output_idx
105+
106+
def __repr__(self):
107+
return f"DALIOutputBatchRef(pipe={self.pipe}, output_idx={self.output_idx})"
108+
109+
110+
def _collate_dali_output_sample_ref_fn(samples, *, collate_fn_map=None):
111+
"""
112+
Special collate function that schedules a DALI iteration for execution
113+
"""
114+
assert len(samples) > 0
115+
pipe = samples[0].pipe
116+
output_idx = samples[0].output_idx
117+
for i, sample in enumerate(samples):
118+
if (
119+
sample.pipe != pipe
120+
or sample.output_idx != output_idx
121+
):
122+
raise RuntimeError("All samples should belong to the same batch")
123+
124+
if sample.sample_idx != i:
125+
raise RuntimeError("Unexpected sample order")
126+
127+
return pipe._complete_batch()[output_idx]
128+
129+
130+
# In-place modify `default_collate_fn_map` to handle DALIOutputSampleRef
131+
_default_collate_fn_map.update({DALIOutputSampleRef: _collate_dali_output_sample_ref_fn})
132+
133+
class DALIPipelineRunner:
134+
def __init__(self, pipeline_fn, pipeline_kwargs):
135+
# Pipeline function
136+
self._pipeline_fn = pipeline_fn
137+
# Pipeline kwargs
138+
self._pipeline_kwargs = pipeline_kwargs
139+
# get pipeline
140+
self._pipe = None
141+
self._signature = None
142+
self._num_outputs = None
143+
# Current batch
144+
self._curr_batch_params = {}
145+
# Whether the current batch is complete
146+
self._batch_complete = False
147+
# batch idx
148+
self._batch_idx = None
149+
self._batch_sample_idx = None
150+
# Outputs of the current batch
151+
self._batch_outputs = []
152+
153+
self._callable = None
154+
155+
def init_pipeline(self):
156+
if self._pipe is not None:
157+
return self._pipe
158+
159+
self._pipe = self._pipeline_fn(**self._pipeline_kwargs)
160+
161+
# Override callable signature
162+
self._dali_input_names = _external_source_node_names(self._pipe)
163+
num_inputs = len(self._dali_input_names)
164+
if num_inputs == 0:
165+
raise RuntimeError("The provided pipeline doesn't have any inputs")
166+
167+
parameters = [Parameter("self", Parameter.POSITIONAL_OR_KEYWORD)]
168+
parameter_kind = (
169+
Parameter.POSITIONAL_OR_KEYWORD if num_inputs == 1 else Parameter.KEYWORD_ONLY
170+
)
171+
for input_name in self._dali_input_names:
172+
parameters.append(Parameter(input_name, parameter_kind))
173+
return_annotation = tuple(DALIOutputSampleRef for _ in range(self._pipe.num_outputs))
174+
self._signature = Signature(parameters, return_annotation=return_annotation)
175+
self._num_outputs = self._pipe.num_outputs
176+
return self._pipe
177+
178+
def _add_sample(self, inputs):
179+
"""
180+
Adds a sample to the current batch. In the collate function, we mark the batch as
181+
complete and submit it for execution.
182+
When a completed batch is encountered, a new batch should be started.
183+
"""
184+
if self._batch_idx is None or self._batch_complete:
185+
self._batch_idx = self._batch_idx + 1 if self._batch_idx is not None else 0
186+
self._batch_sample_idx = 0
187+
self._curr_batch_params = {}
188+
self._batch_complete = False
189+
190+
for name, value in inputs.items():
191+
# we want to transfer only the arguments to the caller side, not the the self reference
192+
if name == "self":
193+
continue
194+
if name not in self._curr_batch_params:
195+
self._curr_batch_params[name] = []
196+
self._curr_batch_params[name].append(value)
197+
198+
ret = tuple(DALIOutputSampleRef(self, output_idx=i, sample_idx=self._batch_sample_idx) for i in range(self._num_outputs))
199+
200+
# unpack single element tuple
201+
if len(ret) == 1:
202+
ret = ret[0]
203+
self._batch_sample_idx += 1
204+
return ret
205+
206+
def _complete_batch(self):
207+
"""
208+
Complete the current batch and submit it for execution.
209+
"""
210+
if self._batch_complete is False:
211+
self._batch_complete = True
212+
for key, value in self._curr_batch_params.items():
213+
self._pipe.feed_input(key, value)
214+
self._pipe._run_once()
215+
dali_outputs = self._pipe.outputs()
216+
self._batch_outputs = tuple(to_torch_tensor(out.as_tensor(), not self._pipe.exec_dynamic) for out in dali_outputs)
217+
return self._batch_outputs
218+
219+
def __call__(self, *args, **kwargs):
220+
self.init_pipeline()
221+
bound_args = self._signature.bind(self, *args, **kwargs)
222+
return self._add_sample(bound_args.arguments)

0 commit comments

Comments
 (0)