Skip to content

Commit 9b78950

Browse files
authored
Merge branch 'master' into fix-test-torch-112
2 parents a95828c + 4095ed2 commit 9b78950

File tree

7 files changed

+56
-37
lines changed

7 files changed

+56
-37
lines changed

pytorch_pfn_extras/nn/modules/lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _lazy_load_hook( # type: ignore[no-untyped-def]
129129

130130
class UninitializedParameter(torch.nn.Parameter):
131131

132-
def __repr__(self) -> str:
132+
def __repr__(self) -> str: # type: ignore[override]
133133
return 'Uninitialized lazy parameter'
134134

135135
def share_memory_(self) -> 'UninitializedParameter':

pytorch_pfn_extras/nn/parallel/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
233233

234234
def load_state_dict(
235235
self,
236-
state_dict: 'OrderedDict[str, torch.Tensor]',
236+
state_dict: 'Mapping[str, torch.Tensor]',
237237
strict: bool = True,
238238
) -> None:
239-
self.module.load_state_dict(state_dict, strict=strict)
239+
self.module.load_state_dict(state_dict, strict=strict) # type: ignore[arg-type]
240240

241241
T_destination = TypeVar('T_destination', bound=Mapping[str, torch.Tensor])
242242

pytorch_pfn_extras/onnx/_constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
if pytorch_pfn_extras.requires("1.12.0"):
66
from torch.onnx._constants import * # NOQA
77
else:
8-
onnx_default_opset = torch.onnx.symbolic_helper._default_onnx_opset_version
9-
onnx_main_opset = torch.onnx.symbolic_helper._onnx_main_opset
10-
onnx_stable_opsets = torch.onnx.symbolic_helper._onnx_stable_opsets
8+
onnx_default_opset = torch.onnx.symbolic_helper._default_onnx_opset_version # type: ignore
9+
onnx_main_opset = torch.onnx.symbolic_helper._onnx_main_opset # type: ignore
10+
onnx_stable_opsets = torch.onnx.symbolic_helper._onnx_stable_opsets # type: ignore
1111
onnx_constant_folding_opsets = torch.onnx.symbolic_helper._constant_folding_opset_versions if pytorch_pfn_extras.requires("1.11.0") else torch.onnx.constant_folding_opset_versions # type: ignore[attr-defined]
Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import pytorch_pfn_extras
22
import torch
3-
import torch.onnx.symbolic_helper
43
from typing import Optional
54

65

7-
class _InternalGlobalsBeforeTorch1_11:
8-
@property
9-
def export_onnx_opset_version(self) -> int:
10-
return torch.onnx.symbolic_helper._export_onnx_opset_version
6+
if pytorch_pfn_extras.requires("1.12.0"):
7+
import torch.onnx._globals
8+
GLOBALS = torch.onnx._globals.GLOBALS
119

12-
@property
13-
def operator_export_type(self) -> Optional[torch._C._onnx.OperatorExportTypes]:
14-
return torch.onnx.symbolic_helper._operator_export_type
10+
else:
11+
import torch.onnx.symbolic_helper as symhel
1512

16-
@property
17-
def training_mode(self) -> Optional[torch._C._onnx.TrainingMode]:
18-
return torch.onnx.symbolic_helper._training_mode
13+
class _InternalGlobalsBeforeTorch1_11:
14+
@property
15+
def export_onnx_opset_version(self) -> int:
16+
return symhel._export_onnx_opset_version # type: ignore
1917

20-
@property
21-
def onnx_shape_inference(self) -> bool:
22-
return torch.onnx.symbolic_helper._onnx_shape_inference
18+
@property
19+
def operator_export_type(self) -> Optional[torch._C._onnx.OperatorExportTypes]:
20+
return symhel._operator_export_type # type: ignore
2321

22+
@property
23+
def training_mode(self) -> Optional[torch._C._onnx.TrainingMode]:
24+
return symhel._training_mode # type: ignore
2425

25-
if pytorch_pfn_extras.requires("1.12.0"):
26-
import torch.onnx._globals
27-
GLOBALS = torch.onnx._globals.GLOBALS
28-
else:
29-
GLOBALS = _InternalGlobalsBeforeTorch1_11()
26+
@property
27+
def onnx_shape_inference(self) -> bool:
28+
return symhel._onnx_shape_inference # type: ignore
29+
30+
GLOBALS = _InternalGlobalsBeforeTorch1_11() # type: ignore[assignment]

pytorch_pfn_extras/onnx/export_testcase.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def _export_util(
111111
operator_export_type = OperatorExportTypes.ONNX_ATEN if\
112112
aten else OperatorExportTypes.RAW # type: ignore
113113
elif operator_export_type is None:
114-
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
114+
if pytorch_pfn_extras.requires("1.12.0"):
115+
use_onnx_aten_fallback = torch.onnx._CAFFE2_ATEN_FALLBACK # type: ignore[attr-defined]
116+
else:
117+
use_onnx_aten_fallback = torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE # type: ignore[attr-defined]
118+
119+
if use_onnx_aten_fallback:
115120
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
116121
else:
117122
operator_export_type = OperatorExportTypes.ONNX

pytorch_pfn_extras/utils/comparer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn
1313
import torch.testing
1414

15+
import pytorch_pfn_extras
1516
from pytorch_pfn_extras import handler as _handler_module
1617
from pytorch_pfn_extras.handler import _logic
1718
from pytorch_pfn_extras.training import _trainer
@@ -191,8 +192,13 @@ def compare_fn(
191192
val1 = val1.cpu().detach()
192193
if isinstance(val2, torch.Tensor):
193194
val2 = val2.cpu().detach()
194-
torch.testing.assert_allclose(
195-
val1, val2, rtol=rtol, atol=atol, equal_nan=equal_nan)
195+
196+
if pytorch_pfn_extras.requires("1.9.0"):
197+
assert_close = torch.testing.assert_close # type: ignore[attr-defined]
198+
else:
199+
assert_close = torch.testing.assert_allclose # type: ignore[assignment]
200+
201+
assert_close(val1, val2, rtol=rtol, atol=atol, equal_nan=equal_nan)
196202

197203
return compare_fn
198204

tests/pytorch_pfn_extras_tests/test_handler.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
import pytorch_pfn_extras as ppe
77

88

9+
def torch_testing_assert_close(*args, **kwargs):
10+
if ppe.requires("1.10.0"):
11+
torch.testing.assert_close(*args, **kwargs)
12+
else:
13+
torch.testing.assert_allclose(*args, **kwargs)
14+
15+
916
class MockRuntime(ppe.runtime.BaseRuntime):
1017
def __init__(self, device, options):
1118
super().__init__(device, options)
@@ -361,8 +368,8 @@ def test_train_step(self):
361368
model = models['main']
362369
assert input.grad is not None
363370
# The gradient of a linear layer is its transposed weight
364-
torch.testing.assert_allclose(input.grad, model.weight.T)
365-
torch.testing.assert_allclose(out, model(input))
371+
torch_testing_assert_close(input.grad, model.weight.T)
372+
torch_testing_assert_close(out, model(input))
366373

367374
@pytest.mark.parametrize(
368375
'to_backprop',
@@ -403,11 +410,11 @@ def forward(self, x):
403410
grad = torch.zeros(1)
404411
for val in to_backprop:
405412
grad = grad + getattr(model, f'l{val}').weight.T
406-
torch.testing.assert_allclose(input.grad, grad)
413+
torch_testing_assert_close(input.grad, grad)
407414

408415
# Check that logic step does not change the value of weight
409416
for val in original_parameters:
410-
torch.testing.assert_allclose(
417+
torch_testing_assert_close(
411418
original_parameters[val], getattr(model, f'l{val}').weight)
412419

413420
def test_train_step_backward_nograd(self):
@@ -461,7 +468,7 @@ def test_train_step_optimizers(self):
461468
w_grad = model.weight.grad.clone().detach()
462469
logic.train_step_optimizers(model, optimizers, 0)
463470
# Checks that the value was correctly updated
464-
torch.testing.assert_allclose(m_weight - w_grad, model.weight.T)
471+
torch_testing_assert_close(m_weight - w_grad, model.weight.T)
465472

466473
@pytest.mark.gpu
467474
def test_grad_scaler(self):
@@ -473,12 +480,12 @@ def test_grad_scaler(self):
473480
m_weight = model.weight.clone().detach()
474481
w_grad = model.weight.grad.clone().detach()
475482
# The gradient of a linear layer is its transposed weight
476-
torch.testing.assert_allclose(input.grad, scaler.scale(model.weight.T))
477-
torch.testing.assert_allclose(out, model(input))
483+
torch_testing_assert_close(input.grad, scaler.scale(model.weight.T))
484+
torch_testing_assert_close(out, model(input))
478485
logic.train_step_optimizers(model, optimizers, 0)
479486
# Checks that the value was correctly updated and gradients deescaled
480487
# before the update
481-
torch.testing.assert_allclose(
488+
torch_testing_assert_close(
482489
scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T))
483490

484491
@pytest.mark.gpu
@@ -513,4 +520,4 @@ def test_eval_step(self):
513520
models = {'main': model}
514521
models['main'].eval()
515522
out = logic.eval_step(models, 0, input)
516-
torch.testing.assert_allclose(out, model(input))
523+
torch_testing_assert_close(out, model(input))

0 commit comments

Comments
 (0)