Skip to content

Commit 4dd935f

Browse files
authored
Merge pull request #568 from asi1024/fix-mypy-failure
Fix mypy failures for torch 1.12
2 parents a27e3d4 + 7070bcd commit 4dd935f

File tree

5 files changed

+27
-26
lines changed

5 files changed

+27
-26
lines changed

pytorch_pfn_extras/nn/modules/lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _lazy_load_hook( # type: ignore[no-untyped-def]
129129

130130
class UninitializedParameter(torch.nn.Parameter):
131131

132-
def __repr__(self) -> str:
132+
def __repr__(self) -> str: # type: ignore[override]
133133
return 'Uninitialized lazy parameter'
134134

135135
def share_memory_(self) -> 'UninitializedParameter':

pytorch_pfn_extras/nn/parallel/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
233233

234234
def load_state_dict(
235235
self,
236-
state_dict: 'OrderedDict[str, torch.Tensor]',
236+
state_dict: 'Mapping[str, torch.Tensor]',
237237
strict: bool = True,
238238
) -> None:
239-
self.module.load_state_dict(state_dict, strict=strict)
239+
self.module.load_state_dict(state_dict, strict=strict) # type: ignore[arg-type]
240240

241241
T_destination = TypeVar('T_destination', bound=Mapping[str, torch.Tensor])
242242

pytorch_pfn_extras/onnx/_constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
if pytorch_pfn_extras.requires("1.12.0"):
66
from torch.onnx._constants import * # NOQA
77
else:
8-
onnx_default_opset = torch.onnx.symbolic_helper._default_onnx_opset_version
9-
onnx_main_opset = torch.onnx.symbolic_helper._onnx_main_opset
10-
onnx_stable_opsets = torch.onnx.symbolic_helper._onnx_stable_opsets
8+
onnx_default_opset = torch.onnx.symbolic_helper._default_onnx_opset_version # type: ignore
9+
onnx_main_opset = torch.onnx.symbolic_helper._onnx_main_opset # type: ignore
10+
onnx_stable_opsets = torch.onnx.symbolic_helper._onnx_stable_opsets # type: ignore
1111
onnx_constant_folding_opsets = torch.onnx.symbolic_helper._constant_folding_opset_versions if pytorch_pfn_extras.requires("1.11.0") else torch.onnx.constant_folding_opset_versions # type: ignore[attr-defined]
Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import pytorch_pfn_extras
22
import torch
3-
import torch.onnx.symbolic_helper
43
from typing import Optional
54

65

7-
class _InternalGlobalsBeforeTorch1_11:
8-
@property
9-
def export_onnx_opset_version(self) -> int:
10-
return torch.onnx.symbolic_helper._export_onnx_opset_version
6+
if pytorch_pfn_extras.requires("1.12.0"):
7+
import torch.onnx._globals
8+
GLOBALS = torch.onnx._globals.GLOBALS
119

12-
@property
13-
def operator_export_type(self) -> Optional[torch._C._onnx.OperatorExportTypes]:
14-
return torch.onnx.symbolic_helper._operator_export_type
10+
else:
11+
import torch.onnx.symbolic_helper as symhel
1512

16-
@property
17-
def training_mode(self) -> Optional[torch._C._onnx.TrainingMode]:
18-
return torch.onnx.symbolic_helper._training_mode
13+
class _InternalGlobalsBeforeTorch1_11:
14+
@property
15+
def export_onnx_opset_version(self) -> int:
16+
return symhel._export_onnx_opset_version # type: ignore
1917

20-
@property
21-
def onnx_shape_inference(self) -> bool:
22-
return torch.onnx.symbolic_helper._onnx_shape_inference
18+
@property
19+
def operator_export_type(self) -> Optional[torch._C._onnx.OperatorExportTypes]:
20+
return symhel._operator_export_type # type: ignore
2321

22+
@property
23+
def training_mode(self) -> Optional[torch._C._onnx.TrainingMode]:
24+
return symhel._training_mode # type: ignore
2425

25-
if pytorch_pfn_extras.requires("1.12.0"):
26-
import torch.onnx._globals
27-
GLOBALS = torch.onnx._globals.GLOBALS
28-
else:
29-
GLOBALS = _InternalGlobalsBeforeTorch1_11()
26+
@property
27+
def onnx_shape_inference(self) -> bool:
28+
return symhel._onnx_shape_inference # type: ignore
29+
30+
GLOBALS = _InternalGlobalsBeforeTorch1_11() # type: ignore[assignment]

pytorch_pfn_extras/onnx/export_testcase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _export_util(
111111
operator_export_type = OperatorExportTypes.ONNX_ATEN if\
112112
aten else OperatorExportTypes.RAW # type: ignore
113113
elif operator_export_type is None:
114-
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
114+
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: # type: ignore[attr-defined]
115115
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
116116
else:
117117
operator_export_type = OperatorExportTypes.ONNX

0 commit comments

Comments
 (0)