Skip to content

Commit 935bf3f

Browse files
authored
[API Compatiblity] Add paddle.compat.nn.Softmax (#76637)
* add softmax * fix * fix * fix ut
1 parent bafb572 commit 935bf3f

File tree

2 files changed

+205
-1
lines changed

2 files changed

+205
-1
lines changed

python/paddle/compat/nn/__init__.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
__all__ = [
4141
'Unfold',
4242
'Linear',
43+
'Softmax',
4344
'AvgPool1D',
4445
'AvgPool2D',
4546
'AvgPool3D',
@@ -463,7 +464,6 @@ def to_list_if_necessary(x):
463464
strides=to_list_if_necessary(self.strides),
464465
paddings=to_list_if_necessary(self.paddings),
465466
dilations=to_list_if_necessary(self.dilations),
466-
name=self.name,
467467
)
468468

469469

@@ -610,6 +610,135 @@ def reset_parameters(self) -> None:
610610
nn.init.uniform_(self.bias, -bound, bound)
611611

612612

613+
class Softmax(nn.Layer):
614+
r"""
615+
Softmax Activation.
616+
617+
This operator implements the softmax layer. The calculation process is as follows:
618+
619+
1. The dimension :attr:`dim` of ``input`` will be permuted to the last.
620+
621+
2. Then ``input`` will be logically flattened to a 2-D matrix. The matrix's second
622+
dimension(row length) is the same as the dimension :attr:`dim` of ``input``,
623+
and the first dimension(column length) is the product of all other dimensions
624+
of ``input``. For each row of the matrix, the softmax operator squashes the
625+
K-dimensional(K is the width of the matrix, which is also the size of ``input``'s
626+
dimension :attr:`dim`) vector of arbitrary real values to a K-dimensional
627+
vector of real values in the range [0, 1] that add up to 1.
628+
629+
3. After the softmax operation is completed, the inverse operations of steps 1 and 2
630+
are performed to restore the two-dimensional matrix to the same dimension as the ``input`` .
631+
632+
It computes the exponential of the given dimension and the sum of exponential
633+
values of all the other dimensions in the K-dimensional vector input.
634+
Then the ratio of the exponential of the given dimension and the sum of
635+
exponential values of all the other dimensions is the output of the softmax
636+
operator.
637+
638+
For each row :math:`i` and each column :math:`j` in the matrix, we have:
639+
640+
.. math::
641+
642+
Softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}
643+
644+
Example:
645+
646+
.. code-block:: text
647+
648+
Case 1:
649+
Input:
650+
x.shape = [2, 3, 4]
651+
x.data = [[[2.0, 3.0, 4.0, 5.0],
652+
[3.0, 4.0, 5.0, 6.0],
653+
[7.0, 8.0, 8.0, 9.0]],
654+
[[1.0, 2.0, 3.0, 4.0],
655+
[5.0, 6.0, 7.0, 8.0],
656+
[6.0, 7.0, 8.0, 9.0]]]
657+
658+
Attrs:
659+
dim = -1
660+
661+
Output:
662+
out.shape = [2, 3, 4]
663+
out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
664+
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
665+
[0.07232949, 0.19661193, 0.19661193, 0.53444665]],
666+
[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
667+
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
668+
[0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]
669+
670+
Case 2:
671+
Input:
672+
x.shape = [2, 3, 4]
673+
x.data = [[[2.0, 3.0, 4.0, 5.0],
674+
[3.0, 4.0, 5.0, 6.0],
675+
[7.0, 8.0, 8.0, 9.0]],
676+
[[1.0, 2.0, 3.0, 4.0],
677+
[5.0, 6.0, 7.0, 8.0],
678+
[6.0, 7.0, 8.0, 9.0]]]
679+
Attrs:
680+
dim = 1
681+
682+
Output:
683+
out.shape = [2, 3, 4]
684+
out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
685+
[0.01786798, 0.01786798, 0.04661262, 0.04661262],
686+
[0.97555875, 0.97555875, 0.93623955, 0.93623955]],
687+
[[0.00490169, 0.00490169, 0.00490169, 0.00490169],
688+
[0.26762315, 0.26762315, 0.26762315, 0.26762315],
689+
[0.72747516, 0.72747516, 0.72747516, 0.72747516]]]
690+
691+
Parameters:
692+
dim (int, optional): The dim along which to perform log_softmax
693+
calculations. It should be in range [-D, D), where D is the
694+
dimensions of ``input`` . If ``dim`` < 0, it works the same way as
695+
:math:`dim + D` . Default is None.
696+
697+
Shape:
698+
- input: Tensor with any shape.
699+
- output: Tensor with the same shape as input.
700+
701+
Examples:
702+
.. code-block:: python
703+
704+
>>> import paddle
705+
706+
>>> x = paddle.to_tensor([[[2.0, 3.0, 4.0, 5.0],
707+
... [3.0, 4.0, 5.0, 6.0],
708+
... [7.0, 8.0, 8.0, 9.0]],
709+
... [[1.0, 2.0, 3.0, 4.0],
710+
... [5.0, 6.0, 7.0, 8.0],
711+
... [6.0, 7.0, 8.0, 9.0]]], dtype='float32')
712+
>>> m = paddle.compat.nn.Softmax()
713+
>>> out = m(x)
714+
>>> print(out)
715+
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
716+
[[[0.73105854, 0.73105854, 0.73105854, 0.73105854],
717+
[0.11920292, 0.11920292, 0.11920292, 0.11920292],
718+
[0.73105854, 0.73105854, 0.50000000, 0.50000000]],
719+
[[0.26894143, 0.26894143, 0.26894143, 0.26894143],
720+
[0.88079703, 0.88079703, 0.88079703, 0.88079703],
721+
[0.26894143, 0.26894143, 0.50000000, 0.50000000]]])
722+
723+
"""
724+
725+
@ForbidKeywordsDecorator(
726+
illegal_keys={"axis"},
727+
func_name="paddle.compat.nn.Softmax",
728+
correct_name="paddle.nn.Softmax",
729+
)
730+
def __init__(self, dim: int | None = None) -> None:
731+
super().__init__()
732+
self._dim = dim
733+
self._dtype = None
734+
735+
def forward(self, input: Tensor) -> Tensor:
736+
return functional.softmax(input, self._dim)
737+
738+
def extra_repr(self) -> str:
739+
return f"dim={self.dim}"
740+
741+
613742
AvgPool1d = AvgPool1D
614743
AvgPool2d = AvgPool2D
615744
AvgPool3d = AvgPool3D
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
16+
import numpy as np
17+
18+
import paddle
19+
20+
21+
class TestCompatSoftmax(unittest.TestCase):
22+
def _compare_with_origin(self, input_tensor, axis):
23+
softmax_compat = paddle.compat.nn.Softmax(dim=axis)
24+
softmax_origin = paddle.nn.Softmax(axis=axis)
25+
26+
expected_res = softmax_origin(input_tensor).numpy()
27+
np.testing.assert_allclose(
28+
softmax_compat(input_tensor).numpy(),
29+
expected_res,
30+
rtol=1e-6,
31+
atol=1e-6,
32+
)
33+
34+
def test_compare_with_origin(self):
35+
input_shape = (3, 4)
36+
input_tensor = paddle.randn(input_shape, dtype=paddle.float32)
37+
self._compare_with_origin(input_tensor, axis=0)
38+
self._compare_with_origin(input_tensor, axis=1)
39+
self._compare_with_origin(input_tensor, axis=-1)
40+
41+
input_shape = (2, 3, 4)
42+
input_tensor = paddle.randn(input_shape, dtype=paddle.float64)
43+
self._compare_with_origin(input_tensor, axis=0)
44+
self._compare_with_origin(input_tensor, axis=1)
45+
self._compare_with_origin(input_tensor, axis=2)
46+
self._compare_with_origin(input_tensor, axis=-1)
47+
48+
input_shape = (2, 3, 4, 5)
49+
input_tensor = paddle.randn(input_shape, dtype=paddle.float32)
50+
self._compare_with_origin(input_tensor, axis=1)
51+
self._compare_with_origin(input_tensor, axis=-2)
52+
53+
input_tensor = paddle.randn((2, 3), dtype=paddle.float32)
54+
softmax_compat = paddle.compat.nn.Softmax()
55+
softmax_origin = paddle.nn.Softmax()
56+
expected_res = softmax_origin(input_tensor).numpy()
57+
np.testing.assert_allclose(
58+
softmax_compat(input_tensor).numpy(),
59+
expected_res,
60+
rtol=1e-6,
61+
atol=1e-6,
62+
)
63+
64+
def test_error_handling(self):
65+
x = paddle.randn([3, 9, 5])
66+
67+
msg_gt_1 = "paddle.compat.nn.Softmax() received unexpected keyword argument 'axis'. \nDid you mean to use paddle.nn.Softmax() instead?"
68+
69+
with self.assertRaises(TypeError) as cm:
70+
softmax = paddle.compat.nn.Softmax(axis=1)
71+
self.assertEqual(str(cm.exception), msg_gt_1)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main()

0 commit comments

Comments
 (0)