Skip to content

Commit c40cd02

Browse files
author
emcastillo
authored
Merge pull request #570 from asi1024/fix-test-torch-112
Small test fixes for torch 1.12
2 parents 4095ed2 + 9b78950 commit c40cd02

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

tests/pytorch_pfn_extras_tests/onnx_tests/test_annotate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,8 @@ def forward(self, *xs):
374374
anchor_node, pre_node, next_node = named_nodes['Anchor_0_start']
375375
# anchor_attrs = [a.name for a in anchor_node.attribute]
376376
assert pre_node is None
377-
assert next_node.name == 'Concat_4'
377+
assert next_node.name.startswith('Concat_')
378378
anchor_node, pre_node, next_node = named_nodes['Anchor_0_end']
379379
# anchor_attrs = [a.name for a in anchor_node.attribute]
380-
assert pre_node.name == 'Split_10'
380+
assert pre_node.name.startswith('Split_')
381381
assert next_node is None

tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
import torch
33
import torchvision
44

5+
import pytorch_pfn_extras
56
from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test
67

78

9+
if pytorch_pfn_extras.requires("1.12.0"):
10+
resnet18_kwargs = {'weights': None}
11+
else:
12+
resnet18_kwargs = {'pretrained': True}
13+
14+
815
def test_eval_resnet18():
916
torch.manual_seed(100)
1017
run_model_test(
11-
torchvision.models.resnet.resnet18(pretrained=True),
18+
torchvision.models.resnet.resnet18(**resnet18_kwargs),
1219
(torch.rand(1, 3, 224, 224),),
1320
rtol=1e-03,
1421
use_gpu=True,
@@ -19,7 +26,7 @@ def test_eval_resnet18():
1926
@pytest.mark.xfail
2027
def test_train_resnet18():
2128
run_model_test(
22-
torchvision.models.resnet.resnet18(pretrained=True),
29+
torchvision.models.resnet.resnet18(**resnet18_kwargs),
2330
(torch.rand(1, 3, 224, 224),),
2431
rtol=1e-03,
2532
use_gpu=True,

0 commit comments

Comments
 (0)