diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index c5091dd2959710..68ed5439c76596 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -578,18 +578,19 @@ def requires_grad(self: Tensor, value: bool) -> None: ) self.stop_gradient = not value - def requires_grad_(self, value: bool) -> None: + def requires_grad_(self, requires_grad: bool = True) -> Tensor: """ Set whether this Tensor requires gradient computation. Args: value (bool): True to enable gradient computation, False to disable. """ - if not isinstance(value, bool): + if not isinstance(requires_grad, bool): raise TypeError( - f"requires_grad must be bool, but got {type(value)}" + f"requires_grad must be bool, but got {type(requires_grad)}" ) - self.stop_gradient = not value + self.stop_gradient = not requires_grad + return self @property def itemsize(self: Tensor) -> int: diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 2d3050906759b6..0a8dc872ccd214 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -14,6 +14,7 @@ import inspect import warnings +from typing import TYPE_CHECKING from .. import core from ..dygraph.base import in_to_static_mode @@ -24,6 +25,9 @@ static_only, ) +if TYPE_CHECKING: + from paddle import Tensor + _supported_int_dtype_ = [ core.VarDesc.VarType.BOOL, core.VarDesc.VarType.UINT8, @@ -598,18 +602,19 @@ def requires_grad(self, value: bool) -> None: ) self.stop_gradient = not value - def requires_grad_(self, value: bool) -> None: + def requires_grad_(self, requires_grad: bool = True) -> Tensor: """ Set whether this Tensor requires gradient computation. Args: value (bool): True to enable gradient computation, False to disable. """ - if not isinstance(value, bool): + if not isinstance(requires_grad, bool): raise TypeError( - f"requires_grad must be bool, but got {type(value)}" + f"requires_grad must be bool, but got {type(requires_grad)}" ) - self.stop_gradient = not value + self.stop_gradient = not requires_grad + return self def _scalar_add_(var, value): return _scalar_op_(var, 1.0, value) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 0e61f9204b8ae7..9fe48bfeb03630 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -33,9 +33,9 @@ from . import Value if TYPE_CHECKING: + from paddle import Tensor from paddle._typing import DTypeLike, PlaceLike, ShapeLike - _already_patch_value = False _supported_int_dtype_ = [ @@ -1456,18 +1456,19 @@ def requires_grad(self, value: bool) -> None: ) self.stop_gradient = not value - def requires_grad_(self, value: bool) -> None: + def requires_grad_(self, requires_grad: bool = True) -> Tensor: """ Set whether this Tensor requires gradient computation. Args: value (bool): True to enable gradient computation, False to disable. """ - if not isinstance(value, bool): + if not isinstance(requires_grad, bool): raise TypeError( - f"requires_grad must be bool, but got {type(value)}" + f"requires_grad must be bool, but got {type(requires_grad)}" ) - self.stop_gradient = not value + self.stop_gradient = not requires_grad + return self @property def itemsize(self) -> int: