Skip to content

Commit 3f3e10d

Browse files
author
twata
committed
[WIP] [pfto] Add ppe.map test case
1 parent ef49153 commit 3f3e10d

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import torch
44

5+
import pytorch_pfn_extras as ppe
56
from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test
67

78

@@ -389,3 +390,18 @@ def forward(self, x):
389390
assert len(model.graph.input) == 1
390391
model = run_model_test(Model(), (torch.rand((1,)),), keep_initializers_as_inputs=True)
391392
assert len(model.graph.input) == (2 if persistent else 1)
393+
394+
395+
def test_ppe_map():
396+
torch.manual_seed(100)
397+
398+
class Net(torch.nn.Module):
399+
def __init__(self):
400+
super(Net, self).__init__()
401+
self.conv = torch.nn.Conv2d(1, 1, 3)
402+
403+
def forward(self, x):
404+
y = self.conv(x)
405+
return list(ppe.map(lambda u: u + 1, y))[0]
406+
407+
run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03)

0 commit comments

Comments
 (0)