From c1d3b097dc6bb11eec789fb531847b19719da628 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 26 Nov 2025 12:19:56 +0000 Subject: [PATCH 1/4] add softmax --- python/paddle/compat/nn/__init__.py | 143 ++++++++++++++++++++++-- test/legacy_test/test_compat_softmax.py | 80 +++++++++++++ 2 files changed, 215 insertions(+), 8 deletions(-) create mode 100644 test/legacy_test/test_compat_softmax.py diff --git a/python/paddle/compat/nn/__init__.py b/python/paddle/compat/nn/__init__.py index 2c0241fe5377f7..54afb8e6dd20ba 100644 --- a/python/paddle/compat/nn/__init__.py +++ b/python/paddle/compat/nn/__init__.py @@ -40,6 +40,7 @@ __all__ = [ 'Unfold', 'Linear', + 'Softmax', 'AvgPool1D', 'AvgPool2D', 'AvgPool3D', @@ -400,9 +401,6 @@ def __setstate__(self, state): self.__dict__.setdefault("count_include_pad", True) -__all__ = ['Unfold', 'Linear', 'MultiheadAttention'] - - class Unfold(nn.Unfold): """ A compatible version of paddle.nn.Unfold: @@ -441,7 +439,13 @@ class Unfold(nn.Unfold): strides: Size2 @ForbidKeywordsDecorator( - illegal_keys={"kernel_sizes", "dilations", "paddings", "strides"}, + illegal_keys={ + "kernel_sizes", + "dilations", + "paddings", + "strides", + "name", + }, func_name="paddle.compat.nn.Unfold", correct_name="paddle.nn.Unfold", ) @@ -466,7 +470,6 @@ def to_list_if_necessary(x): strides=to_list_if_necessary(self.strides), paddings=to_list_if_necessary(self.paddings), dilations=to_list_if_necessary(self.dilations), - name=self.name, ) @@ -613,6 +616,130 @@ def reset_parameters(self) -> None: nn.init.uniform_(self.bias, -bound, bound) -AvgPool1d = AvgPool1D -AvgPool2d = AvgPool2D -AvgPool3d = AvgPool3D +class Softmax(nn.Layer): + r""" + Softmax Activation. + + This operator implements the softmax layer. The calculation process is as follows: + + 1. The dimension :attr:`dim` of ``input`` will be permuted to the last. + + 2. Then ``input`` will be logically flattened to a 2-D matrix. The matrix's second + dimension(row length) is the same as the dimension :attr:`dim` of ``input``, + and the first dimension(column length) is the product of all other dimensions + of ``input``. For each row of the matrix, the softmax operator squashes the + K-dimensional(K is the width of the matrix, which is also the size of ``input``'s + dimension :attr:`dim`) vector of arbitrary real values to a K-dimensional + vector of real values in the range [0, 1] that add up to 1. + + 3. After the softmax operation is completed, the inverse operations of steps 1 and 2 + are performed to restore the two-dimensional matrix to the same dimension as the ``input`` . + + It computes the exponential of the given dimension and the sum of exponential + values of all the other dimensions in the K-dimensional vector input. + Then the ratio of the exponential of the given dimension and the sum of + exponential values of all the other dimensions is the output of the softmax + operator. + + For each row :math:`i` and each column :math:`j` in the matrix, we have: + + .. math:: + + Softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])} + + Example: + + .. code-block:: text + + Case 1: + Input: + x.shape = [2, 3, 4] + x.data = [[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]] + + Attrs: + dim = -1 + + Output: + out.shape = [2, 3, 4] + out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.07232949, 0.19661193, 0.19661193, 0.53444665]], + [[0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426], + [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]] + + Case 2: + Input: + x.shape = [2, 3, 4] + x.data = [[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]] + Attrs: + dim = 1 + + Output: + out.shape = [2, 3, 4] + out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783], + [0.01786798, 0.01786798, 0.04661262, 0.04661262], + [0.97555875, 0.97555875, 0.93623955, 0.93623955]], + [[0.00490169, 0.00490169, 0.00490169, 0.00490169], + [0.26762315, 0.26762315, 0.26762315, 0.26762315], + [0.72747516, 0.72747516, 0.72747516, 0.72747516]]] + + Parameters: + dim (int, optional): The dim along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of ``input`` . If ``dim`` < 0, it works the same way as + :math:`dim + D` . Default is None. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([[[2.0, 3.0, 4.0, 5.0], + ... [3.0, 4.0, 5.0, 6.0], + ... [7.0, 8.0, 8.0, 9.0]], + ... [[1.0, 2.0, 3.0, 4.0], + ... [5.0, 6.0, 7.0, 8.0], + ... [6.0, 7.0, 8.0, 9.0]]], dtype='float32') + >>> m = paddle.compat.nn.Softmax() + >>> out = m(x) + >>> print(out) + Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[0.73105854, 0.73105854, 0.73105854, 0.73105854], + [0.11920292, 0.11920292, 0.11920292, 0.11920292], + [0.73105854, 0.73105854, 0.50000000, 0.50000000]], + [[0.26894143, 0.26894143, 0.26894143, 0.26894143], + [0.88079703, 0.88079703, 0.88079703, 0.88079703], + [0.26894143, 0.26894143, 0.50000000, 0.50000000]]]) + + """ + + @ForbidKeywordsDecorator( + illegal_keys={"axis", "name"}, + func_name="paddle.compat.nn.Softmax", + correct_name="paddle.nn.Softmax", + ) + def __init__(self, dim: int | None = None) -> None: + super().__init__() + self._dim = dim + self._dtype = None + + def forward(self, input: Tensor) -> Tensor: + return functional.softmax(input, self._dim) + + def extra_repr(self) -> str: + return f"dim={self.dim}" diff --git a/test/legacy_test/test_compat_softmax.py b/test/legacy_test/test_compat_softmax.py new file mode 100644 index 00000000000000..1f34455dfd71fc --- /dev/null +++ b/test/legacy_test/test_compat_softmax.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np + +import paddle + + +class TestCompatSoftmax(unittest.TestCase): + def _compare_with_origin(self, input_tensor, axis): + softmax_compat = paddle.compat.nn.Softmax(dim=axis) + softmax_origin = paddle.nn.Softmax(axis=axis) + + expected_res = softmax_origin(input_tensor).numpy() + np.testing.assert_allclose( + softmax_compat(input_tensor).numpy(), + expected_res, + rtol=1e-6, + atol=1e-6, + ) + + def test_compare_with_origin(self): + input_shape = (3, 4) + input_tensor = paddle.randn(input_shape, dtype=paddle.float32) + self._compare_with_origin(input_tensor, axis=0) + self._compare_with_origin(input_tensor, axis=1) + self._compare_with_origin(input_tensor, axis=-1) + + input_shape = (2, 3, 4) + input_tensor = paddle.randn(input_shape, dtype=paddle.float64) + self._compare_with_origin(input_tensor, axis=0) + self._compare_with_origin(input_tensor, axis=1) + self._compare_with_origin(input_tensor, axis=2) + self._compare_with_origin(input_tensor, axis=-1) + + input_shape = (2, 3, 4, 5) + input_tensor = paddle.randn(input_shape, dtype=paddle.float32) + self._compare_with_origin(input_tensor, axis=1) + self._compare_with_origin(input_tensor, axis=-2) + + input_tensor = paddle.randn((2, 3), dtype=paddle.float32) + softmax_compat = paddle.compat.nn.Softmax() + softmax_origin = paddle.nn.Softmax() + expected_res = softmax_origin(input_tensor).numpy() + np.testing.assert_allclose( + softmax_compat(input_tensor).numpy(), + expected_res, + rtol=1e-6, + atol=1e-6, + ) + + def test_error_handling(self): + x = paddle.randn([3, 9, 5]) + + msg_gt_1 = "paddle.compat.nn.Softmax() received unexpected keyword argument 'axis'. \nDid you mean to use paddle.nn.Softmax() instead?" + msg_gt_2 = "paddle.compat.nn.Softmax() received unexpected keyword arguments 'axis', 'name'. \nDid you mean to use paddle.nn.Softmax() instead?" + + with self.assertRaises(TypeError) as cm: + softmax = paddle.compat.nn.Softmax(axis=1) + self.assertEqual(str(cm.exception), msg_gt_1) + + with self.assertRaises(TypeError) as cm: + softmax = paddle.compat.nn.Softmax(axis=1, name="softmax") + self.assertEqual(str(cm.exception), msg_gt_2) + + +if __name__ == "__main__": + unittest.main() From 0ea957eea090359347602534720ae202ce08ab84 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 26 Nov 2025 13:28:37 +0000 Subject: [PATCH 2/4] fix --- python/paddle/compat/nn/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/compat/nn/__init__.py b/python/paddle/compat/nn/__init__.py index 54afb8e6dd20ba..07710892d7c689 100644 --- a/python/paddle/compat/nn/__init__.py +++ b/python/paddle/compat/nn/__init__.py @@ -439,13 +439,7 @@ class Unfold(nn.Unfold): strides: Size2 @ForbidKeywordsDecorator( - illegal_keys={ - "kernel_sizes", - "dilations", - "paddings", - "strides", - "name", - }, + illegal_keys={"kernel_sizes", "dilations", "paddings", "strides"}, func_name="paddle.compat.nn.Unfold", correct_name="paddle.nn.Unfold", ) @@ -470,6 +464,7 @@ def to_list_if_necessary(x): strides=to_list_if_necessary(self.strides), paddings=to_list_if_necessary(self.paddings), dilations=to_list_if_necessary(self.dilations), + name=self.name, ) @@ -743,3 +738,8 @@ def forward(self, input: Tensor) -> Tensor: def extra_repr(self) -> str: return f"dim={self.dim}" + + +AvgPool1d = AvgPool1D +AvgPool2d = AvgPool2D +AvgPool3d = AvgPool3D From bf5688edb62259de0cfd0922e6d118d0e4199d1c Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 26 Nov 2025 13:30:42 +0000 Subject: [PATCH 3/4] fix --- python/paddle/compat/nn/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/compat/nn/__init__.py b/python/paddle/compat/nn/__init__.py index 07710892d7c689..b420ff2ac47ec0 100644 --- a/python/paddle/compat/nn/__init__.py +++ b/python/paddle/compat/nn/__init__.py @@ -464,7 +464,6 @@ def to_list_if_necessary(x): strides=to_list_if_necessary(self.strides), paddings=to_list_if_necessary(self.paddings), dilations=to_list_if_necessary(self.dilations), - name=self.name, ) @@ -724,7 +723,7 @@ class Softmax(nn.Layer): """ @ForbidKeywordsDecorator( - illegal_keys={"axis", "name"}, + illegal_keys={"axis"}, func_name="paddle.compat.nn.Softmax", correct_name="paddle.nn.Softmax", ) From 8fddbe98595942e487108d2b569628b3c1be11da Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Wed, 26 Nov 2025 13:52:02 +0000 Subject: [PATCH 4/4] fix ut --- test/legacy_test/test_compat_softmax.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/legacy_test/test_compat_softmax.py b/test/legacy_test/test_compat_softmax.py index 1f34455dfd71fc..0fd618de8e2bd3 100644 --- a/test/legacy_test/test_compat_softmax.py +++ b/test/legacy_test/test_compat_softmax.py @@ -65,16 +65,11 @@ def test_error_handling(self): x = paddle.randn([3, 9, 5]) msg_gt_1 = "paddle.compat.nn.Softmax() received unexpected keyword argument 'axis'. \nDid you mean to use paddle.nn.Softmax() instead?" - msg_gt_2 = "paddle.compat.nn.Softmax() received unexpected keyword arguments 'axis', 'name'. \nDid you mean to use paddle.nn.Softmax() instead?" with self.assertRaises(TypeError) as cm: softmax = paddle.compat.nn.Softmax(axis=1) self.assertEqual(str(cm.exception), msg_gt_1) - with self.assertRaises(TypeError) as cm: - softmax = paddle.compat.nn.Softmax(axis=1, name="softmax") - self.assertEqual(str(cm.exception), msg_gt_2) - if __name__ == "__main__": unittest.main()