Skip to content

Commit 0738da5

Browse files
committed
changed view() -> reshape(), because there was a RuntimeError
1 parent 6cd6703 commit 0738da5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

inclearn/lib/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def accuracy(output, targets, topk=1):
181181
pred = pred.t()
182182
correct = pred.eq(targets.view(1, -1).expand_as(pred))
183183

184-
correct_k = correct[:topk].view(-1).float().sum(0).item()
184+
correct_k = correct[:topk].reshape(-1).float().sum(0).item()
185185
return round(correct_k / batch_size, 3)
186186

187187

0 commit comments

Comments
 (0)