Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,18 +578,19 @@ def requires_grad(self: Tensor, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

@property
def itemsize(self: Tensor) -> int:
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import inspect
import warnings

from paddle import Tensor

from .. import core
from ..dygraph.base import in_to_static_mode
from ..framework import (
Expand Down Expand Up @@ -598,18 +600,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

def _scalar_add_(var, value):
return _scalar_op_(var, 1.0, value)
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from . import Value

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import DTypeLike, PlaceLike, ShapeLike


_already_patch_value = False

_supported_int_dtype_ = [
Expand Down Expand Up @@ -1456,18 +1456,19 @@ def requires_grad(self, value: bool) -> None:
)
self.stop_gradient = not value

def requires_grad_(self, value: bool) -> None:
def requires_grad_(self, requires_grad: bool = True) -> Tensor:
"""
Set whether this Tensor requires gradient computation.

Args:
value (bool): True to enable gradient computation, False to disable.
"""
if not isinstance(value, bool):
if not isinstance(requires_grad, bool):
raise TypeError(
f"requires_grad must be bool, but got {type(value)}"
f"requires_grad must be bool, but got {type(requires_grad)}"
)
self.stop_gradient = not value
self.stop_gradient = not requires_grad
return self

@property
def itemsize(self) -> int:
Expand Down
Loading