Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,18 +578,19 @@ def requires_grad(self: Tensor, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

@property
def itemsize(self: Tensor) -> int:
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import inspect
import warnings

from paddle import Tensor
Copy link
Member

@SigureMo SigureMo Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 python/paddle/pir/math_op_patch.pyTYPE_CHECKING 下 import


from .. import core
from ..dygraph.base import in_to_static_mode
from ..framework import (
Expand Down Expand Up @@ -598,18 +600,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

def _scalar_add_(var, value):
return _scalar_op_(var, 1.0, value)
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from . import Value

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import DTypeLike, PlaceLike, ShapeLike


_already_patch_value = False

_supported_int_dtype_ = [
Expand Down Expand Up @@ -1456,18 +1456,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

@property
def itemsize(self) -> int:
Expand Down
Loading