Skip to content

Commit 88d15e1

Browse files
authored
Add dynamic API for math functions + tests. (#6066)
* Add math functions and tests. * Move arithm op arguments to GPU if there's any GPU argument. * Forbid mixing of DALI CPU and GPU tensors/batches in arithmetic ops and math functions. --------- Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 1775903 commit 88d15e1

File tree

4 files changed

+438
-1
lines changed

4 files changed

+438
-1
lines changed

dali/python/nvidia/dali/experimental/dynamic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from . import _fn
2727
from . import ops
28+
from . import math # noqa: F401
2829

2930
ops._initialize()
3031
_fn._initialize()

dali/python/nvidia/dali/experimental/dynamic/_batch.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,30 @@ def _arithm_op(name, *args, **kwargs):
101101
from . import arithmetic_generic_op
102102

103103
argsstr = " ".join(f"&{i}" for i in range(len(args)))
104-
return arithmetic_generic_op(*args, expression_desc=f"{name}({argsstr})")
104+
gpu = False
105+
new_args = [None] * len(args)
106+
for i, a in enumerate(args):
107+
if isinstance(a, (Batch, Tensor)):
108+
if a.device.device_type == "gpu":
109+
gpu = True
110+
else:
111+
# TODO(michalz): We might use some caching here for common values.
112+
if new_args is None:
113+
new_args = list(args)
114+
if gpu:
115+
new_args[i] = _as_tensor(a, device="gpu")
116+
else:
117+
new_args[i] = _as_tensor(a)
118+
if new_args[i].device.device_type == "gpu":
119+
gpu = True
120+
121+
for i in range(len(args)):
122+
if new_args[i] is None:
123+
if (args[i].device.device_type == "gpu") != gpu:
124+
raise ValueError("Cannot mix GPU and CPU inputs.")
125+
new_args[i] = args[i]
126+
127+
return arithmetic_generic_op(*new_args, expression_desc=f"{name}({argsstr})")
105128

106129

107130
class _TensorList:
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ._batch import _arithm_op
16+
17+
18+
def sqrt(input):
19+
"""Computes square root of values in `input`.
20+
21+
:rtype: Tensor or Batch of sqrt(input). If input is an integer, the result will be float,
22+
otherwise the type is preserved.
23+
"""
24+
return _arithm_op("sqrt", input)
25+
26+
27+
def rsqrt(input):
28+
"""Computes reciprocal of the square root of values in `input`.
29+
30+
:rtype: Tensor or Batch of rsqrt(input). If input is an integer, the result will be float,
31+
otherwise the type is preserved.
32+
"""
33+
return _arithm_op("rsqrt", input)
34+
35+
36+
def cbrt(input):
37+
"""Computes cube root of values in `input`.
38+
39+
:rtype: Tensor or Batch of cbrt(input). If input is an integer, the result will be float,
40+
otherwise the type is preserved.
41+
"""
42+
return _arithm_op("cbrt", input)
43+
44+
45+
def exp(input):
46+
"""Computes exponential of values in `input`.
47+
48+
:rtype: Tensor or Batch of exp(input). If input is an integer, the result will be float,
49+
otherwise the type is preserved.
50+
"""
51+
return _arithm_op("exp", input)
52+
53+
54+
def log(input):
55+
"""Computes natural logarithm (base-e) of values in `input`.
56+
57+
:rtype: Tensor or Batch of log(input). If input is an integer, the result will be float,
58+
otherwise the type is preserved.
59+
"""
60+
return _arithm_op("log", input)
61+
62+
63+
def log2(input):
64+
"""Computes logarithm (base-2) of values in `input`.
65+
66+
:rtype: Tensor or Batch of log2(input). If input is an integer, the result will be float,
67+
otherwise the type is preserved.
68+
"""
69+
return _arithm_op("log2", input)
70+
71+
72+
def log10(input):
73+
"""Computes logarithm (base-10) of values in `input`.
74+
75+
:rtype: Tensor or Batch of log10(input). If input is an integer, the result will be float,
76+
otherwise the type is preserved.
77+
"""
78+
return _arithm_op("log10", input)
79+
80+
81+
def abs(input):
82+
"""Computes absolute value of values in `input`.
83+
84+
:rtype: Tensor or Batch of abs(input). The type is preserved.
85+
"""
86+
return _arithm_op("abs", input)
87+
88+
89+
def fabs(input):
90+
"""Computes float absolute value of values in `input`.
91+
92+
:rtype: Tensor or Batch of fabs(input). If input is an integer, the result will be float,
93+
otherwise the type is preserved.
94+
"""
95+
return _arithm_op("fabs", input)
96+
97+
98+
def floor(input):
99+
"""Computes floor of values in `input`.
100+
101+
:rtype: Tensor or Batch of floor(input). If input is an integer, the result will be float,
102+
otherwise the type is preserved.
103+
"""
104+
return _arithm_op("floor", input)
105+
106+
107+
def ceil(input):
108+
"""Computes ceil of values in `input`.
109+
110+
:rtype: Tensor or Batch of ceil(input). If input is an integer, the result will be float,
111+
otherwise the type is preserved.
112+
"""
113+
return _arithm_op("ceil", input)
114+
115+
116+
def sin(input):
117+
"""Computes sine of values in `input`.
118+
119+
:rtype: Tensor or Batch of sin(input). If input is an integer, the result will be float,
120+
otherwise the type is preserved.
121+
"""
122+
return _arithm_op("sin", input)
123+
124+
125+
def cos(input):
126+
"""Computes cosine of values in `input`.
127+
128+
:rtype: Tensor or Batch of cos(input). If input is an integer, the result will be float,
129+
otherwise the type is preserved.
130+
"""
131+
return _arithm_op("cos", input)
132+
133+
134+
def tan(input):
135+
"""Computes tangent of values in `input`.
136+
137+
:rtype: Tensor or Batch of tan(input). If input is an integer, the result will be float,
138+
otherwise the type is preserved.
139+
"""
140+
return _arithm_op("tan", input)
141+
142+
143+
def asin(input):
144+
"""Computes arcus sine of values in `input`.
145+
146+
:rtype: Tensor or Batch of asin(input). If input is an integer, the result will be float,
147+
otherwise the type is preserved.
148+
"""
149+
return _arithm_op("asin", input)
150+
151+
152+
def acos(input):
153+
"""Computes arcus cosine of values in `input`.
154+
155+
:rtype: Tensor or Batch of acos(input). If input is an integer, the result will be float,
156+
otherwise the type is preserved.
157+
"""
158+
return _arithm_op("acos", input)
159+
160+
161+
def atan(input):
162+
"""Computes arcus tangent of values in `input`.
163+
164+
:rtype: Tensor or Batch of atan(input). If input is an integer, the result will be float,
165+
otherwise the type is preserved.
166+
"""
167+
return _arithm_op("atan", input)
168+
169+
170+
def sinh(input):
171+
"""Computes hyperbolic sine of values in `input`.
172+
173+
:rtype: Tensor or Batch of sinh(input). If input is an integer, the result will be float,
174+
otherwise the type is preserved.
175+
"""
176+
return _arithm_op("sinh", input)
177+
178+
179+
def cosh(input):
180+
"""Computes hyperbolic cosine of values in `input`.
181+
182+
:rtype: Tensor or Batch of cosh(input). If input is an integer, the result will be float,
183+
otherwise the type is preserved.
184+
"""
185+
return _arithm_op("cosh", input)
186+
187+
188+
def tanh(input):
189+
"""Computes hyperbolic tangent of values in `input`.
190+
191+
:rtype: Tensor or Batch of tanh(input). If input is an integer, the result will be float,
192+
otherwise the type is preserved.
193+
"""
194+
return _arithm_op("tanh", input)
195+
196+
197+
def asinh(input):
198+
"""Computes inverse hyperbolic sine of values in `input`.
199+
200+
:rtype: Tensor or Batch of asinh(input). If input is an integer, the result will be float,
201+
otherwise the type is preserved.
202+
"""
203+
return _arithm_op("asinh", input)
204+
205+
206+
def acosh(input):
207+
"""Computes inverse hyperbolic cosine of values in `input`.
208+
209+
:rtype: Tensor or Batch of acosh(input). If input is an integer, the result will be float,
210+
otherwise the type is preserved.
211+
"""
212+
return _arithm_op("acosh", input)
213+
214+
215+
def atanh(input):
216+
"""Computes inverse hyperbolic tangent of values in `input`.
217+
218+
:rtype: Tensor or Batch of atanh(input). If input is an integer, the result will be float,
219+
otherwise the type is preserved.
220+
"""
221+
return _arithm_op("atanh", input)
222+
223+
224+
def min(left, right):
225+
"""Computes minima of corresponding values in `left` and `right`.
226+
227+
:rtype: Tensor or Batch of the type that is calculated based on the type promotion rules.
228+
"""
229+
return _arithm_op("min", left, right)
230+
231+
232+
def max(left, right):
233+
"""Computes maxima of corresponding values in `left` and `right`.
234+
235+
:rtype: Tensor or Batch of the type that is calculated based on the type promotion rules.
236+
"""
237+
return _arithm_op("max", left, right)
238+
239+
240+
def pow(base, exponent):
241+
"""Computes base to the power of exponents, that is base ** exponent.
242+
243+
:rtype: Tensor or Batch of pow(base, exponent). Type is calculated based on the type
244+
promotion rules.
245+
"""
246+
return _arithm_op("pow", base, exponent)
247+
248+
249+
def fpow(base, exponent):
250+
"""Computes base to the power of exponents as floating point numbers.
251+
252+
:rtype: Tensor or Batch of pow(base, exponent). If all inputs are integers, the result
253+
will be float, otherwise the type is preserved.
254+
"""
255+
return _arithm_op("fpow", base, exponent)
256+
257+
258+
def atan2(x, y):
259+
"""Computes arcus tangent of corresponding values in x / y.
260+
261+
:rtype: Tensor or Batch of atan2(x, y). If all inputs are integers, the result will be float,
262+
otherwise the type is preserved.
263+
"""
264+
return _arithm_op("atan2", x, y)
265+
266+
267+
def clamp(value, lo, hi):
268+
"""Produces a tensor of values from `value` clamped to the range ``[lo, hi]``.
269+
270+
:rtype: Tensor or Batch of the type that is calculated based on the type promotion rules.
271+
"""
272+
return _arithm_op("clamp", value, lo, hi)

0 commit comments

Comments
 (0)