We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 768835a commit 808209eCopy full SHA for 808209e
src/accmt/modules.py
@@ -91,7 +91,7 @@ def training_step(self, batch: Any) -> torch.Tensor:
91
"""Defines the training logic. Must return a loss tensor (scalar)."""
92
93
@override
94
- def validation_step(self, key: Any, batch: Any) -> dict:
+ def validation_step(self, key: str, batch: Any) -> dict:
95
"""
96
Defines the validation logic. Must return a dictionary containing
97
each metric with predictions and targets, and also the loss value in the dictionary.
0 commit comments