Skip to content

Commit b882ee6

Browse files
[mypy] several files from src/nncf/torch (#3866)
### Changes Add file to mypy check: ``` "src/custom_version.py", "src/nncf/torch/engine.py", "src/nncf/torch/functions.py", "src/nncf/torch/model_creation.py", "src/nncf/torch/node_utils.py", "src/nncf/torch/strip.py", "src/nncf/torch/utils.py", ``` Remove unused function: ``` BaseQuantizer.apply_minmax_init get_flat_tensor_contents_string ``` Moved `sumlike` from `nncf.torch.utils` to `tests.torch.quantization.reference` as it is used only here
1 parent de366af commit b882ee6

File tree

9 files changed

+75
-125
lines changed

9 files changed

+75
-125
lines changed

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,17 @@ strict = true
7676
# https://github.com/hauntsaninja/no_implicit_optional
7777
implicit_optional = true
7878
files = [
79+
"src/custom_version.py",
7980
"src/nncf/api",
8081
"src/nncf/data",
8182
"src/nncf/common",
8283
"src/nncf/torch/function_hook",
84+
"src/nncf/torch/engine.py",
85+
"src/nncf/torch/functions.py",
86+
"src/nncf/torch/model_creation.py",
87+
"src/nncf/torch/node_utils.py",
88+
"src/nncf/torch/strip.py",
89+
"src/nncf/torch/utils.py",
8390
"src/nncf/quantization/*py",
8491
"src/nncf/telemetry/",
8592
"src/nncf/tensor/",

src/custom_version.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ def __getattr__(name: str) -> str:
104104
# Rewrite version.py to pass custom version to package
105105
if os.environ.get("_PYPROJECT_HOOKS_BUILD_BACKEND"):
106106
content = Path(NNCF_VERSION_FILE).read_text()
107-
version_str = re.search(r"^__version__ = ['\"][^'\"]*['\"]", content, re.M).group(0)
107+
version_match = re.search(r"^__version__ = ['\"][^'\"]*['\"]", content, re.M)
108+
if version_match is None:
109+
msg = "Unable to find version string."
110+
raise RuntimeError(msg)
111+
version_str = version_match.group(0)
108112
content = content.replace(version_str, f'__version__ = "{version}"')
109113
Path(NNCF_VERSION_FILE).write_text(content)
110114

src/nncf/torch/engine.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Any, Union
12+
from typing import Any
1313

14-
import torch
1514
from torch import nn
1615

1716
from nncf.common.engine import Engine
@@ -34,9 +33,7 @@ def __init__(self, model: nn.Module):
3433
if get_backend(model) == BackendType.TORCH:
3534
self._model.eval()
3635

37-
def infer(
38-
self, input_data: Union[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]
39-
) -> Union[torch.Tensor, dict[str, Any]]:
36+
def infer(self, input_data: Any) -> Any:
4037
"""
4138
Runs Torch model on the provided input.
4239

src/nncf/torch/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414

1515

16-
def clamp(x, low, high):
16+
def clamp(x: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor:
1717
return torch.max(torch.min(x, high), low)
1818

1919

@@ -28,5 +28,5 @@ def forward(ctx: Any, input_: torch.Tensor, threshold: float = 0.5) -> torch.Ten
2828
return output
2929

3030
@staticmethod
31-
def backward(ctx: Any, *grad_outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, None]:
31+
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]:
3232
return grad_outputs[0], None

src/nncf/torch/node_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
2929
raise nncf.InternalError(msg)
3030

3131
if node.metatype not in [PTMatMulMetatype, PTAddmmMetatype]:
32+
if not isinstance(node.metatype.output_channel_axis, int):
33+
msg = f"Node metatype {node.metatype} does not have defined output channel axis"
34+
raise nncf.InternalError(msg)
3235
return node.metatype.output_channel_axis
3336

3437
if port_id == 0:
@@ -38,5 +41,5 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
3841
# W(port:0) * X(port:1): [... , C_OUT, C_IN] * [... , C_IN, ...]
3942
return -2
4043

41-
msg = f"Port id for a {node.metatype} operation is expected to be in [0, 1], {port_id} recieved"
44+
msg = f"Port id for a {node.metatype} operation is expected to be in [0, 1], {port_id} received"
4245
raise nncf.InternalError(msg)

src/nncf/torch/quantization/layers.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import nncf
2525
from nncf.common.graph import NNCFNodeName
26-
from nncf.common.logging import nncf_logger
2726
from nncf.common.quantization.quantizer_setup import QuantizationPointId
2827
from nncf.common.quantization.quantizer_setup import QuantizerSetupBase
2928
from nncf.common.quantization.quantizers import calculate_asymmetric_level_ranges
@@ -34,7 +33,6 @@
3433
from nncf.common.quantization.structs import QuantizerSpec
3534
from nncf.common.utils.debug import is_debug
3635
from nncf.common.utils.registry import Registry
37-
from nncf.torch.functions import clamp
3836
from nncf.torch.graph.transformations.commands import PTTargetPoint
3937
from nncf.torch.graph.transformations.commands import TargetType
4038
from nncf.torch.layer_utils import COMPRESSION_MODULES
@@ -56,8 +54,6 @@
5654
from nncf.torch.quantization.quantize_functions import unpack_uint4
5755
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
5856
from nncf.torch.return_types import maybe_wrap_to_torch_return_type
59-
from nncf.torch.utils import get_flat_tensor_contents_string
60-
from nncf.torch.utils import get_model_device
6157
from nncf.torch.utils import is_tracing_state
6258
from nncf.torch.utils import no_jit_trace
6359

@@ -464,29 +460,6 @@ def reset_call_counter(self):
464460
def get_trainable_params(self) -> dict[str, torch.Tensor]:
465461
return {}
466462

467-
def apply_minmax_init(self, min_values: torch.Tensor, max_values: torch.Tensor, log_module_name: str = None):
468-
"""min_values and max_values must have the same shape as specified in self.scale_shape"""
469-
if self.initialized:
470-
nncf_logger.debug(f"Skipped initializing {log_module_name} - loaded from checkpoint")
471-
return
472-
473-
if torch.all(torch.isinf(min_values)) or torch.all(torch.isinf(max_values)):
474-
msg = f"Statistics are not collected for {log_module_name}"
475-
raise ValueError(msg)
476-
477-
if torch.any(torch.eq(min_values, np.inf)) or torch.any(torch.eq(max_values, -np.inf)):
478-
msg = f"Some of the values in statistics have infinite value for {log_module_name}"
479-
raise ValueError(msg)
480-
481-
own_device = get_model_device(self)
482-
min_values = min_values.to(own_device)
483-
max_values = max_values.to(own_device)
484-
self._apply_minmax_init(min_values, max_values, log_module_name)
485-
486-
@abstractmethod
487-
def _apply_minmax_init(self, min_values: torch.Tensor, max_values: torch.Tensor, log_module_name: str = None):
488-
pass
489-
490463
@abstractmethod
491464
def set_levels(self):
492465
"""
@@ -795,26 +768,6 @@ def quantize(self, x, execute_traced_op_as_identity: bool = False):
795768
def get_trainable_params(self) -> dict[str, torch.Tensor]:
796769
return {self.SCALE_PARAM_NAME: self.scale}
797770

798-
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None):
799-
sign = torch.any(torch.lt(min_values, 0))
800-
if self._signedness_to_force is not None and sign != self._signedness_to_force:
801-
nncf_logger.debug(f"Forcing signed to {self._signedness_to_force} for module {log_module_name}")
802-
sign = self._signedness_to_force
803-
self.signed = sign
804-
805-
abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
806-
SCALE_LOWER_THRESHOLD = 0.1
807-
mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD)
808-
self._scale_param_storage.data = torch.where(
809-
mask, abs_max, SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage)
810-
)
811-
if self._is_using_log_scale_storage:
812-
self._scale_param_storage.data.log_()
813-
814-
nncf_logger.debug(
815-
f"Set sign: {self.signed} and scale: {get_flat_tensor_contents_string(self.scale)} for {log_module_name}"
816-
)
817-
818771
def broadcast_initialized_params(self, src: int = 0):
819772
super().broadcast_initialized_params(src)
820773
distributed.broadcast(self._scale_param_storage, src=src)
@@ -996,22 +949,6 @@ def get_trainable_params(self) -> dict[str, torch.Tensor]:
996949
self.INPUT_RANGE_PARAM_NAME: self.input_range,
997950
}
998951

999-
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None):
1000-
ranges = max_values - min_values
1001-
max_range = torch.max(max_values - min_values)
1002-
eps = 1e-2
1003-
correction = (clamp(ranges, low=eps * max_range, high=max_range) - ranges) * 0.5
1004-
self._input_range_param_storage.data = (ranges + 2 * correction).data
1005-
if self._is_using_log_scale_storage:
1006-
self._input_range_param_storage.data.log_()
1007-
1008-
self.input_low.data = (min_values - correction).data
1009-
1010-
nncf_logger.debug(
1011-
f"Set input_low: {get_flat_tensor_contents_string(self.input_low)} "
1012-
f"and input_range: {get_flat_tensor_contents_string(self.input_range)} for {log_module_name}"
1013-
)
1014-
1015952
def broadcast_initialized_params(self, src: int = 0):
1016953
super().broadcast_initialized_params(src)
1017954
distributed.broadcast(self.input_low, src)

src/nncf/torch/quantization/reference.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,38 @@
1717

1818
import nncf
1919
from nncf.torch.utils import CompilationWrapper
20-
from nncf.torch.utils import sum_like
2120

2221
GeneralizedTensor = TypeVar("GeneralizedTensor", torch.Tensor, np.ndarray)
2322

2423

24+
def fp32_accum_wrapper(func):
25+
def wrapper(tensor_to_sum, ret_tensor):
26+
half = tensor_to_sum.dtype == np.float16
27+
if half:
28+
tensor_to_sum = tensor_to_sum.astype(np.float32)
29+
retval = func(tensor_to_sum, ret_tensor)
30+
if half:
31+
retval = retval.astype(np.float16)
32+
return retval
33+
34+
return wrapper
35+
36+
37+
@fp32_accum_wrapper
38+
def sum_like(tensor_to_sum, ref_tensor):
39+
"""Warning: may modify tensor_to_sum"""
40+
if ref_tensor.size == 1:
41+
return tensor_to_sum.sum()
42+
43+
for dim, size in enumerate(ref_tensor.shape):
44+
if size == 1:
45+
if isinstance(tensor_to_sum, np.ndarray):
46+
tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True)
47+
else:
48+
tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True)
49+
return tensor_to_sum
50+
51+
2552
class ReferenceBackendType(Enum):
2653
NUMPY = "numpy"
2754
TORCH = "torch"

src/nncf/torch/utils.py

Lines changed: 25 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from contextlib import contextmanager
1212
from typing import Any, Callable, Generator
1313

14-
import numpy as np
1514
import torch
1615
from torch.nn import Module
1716

@@ -20,57 +19,33 @@
2019
from nncf.common.utils.os import is_windows
2120

2221

23-
def is_tracing_state():
24-
return torch._C._get_tracing_state() is not None
25-
26-
27-
class no_jit_trace:
28-
def __enter__(self):
29-
self.state = torch._C._get_tracing_state()
30-
torch._C._set_tracing_state(None)
31-
32-
def __exit__(self, *args):
33-
torch._C._set_tracing_state(self.state)
34-
self.state = None
35-
36-
37-
def fp32_accum_wrapper(func):
38-
def wrapper(tensor_to_sum, ret_tensor):
39-
half = tensor_to_sum.dtype == np.float16
40-
if half:
41-
tensor_to_sum = tensor_to_sum.astype(np.float32)
42-
retval = func(tensor_to_sum, ret_tensor)
43-
if half:
44-
retval = retval.astype(np.float16)
45-
return retval
46-
47-
return wrapper
48-
22+
def is_tracing_state() -> bool:
23+
"""
24+
Checks whether the current execution context is being traced by torch.jit.
4925
50-
@fp32_accum_wrapper
51-
def sum_like(tensor_to_sum, ref_tensor):
52-
"""Warning: may modify tensor_to_sum"""
53-
if ref_tensor.size == 1:
54-
return tensor_to_sum.sum()
26+
:return: True if the current thread is being traced, False otherwise.
27+
"""
28+
return torch._C._get_tracing_state() is not None
5529

56-
for dim, size in enumerate(ref_tensor.shape):
57-
if size == 1:
58-
if isinstance(tensor_to_sum, np.ndarray):
59-
tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True)
60-
else:
61-
tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True)
62-
return tensor_to_sum
6330

31+
@contextmanager
32+
def no_jit_trace() -> Generator[None, None, None]:
33+
"""
34+
Context manager and decorator to temporarily disable PyTorch JIT tracing.
6435
65-
def get_flat_tensor_contents_string(input_tensor):
66-
retval = "["
67-
for idx, el in enumerate(input_tensor.view(-1)):
68-
if idx >= 10:
69-
retval += f"... (first 10/{len(input_tensor.view(-1))} elements shown only) "
70-
break
71-
retval += f"{el.item():.4f}, "
72-
retval += "]"
73-
return retval
36+
When used, any operations performed within this scope will not be recorded
37+
in the TorchScript graph, even if the code is currently being executed
38+
via `torch.jit.trace`.
39+
"""
40+
# Capture the original state
41+
original_state = torch._C._get_tracing_state()
42+
try:
43+
# Disable tracing
44+
torch._C._set_tracing_state(None) # type: ignore[attr-defined]
45+
yield
46+
finally:
47+
# Restore state regardless of whether an error occurred
48+
torch._C._set_tracing_state(original_state) # type: ignore[attr-defined]
7449

7550

7651
class _ModuleState:
@@ -97,7 +72,7 @@ def save_module_state(module: Module) -> _ModuleState:
9772
return _ModuleState(module)
9873

9974

100-
def load_module_state(base_module: Module, state: _ModuleState, strict=False) -> None:
75+
def load_module_state(base_module: Module, state: _ModuleState, strict: bool = False) -> None:
10176
for name, module in base_module.named_modules():
10277
try:
10378
module.train(state.training_state[name])
@@ -114,7 +89,7 @@ def load_module_state(base_module: Module, state: _ModuleState, strict=False) ->
11489

11590

11691
@contextmanager
117-
def training_mode_switcher(model: Module, is_training: bool = True):
92+
def training_mode_switcher(model: Module, is_training: bool = True) -> Generator[None, None, None]:
11893
saved_state = save_module_state(model)
11994
model.train(is_training)
12095
try:

tests/common/test_statistic_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_duplicated_statistics_are_merged():
150150
target_inputs = TensorCollector.get_tensor_collector_inputs(outputs, output_info)
151151
collector.register_inputs(target_inputs)
152152

153-
# Check aggregators recieved inputs as expected
153+
# Check aggregators received inputs as expected
154154
assert aggregators[0]._collected_samples == 1
155155
for aggregator in aggregators[1:]:
156156
assert aggregator._collected_samples == 0
@@ -160,7 +160,7 @@ def test_duplicated_statistics_are_merged():
160160

161161
statistics = collector.get_statistics()
162162

163-
# Check aggregators recieved correct inputs
163+
# Check aggregators received correct inputs
164164
assert len(statistics) == 6
165165
for k in "ABC":
166166
assert statistics[k] == Tensor(np.array(5))

0 commit comments

Comments
 (0)