|
23 | 23 |
|
24 | 24 | import nncf |
25 | 25 | from nncf.common.graph import NNCFNodeName |
26 | | -from nncf.common.logging import nncf_logger |
27 | 26 | from nncf.common.quantization.quantizer_setup import QuantizationPointId |
28 | 27 | from nncf.common.quantization.quantizer_setup import QuantizerSetupBase |
29 | 28 | from nncf.common.quantization.quantizers import calculate_asymmetric_level_ranges |
|
34 | 33 | from nncf.common.quantization.structs import QuantizerSpec |
35 | 34 | from nncf.common.utils.debug import is_debug |
36 | 35 | from nncf.common.utils.registry import Registry |
37 | | -from nncf.torch.functions import clamp |
38 | 36 | from nncf.torch.graph.transformations.commands import PTTargetPoint |
39 | 37 | from nncf.torch.graph.transformations.commands import TargetType |
40 | 38 | from nncf.torch.layer_utils import COMPRESSION_MODULES |
|
56 | 54 | from nncf.torch.quantization.quantize_functions import unpack_uint4 |
57 | 55 | from nncf.torch.return_types import maybe_get_values_from_torch_return_type |
58 | 56 | from nncf.torch.return_types import maybe_wrap_to_torch_return_type |
59 | | -from nncf.torch.utils import get_flat_tensor_contents_string |
60 | | -from nncf.torch.utils import get_model_device |
61 | 57 | from nncf.torch.utils import is_tracing_state |
62 | 58 | from nncf.torch.utils import no_jit_trace |
63 | 59 |
|
@@ -464,29 +460,6 @@ def reset_call_counter(self): |
464 | 460 | def get_trainable_params(self) -> dict[str, torch.Tensor]: |
465 | 461 | return {} |
466 | 462 |
|
467 | | - def apply_minmax_init(self, min_values: torch.Tensor, max_values: torch.Tensor, log_module_name: str = None): |
468 | | - """min_values and max_values must have the same shape as specified in self.scale_shape""" |
469 | | - if self.initialized: |
470 | | - nncf_logger.debug(f"Skipped initializing {log_module_name} - loaded from checkpoint") |
471 | | - return |
472 | | - |
473 | | - if torch.all(torch.isinf(min_values)) or torch.all(torch.isinf(max_values)): |
474 | | - msg = f"Statistics are not collected for {log_module_name}" |
475 | | - raise ValueError(msg) |
476 | | - |
477 | | - if torch.any(torch.eq(min_values, np.inf)) or torch.any(torch.eq(max_values, -np.inf)): |
478 | | - msg = f"Some of the values in statistics have infinite value for {log_module_name}" |
479 | | - raise ValueError(msg) |
480 | | - |
481 | | - own_device = get_model_device(self) |
482 | | - min_values = min_values.to(own_device) |
483 | | - max_values = max_values.to(own_device) |
484 | | - self._apply_minmax_init(min_values, max_values, log_module_name) |
485 | | - |
486 | | - @abstractmethod |
487 | | - def _apply_minmax_init(self, min_values: torch.Tensor, max_values: torch.Tensor, log_module_name: str = None): |
488 | | - pass |
489 | | - |
490 | 463 | @abstractmethod |
491 | 464 | def set_levels(self): |
492 | 465 | """ |
@@ -795,26 +768,6 @@ def quantize(self, x, execute_traced_op_as_identity: bool = False): |
795 | 768 | def get_trainable_params(self) -> dict[str, torch.Tensor]: |
796 | 769 | return {self.SCALE_PARAM_NAME: self.scale} |
797 | 770 |
|
798 | | - def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): |
799 | | - sign = torch.any(torch.lt(min_values, 0)) |
800 | | - if self._signedness_to_force is not None and sign != self._signedness_to_force: |
801 | | - nncf_logger.debug(f"Forcing signed to {self._signedness_to_force} for module {log_module_name}") |
802 | | - sign = self._signedness_to_force |
803 | | - self.signed = sign |
804 | | - |
805 | | - abs_max = torch.max(torch.abs(max_values), torch.abs(min_values)) |
806 | | - SCALE_LOWER_THRESHOLD = 0.1 |
807 | | - mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD) |
808 | | - self._scale_param_storage.data = torch.where( |
809 | | - mask, abs_max, SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage) |
810 | | - ) |
811 | | - if self._is_using_log_scale_storage: |
812 | | - self._scale_param_storage.data.log_() |
813 | | - |
814 | | - nncf_logger.debug( |
815 | | - f"Set sign: {self.signed} and scale: {get_flat_tensor_contents_string(self.scale)} for {log_module_name}" |
816 | | - ) |
817 | | - |
818 | 771 | def broadcast_initialized_params(self, src: int = 0): |
819 | 772 | super().broadcast_initialized_params(src) |
820 | 773 | distributed.broadcast(self._scale_param_storage, src=src) |
@@ -996,22 +949,6 @@ def get_trainable_params(self) -> dict[str, torch.Tensor]: |
996 | 949 | self.INPUT_RANGE_PARAM_NAME: self.input_range, |
997 | 950 | } |
998 | 951 |
|
999 | | - def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): |
1000 | | - ranges = max_values - min_values |
1001 | | - max_range = torch.max(max_values - min_values) |
1002 | | - eps = 1e-2 |
1003 | | - correction = (clamp(ranges, low=eps * max_range, high=max_range) - ranges) * 0.5 |
1004 | | - self._input_range_param_storage.data = (ranges + 2 * correction).data |
1005 | | - if self._is_using_log_scale_storage: |
1006 | | - self._input_range_param_storage.data.log_() |
1007 | | - |
1008 | | - self.input_low.data = (min_values - correction).data |
1009 | | - |
1010 | | - nncf_logger.debug( |
1011 | | - f"Set input_low: {get_flat_tensor_contents_string(self.input_low)} " |
1012 | | - f"and input_range: {get_flat_tensor_contents_string(self.input_range)} for {log_module_name}" |
1013 | | - ) |
1014 | | - |
1015 | 952 | def broadcast_initialized_params(self, src: int = 0): |
1016 | 953 | super().broadcast_initialized_params(src) |
1017 | 954 | distributed.broadcast(self.input_low, src) |
|
0 commit comments