Skip to content

Commit 7e4bba8

Browse files
VijayVignesh1rittik9
authored andcommitted
Adding upper_face_maps index check, removing plot lower bound, making device agnostic test implementation and correcting docstring mistakes
1 parent 494cbfb commit 7e4bba8

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

src/torchmetrics/functional/multimodal/fdd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def upper_face_dynamics_deviation(
2727
2828
The Upper Face Dynamics Deviation (FDD) metric evaluates the quality of facial expressions in the upper
2929
face region for 3D talking head models. It quantifies the deviation in vertex motion dynamics between the
30-
predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex squared displacements from the neutral template.
30+
predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex
31+
squared displacements from the neutral template.
3132
3233
The metric is defined as:
3334
@@ -58,7 +59,7 @@ def upper_face_dynamics_deviation(
5859
If the number of dimensions of `vertices_pred` or `vertices_gt` is not 3.
5960
If vertex dimensions (V) or coordinate dimensions (3) don't match.
6061
If ``upper_face_map`` is empty or contains invalid indices.
61-
If there are at least two frames to compute face dynamics deviation.
62+
If there are fewer than two frames to compute face dynamics deviation.
6263
6364
Example:
6465
>>> import torch
@@ -90,7 +91,6 @@ def upper_face_dynamics_deviation(
9091
f"upper_face_map contains invalid vertex indices. Max index {max(upper_face_map)} is larger than "
9192
f"number of vertices {vertices_pred.shape[1]}."
9293
)
93-
9494
min_frames = min(vertices_pred.shape[0], vertices_gt.shape[0])
9595
pred = vertices_pred[:min_frames, upper_face_map, :] # (T, M, 3)
9696
gt = vertices_gt[:min_frames, upper_face_map, :]

src/torchmetrics/multimodal/fdd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ class UpperFaceDynamicsDeviation(Metric):
5555
V is the number of vertices, and 3 represents XYZ coordinates.
5656
- ``target`` (:class:`~torch.Tensor`): Ground truth vertices tensor of shape (T, V, 3) where T is the number of
5757
frames, V is the number of vertices, and 3 represents XYZ coordinates.
58-
- ``upper_face_map`` (:class:`list`): List of vertex indices corresponding to the upper-face region.
59-
- ``template`` (:class:`~torch.Tensor`): Template mesh tensor of shape (V, 3) representing the neutral face.
6058
6159
As output of ``forward`` and ``compute``, the metric returns the following output:
6260
@@ -89,7 +87,6 @@ class UpperFaceDynamicsDeviation(Metric):
8987
is_differentiable: bool = True
9088
higher_is_better: bool = False
9189
full_state_update: bool = False
92-
plot_lower_bound: float = 0.0
9390

9491
vertices_pred_list: List[Tensor]
9592
vertices_gt_list: List[Tensor]
@@ -108,7 +105,12 @@ def __init__(
108105
raise ValueError(f"Expected template to have shape (V, 3) but got {template.shape}.")
109106
if not self.upper_face_map:
110107
raise ValueError("upper_face_map cannot be empty.")
111-
108+
if min(self.upper_face_map) < 0 or max(self.upper_face_map) >= self.template.shape[0]:
109+
raise ValueError(
110+
f"upper_face_map contains invalid vertex indices. "
111+
f"Valid indices are between 0 and {self.template.shape[0] - 1}, "
112+
f"but got min index {min(self.upper_face_map)}, max index {max(self.upper_face_map)}."
113+
)
112114
self.add_state("vertices_pred_list", default=[], dist_reduce_fx=None)
113115
self.add_state("vertices_gt_list", default=[], dist_reduce_fx=None)
114116

tests/unittests/multimodal/test_fdd.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def _generate_vertices(batch_size: int = 1) -> _InputVertices:
4646

4747
def _reference_fdd(vertices_pred, vertices_gt, template, upper_face_map):
4848
"""Reference implementation for FDD metric using numpy."""
49-
pred = vertices_pred[:, upper_face_map, :].numpy() # (T, M, 3)
50-
gt = vertices_gt[:, upper_face_map, :].numpy() # (T, M, 3)
51-
template = template[upper_face_map, :].numpy() # (M, 3)
49+
pred = vertices_pred[:, upper_face_map, :].detach().cpu().numpy() # (T, M, 3)
50+
gt = vertices_gt[:, upper_face_map, :].detach().cpu().numpy() # (T, M, 3)
51+
template = template[upper_face_map, :].detach().cpu().numpy() # (M, 3)
5252

5353
displacements_gt = gt - template # (T, V, 3)
5454
displacements_pred = pred - template
@@ -128,9 +128,8 @@ def test_error_on_empty_upper_face_map(self):
128128

129129
def test_error_on_invalid_upper_face_indices(self):
130130
"""Test that an error is raised if upper_face_map has invalid indices."""
131-
metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[98, 99, 100])
132131
with pytest.raises(ValueError, match="upper_face_map contains invalid vertex indices.*"):
133-
metric(torch.randn(10, 50, 3), torch.randn(10, 50, 3))
132+
UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[98, 99, 100])
134133

135134
def test_plot_method(self):
136135
"""Test the plot method of FDD."""

0 commit comments

Comments
 (0)