Skip to content

Commit c5b4d58

Browse files
authored
Merge pull request #768 from take-cheeze/disable_tf32_cudnn_onnx
[onnx] Disable cudnn tf32 in torchvision test
2 parents 3f023d3 + abe1248 commit c5b4d58

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning")
1616
def test_eval_resnet18():
17-
run_model_test(
18-
torchvision.models.resnet.resnet18(**resnet18_kwargs),
19-
(torch.rand(1, 3, 224, 224),),
20-
rtol=1e-03,
21-
use_gpu=True,
22-
)
17+
old_allow_tf32 = torch.backends.cudnn.allow_tf32
18+
try:
19+
torch.backends.cudnn.allow_tf32 = False
20+
run_model_test(
21+
torchvision.models.resnet.resnet18(**resnet18_kwargs),
22+
(torch.rand(1, 3, 224, 224),),
23+
rtol=1e-03,
24+
use_gpu=True,
25+
)
26+
finally:
27+
torch.backends.cudnn.allow_tf32 = old_allow_tf32
2328

2429

2530
@pytest.mark.gpu

0 commit comments

Comments
 (0)