Skip to content

Commit 808209e

Browse files
committed
changed type of key
1 parent 768835a commit 808209e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/accmt/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def training_step(self, batch: Any) -> torch.Tensor:
9191
"""Defines the training logic. Must return a loss tensor (scalar)."""
9292

9393
@override
94-
def validation_step(self, key: Any, batch: Any) -> dict:
94+
def validation_step(self, key: str, batch: Any) -> dict:
9595
"""
9696
Defines the validation logic. Must return a dictionary containing
9797
each metric with predictions and targets, and also the loss value in the dictionary.

0 commit comments

Comments
 (0)