Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Optional, Union

import torch
Expand All @@ -25,7 +26,7 @@ def _psnr_compute(
num_obs: Tensor,
data_range: Tensor,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Compute peak signal-to-noise ratio.

Expand Down Expand Up @@ -58,7 +59,7 @@ def _psnr_compute(
def _psnr_update(
preds: Tensor,
target: Tensor,
dim: Optional[Union[int, tuple[int, ...]]] = None,
dim: Optional[Union[int, Sequence[int]]] = None,
) -> tuple[Tensor, Tensor]:
"""Update and return variables required to compute peak signal-to-noise ratio.

Expand Down Expand Up @@ -95,18 +96,18 @@ def _psnr_update(
def peak_signal_noise_ratio(
preds: Tensor,
target: Tensor,
data_range: Union[float, tuple[float, float]],
data_range: Union[float, Sequence[float]],
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
dim: Optional[Union[int, Sequence[int]]] = None,
) -> Tensor:
"""Compute the peak signal-to-noise ratio.

Args:
preds: estimated signal
target: groun truth signal
data_range:
the range of the data. If a tuple is provided then the range is calculated as the difference and
the range of the data. If a Sequence is provided then the range is calculated as the difference and
input is clamped between the values.
base: a base of a logarithm to use
reduction: a method to reduce metric score over labels.
Expand Down Expand Up @@ -136,7 +137,7 @@ def peak_signal_noise_ratio(
if dim is None and reduction != "elementwise_mean":
rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")

if isinstance(data_range, tuple):
if isinstance(data_range, Sequence):
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
target = torch.clamp(target, min=data_range[0], max=data_range[1])
data_range_val = tensor(data_range[1] - data_range[0])
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Sequence
from typing import Union

import torch
Expand Down Expand Up @@ -102,7 +103,7 @@ def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[T
def peak_signal_noise_ratio_with_blocked_effect(
preds: Tensor,
target: Tensor,
data_range: Union[float, tuple[float, float]],
data_range: Union[float, Sequence[float]],
block_size: int = 8,
) -> Tensor:
r"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics.
Expand All @@ -115,7 +116,7 @@ def peak_signal_noise_ratio_with_blocked_effect(
Args:
preds: estimated signal
target: ground truth signal
data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and
data_range: the range of the data. If a Sequence is provided then the range is calculated as the difference and
input is clamped between the values.
block_size: integer indication the block size

Expand All @@ -131,7 +132,7 @@ def peak_signal_noise_ratio_with_blocked_effect(
tensor(7.8402)

"""
if isinstance(data_range, tuple):
if isinstance(data_range, Sequence):
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
target = torch.clamp(target, min=data_range[0], max=data_range[1])
data_range_val = tensor(data_range[1] - data_range[0])
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PeakSignalNoiseRatio(Metric):

Args:
data_range:
the range of the data. If a tuple is provided, then the range is calculated as the difference and
the range of the data. If a Sequence is provided, then the range is calculated as the difference and
input is clamped between the values.
base: a base of a logarithm to use.
reduction: a method to reduce metric score over labels.
Expand Down Expand Up @@ -80,10 +80,10 @@ class PeakSignalNoiseRatio(Metric):

def __init__(
self,
data_range: Union[float, tuple[float, float]],
data_range: Union[float, Sequence[float]],
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
dim: Optional[Union[int, Sequence[int]]] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -99,15 +99,15 @@ def __init__(
self.add_state("total", default=[], dist_reduce_fx="cat")

self.clamping_fn = None
if isinstance(data_range, tuple):
if isinstance(data_range, Sequence):
self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean")
self.clamping_fn = partial(torch.clamp, min=data_range[0], max=data_range[1])
else:
self.add_state("data_range", default=tensor(float(data_range)), dist_reduce_fx="mean")

self.base = base
self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
self.dim = dim

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):
- ``psnrb`` (:class:`~torch.Tensor`): float scalar tensor with aggregated PSNRB value

Args:
data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and
data_range: the range of the data. If a Sequence is provided then the range is calculated as the difference and
input is clamped between the values.
block_size: integer indication the block size
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -74,7 +74,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):

def __init__(
self,
data_range: Union[float, tuple[float, float]],
data_range: Union[float, Sequence[float]],
block_size: int = 8,
**kwargs: Any,
) -> None:
Expand All @@ -87,7 +87,7 @@ def __init__(
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.add_state("bef", default=tensor(0.0), dist_reduce_fx="sum")

if isinstance(data_range, tuple):
if isinstance(data_range, Sequence):
self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean")
self.clamping_fn = lambda x: torch.clamp(x, min=data_range[0], max=data_range[1])
else:
Expand Down
Loading