Skip to content

Commit ae814e1

Browse files
authored
[API Compatiblity] Support paddle.linalg.solve, paddle.nn.functional.normalize, paddle.quantile (PaddlePaddle#76470)
* add Decorator * fix out * fix UT * fix * fix out
1 parent 311b737 commit ae814e1

File tree

7 files changed

+342
-27
lines changed

7 files changed

+342
-27
lines changed

python/paddle/_paddle_docs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,7 +2460,7 @@ def bmm(
24602460
""",
24612461
"""
24622462
def logical_and(
2463-
x: Tensor, y: Tensor, out: Tensor | None = None, name: str | None = None
2463+
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
24642464
) -> Tensor
24652465
""",
24662466
)
@@ -2511,7 +2511,7 @@ def logical_and(
25112511
""",
25122512
"""
25132513
def logical_or(
2514-
x: Tensor, y: Tensor, out: Tensor | None = None, name: str | None = None
2514+
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
25152515
) -> Tensor
25162516
""",
25172517
)
@@ -2557,7 +2557,7 @@ def logical_or(
25572557
""",
25582558
"""
25592559
def logical_not(
2560-
x: Tensor, out: Tensor | None = None, name: str | None = None
2560+
x: Tensor, name: str | None = None, *, out: Tensor | None = None
25612561
) -> Tensor
25622562
""",
25632563
)
@@ -2608,7 +2608,7 @@ def logical_not(
26082608
""",
26092609
"""
26102610
def logical_xor(
2611-
x: Tensor, y: Tensor, out: Tensor | None = None, name: str | None = None
2611+
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
26122612
) -> Tensor
26132613
""",
26142614
)
@@ -2705,7 +2705,7 @@ def dot(
27052705
""",
27062706
"""
27072707
def tanh(
2708-
x: Tensor, *, out: Tensor | None = None, name: str | None = None,
2708+
x: Tensor, name: str | None = None, *, out: Tensor | None = None
27092709
) -> Tensor
27102710
""",
27112711
)
@@ -2745,7 +2745,7 @@ def tanh(
27452745
""",
27462746
"""
27472747
def exp(
2748-
x: Tensor, *, out: Tensor | None = None, name: str | None = None
2748+
x: Tensor, name: str | None = None, *, out: Tensor | None = None
27492749
) -> Tensor
27502750
""",
27512751
)
@@ -2785,7 +2785,7 @@ def exp(
27852785
""",
27862786
"""
27872787
def expm1(
2788-
x: Tensor, *, out: Tensor | None = None, name: str | None = None
2788+
x: Tensor, name: str | None = None, *, out: Tensor | None = None
27892789
) -> Tensor
27902790
""",
27912791
)
@@ -2917,7 +2917,7 @@ def diagonal(
29172917
""",
29182918
"""
29192919
def round(
2920-
x: Tensor, decimals = 0, *, out: Tensor | None = None, name: str | None = None,
2920+
x: Tensor, decimals: int = 0, name: str | None = None, *, out: Tensor | None = None,
29212921
) -> Tensor
29222922
""",
29232923
)

python/paddle/nn/functional/norm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
in_pir_mode,
2929
)
3030
from paddle.utils.decorator_utils import (
31+
ParamAliasDecorator,
3132
param_two_alias,
3233
)
3334

@@ -49,12 +50,15 @@
4950
__all__ = []
5051

5152

53+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"], "epsilon": ["eps"]})
5254
def normalize(
5355
x: Tensor,
5456
p: float = 2,
5557
axis: int = 1,
5658
epsilon: float = 1e-12,
5759
name: str | None = None,
60+
*,
61+
out: Tensor | None = None,
5862
) -> Tensor:
5963
r"""
6064
Normalize ``x`` along dimension ``axis`` using :math:`L_p` norm. This layer computes
@@ -68,13 +72,16 @@ def normalize(
6872
6973
where, :math:`\sum_i{\lvert x_i \rvert^p}` is calculated along the ``axis`` dimension.
7074
71-
7275
Parameters:
7376
x (Tensor): The input tensor could be N-D tensor, and the input data type could be float32 or float64.
77+
Alias: ``input``.
7478
p (float|int, optional): The exponent value in the norm formulation. Default: 2.
7579
axis (int, optional): The axis on which to apply normalization. If `axis < 0`, the dimension to normalization is `x.ndim + axis`. -1 is the last dimension.
80+
Alias: ``dim``.
7681
epsilon (float, optional): Small float added to denominator to avoid dividing by zero. Default is 1e-12.
82+
Alias: ``esp``.
7783
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
84+
out (Tensor|None, optional): The output tensor. Default: None.
7885
7986
Returns:
8087
Tensor, the output has the same shape and data type with ``x``.
@@ -110,13 +117,13 @@ def normalize(
110117

111118
if in_dygraph_mode():
112119
eps = paddle.full(shape=[1], fill_value=epsilon, dtype=x.dtype)
113-
out = _C_ops.p_norm(x, float(p), axis, epsilon, True, False)
114-
return x / _C_ops.maximum(out, eps)
120+
ret = _C_ops.p_norm(x, float(p), axis, epsilon, True, False)
121+
ret = x / _C_ops.maximum(ret, eps)
115122

116123
elif in_pir_mode():
117124
eps = paddle.full(shape=[1], fill_value=epsilon, dtype=x.dtype)
118-
out = _C_ops.p_norm(x, float(p), axis, epsilon, True, False)
119-
return paddle.divide(x, _C_ops.maximum(out, eps), name=name)
125+
ret = _C_ops.p_norm(x, float(p), axis, epsilon, True, False)
126+
ret = paddle.divide(x, _C_ops.maximum(ret, eps), name=name)
120127

121128
else:
122129
check_type(p, 'p', (float, int), 'normalize')
@@ -142,7 +149,12 @@ def normalize(
142149
)
143150
eps = out.block.create_var(dtype=out.dtype)
144151
eps = paddle.full(shape=[1], fill_value=epsilon, dtype=out.dtype)
145-
return paddle.divide(x, paddle.maximum(out, eps), name=name)
152+
out = paddle.divide(x, paddle.maximum(out, eps), name=name)
153+
154+
if out is not None:
155+
paddle.assign(ret, out)
156+
return out
157+
return ret
146158

147159

148160
def batch_norm(

python/paddle/tensor/linalg.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from paddle.utils.decorator_utils import (
2929
ParamAliasDecorator,
3030
VariableArgsDecorator,
31+
param_two_alias,
3132
transpose_decorator,
3233
)
3334
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
@@ -4178,8 +4179,14 @@ def _transpose_last_2dim(x):
41784179
return x
41794180

41804181

4182+
@param_two_alias(["x", "A"], ["y", "B"])
41814183
def solve(
4182-
x: Tensor, y: Tensor, left: bool = True, name: str | None = None
4184+
x: Tensor,
4185+
y: Tensor,
4186+
left: bool = True,
4187+
name: str | None = None,
4188+
*,
4189+
out: Tensor | None = None,
41834190
) -> Tensor:
41844191
r"""
41854192
@@ -4199,12 +4206,13 @@ def solve(
41994206
42004207
Args:
42014208
x (Tensor): A square matrix or a batch of square matrices. Its shape should be ``[*, M, M]``, where ``*`` is zero or
4202-
more batch dimensions. Its data type should be float32 or float64.
4209+
more batch dimensions. Its data type should be float32 or float64. Alias: ``A``.
42034210
y (Tensor): A vector/matrix or a batch of vectors/matrices. Its shape should be ``[*, M, K]``, where ``*`` is zero or
4204-
more batch dimensions. Its data type should be float32 or float64.
4211+
more batch dimensions. Its data type should be float32 or float64. Alias: ``B``.
42054212
left (bool, optional): Whether to solve the system :math:`X * Out = Y` or :math:`Out * X = Y`. Default: True.
42064213
name (str|None, optional): Name for the operation (optional, default is None).
42074214
For more information, please refer to :ref:`api_guide_Name`.
4215+
out (Tensor|None, optional): The output tensor. Default: None.
42084216
42094217
Returns:
42104218
Tensor: The solution of a square system of linear equations with a unique solution for input 'x' and 'y'.
@@ -4234,7 +4242,7 @@ def solve(
42344242
y = _transpose_last_2dim(y)
42354243

42364244
if in_dynamic_or_pir_mode():
4237-
out = _C_ops.solve(x, y)
4245+
ret = _C_ops.solve(x, y)
42384246
else:
42394247
inputs = {"X": [x], "Y": [y]}
42404248
helper = LayerHelper("solve", **locals())
@@ -4247,8 +4255,10 @@ def solve(
42474255
)
42484256

42494257
if not left:
4250-
out = _transpose_last_2dim(out)
4251-
return out
4258+
ret = _transpose_last_2dim(ret)
4259+
if out is not None:
4260+
paddle.assign(ret, out)
4261+
return ret
42524262

42534263

42544264
def triangular_solve(

python/paddle/tensor/stat.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def _compute_quantile(
717717
keepdim: bool = False,
718718
interpolation: _Interpolation = "linear",
719719
ignore_nan: bool = False,
720+
out: Tensor | None = None,
720721
) -> Tensor:
721722
"""
722723
Compute the quantile of the input along the specified axis.
@@ -742,6 +743,7 @@ def _compute_quantile(
742743
ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor.
743744
If ``ignore_nan`` is True, it will calculate nanquantile.
744745
Otherwise it will calculate quantile. Default is False.
746+
out (Tensor|None, optional): The output tensor. Default: None.
745747
746748
Returns:
747749
Tensor, results of quantile along ``axis`` of ``x``.
@@ -879,27 +881,34 @@ def _compute_index(index):
879881

880882
# TODO(chenjianye): replace the for-loop to directly take elements.
881883
for index in indices:
882-
out = _compute_index(index)
884+
ret = _compute_index(index)
883885
if not keepdim:
884-
out = paddle.squeeze(out, axis=axis)
886+
ret = paddle.squeeze(ret, axis=axis)
885887
else:
886-
out = out.reshape(out_shape)
887-
outputs.append(out)
888+
ret = ret.reshape(out_shape)
889+
outputs.append(ret)
888890

889891
if len(outputs) > 1:
890892
outputs = paddle.stack(outputs, 0)
891893
else:
892894
outputs = outputs[0]
893-
# return outputs.astype(x.dtype)
895+
896+
if out is not None:
897+
paddle.assign(outputs, out)
898+
return out
894899
return outputs
895900

896901

902+
@param_two_alias(["x", "input"], ["axis", "dim"])
897903
def quantile(
898904
x: Tensor,
899905
q: float | Sequence[float] | Tensor,
900906
axis: int | list[int] | None = None,
901907
keepdim: bool = False,
902908
interpolation: _Interpolation = "linear",
909+
name: str | None = None,
910+
*,
911+
out: Tensor | None = None,
903912
) -> Tensor:
904913
"""
905914
Compute the quantile of the input along the specified axis.
@@ -925,6 +934,8 @@ def quantile(
925934
lower, midpoint and nearest. Default is linear.
926935
name (str, optional): Name for the operation (optional, default is None).
927936
For more information, please refer to :ref:`api_guide_Name`.
937+
out (Tensor|None, optional): The output tensor. Default: None.
938+
928939
929940
Returns:
930941
Tensor, results of quantile along ``axis`` of ``x``.
@@ -975,6 +986,7 @@ def quantile(
975986
keepdim=keepdim,
976987
interpolation=interpolation,
977988
ignore_nan=False,
989+
out=out,
978990
)
979991

980992

test/legacy_test/test_normalize.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,98 @@ def test_gpu(self):
102102
self.run_static(use_gpu=True)
103103

104104

105+
class TestNormalizeAPI_Compatibility(unittest.TestCase):
106+
def setUp(self):
107+
np.random.seed(2025)
108+
self.places = ['cpu', get_device_place()]
109+
self.shape = [2, 3, 4]
110+
self.dtype = "float32"
111+
self.init_data()
112+
113+
def init_data(self):
114+
self.np_x = np.random.rand(*self.shape).astype(self.dtype)
115+
self.p = 2
116+
self.axis = 1
117+
self.epsilon = 1e-12
118+
119+
def test_dygraph_Compatibility(self):
120+
paddle.disable_static()
121+
x = paddle.to_tensor(self.np_x)
122+
paddle_dygraph_out = []
123+
# Position args (args)
124+
out1 = paddle.nn.functional.normalize(
125+
x, self.p, self.axis, self.epsilon
126+
)
127+
paddle_dygraph_out.append(out1)
128+
# Key words args (kwargs) for paddle
129+
out2 = paddle.nn.functional.normalize(
130+
x=x, p=self.p, axis=self.axis, epsilon=self.epsilon
131+
)
132+
paddle_dygraph_out.append(out2)
133+
# Key words args for torch compatibility
134+
out3 = paddle.nn.functional.normalize(
135+
input=x, p=self.p, dim=self.axis, eps=self.epsilon
136+
)
137+
paddle_dygraph_out.append(out3)
138+
# Key words args for out
139+
out4 = paddle.zeros_like(x)
140+
paddle.nn.functional.normalize(
141+
x, self.p, self.axis, self.epsilon, out=out4
142+
)
143+
paddle_dygraph_out.append(out4)
144+
# Numpy reference output
145+
ref_out = self.np_x / np.maximum(
146+
np.linalg.norm(
147+
self.np_x, ord=self.p, axis=self.axis, keepdims=True
148+
),
149+
self.epsilon,
150+
)
151+
152+
for out in paddle_dygraph_out:
153+
np.testing.assert_allclose(
154+
ref_out, out.numpy(), rtol=1e-05, atol=1e-08
155+
)
156+
paddle.enable_static()
157+
158+
def test_static_Compatibility(self):
159+
paddle.enable_static()
160+
main = paddle.static.Program()
161+
startup = paddle.static.Program()
162+
with paddle.base.program_guard(main, startup):
163+
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
164+
# Position args (args)
165+
out1 = paddle.nn.functional.normalize(
166+
x, self.p, self.axis, self.epsilon
167+
)
168+
# Key words args (kwargs) for paddle
169+
out2 = paddle.nn.functional.normalize(
170+
x=x, p=self.p, axis=self.axis, epsilon=self.epsilon
171+
)
172+
# Key words args for torch compatibility
173+
out3 = paddle.nn.functional.normalize(
174+
input=x, p=self.p, dim=self.axis, eps=self.epsilon
175+
)
176+
# Numpy reference output
177+
ref_out = self.np_x / np.maximum(
178+
np.linalg.norm(
179+
self.np_x, ord=self.p, axis=self.axis, keepdims=True
180+
),
181+
self.epsilon,
182+
)
183+
184+
fetch_list = [out1, out2, out3]
185+
for place in self.places:
186+
exe = paddle.base.Executor(place)
187+
fetches = exe.run(
188+
main,
189+
feed={"x": self.np_x},
190+
fetch_list=fetch_list,
191+
)
192+
for out in fetches:
193+
np.testing.assert_allclose(
194+
out, ref_out, rtol=1e-05, atol=1e-08
195+
)
196+
197+
105198
if __name__ == "__main__":
106199
unittest.main()

0 commit comments

Comments
 (0)