From 37dd201820cd7c0b0e85d79566e7202e86c2eb90 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 23 Sep 2025 17:33:20 +0000 Subject: [PATCH 1/3] clear deprecation warnings in logs Signed-off-by: Brian Dellabetta --- setup.py | 6 +++++- src/llmcompressor/modifiers/pruning/magnitude/base.py | 10 ++++++---- src/llmcompressor/observers/base.py | 3 +-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 7196f2991..cab239772 100644 --- a/setup.py +++ b/setup.py @@ -133,7 +133,11 @@ def localversion_func(version: ScmVersion) -> str: if BUILD_TYPE == "release" else "accelerate>=1.6.0" ), - ("pynvml>=11.5.3,<=13.0.1" if BUILD_TYPE == "release" else "pynvml>=11.5.3"), + ( + "nvidia-ml-py>=12.560.30,<=13.580.82" + if BUILD_TYPE == "release" + else "nvidia-ml-py>=12.560.30" + ), ("pillow>=10.4.0,<=11.3.0" if BUILD_TYPE == "release" else "pillow>=10.4.0"), ( "compressed-tensors==0.11.0" diff --git a/src/llmcompressor/modifiers/pruning/magnitude/base.py b/src/llmcompressor/modifiers/pruning/magnitude/base.py index 1a218d0e3..7d36d6123 100644 --- a/src/llmcompressor/modifiers/pruning/magnitude/base.py +++ b/src/llmcompressor/modifiers/pruning/magnitude/base.py @@ -40,10 +40,12 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking): @field_validator("leave_enabled") def validate_leave_enabled(value: bool) -> bool: - warnings.warn( - "MagnitudePruningModifier.leave_enable has been deprecated", - DeprecationWarning, - ) + if value: + warnings.warn( + "MagnitudePruningModifier.leave_enabled has been deprecated " + "and will be set to False.", + DeprecationWarning, + ) return False def on_initialize(self, state: State, **kwargs) -> bool: diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index aa9e1caab..6ca6e203c 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -10,7 +10,6 @@ ) from compressed_tensors.quantization.utils import is_fp4 from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import safe_permute from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -169,7 +168,7 @@ def get_qparams( group_sizes = group_sizes[torch.argsort(group_indices)] perm = torch.argsort(g_idx) - observed = safe_permute(observed, perm, dim=1) + observed = observed.index_select(dim=1, index=perm) # TODO: experiment with vectorizing for loop for performance end = 0 From 110fd4a2558e7116d7b42befb787c45b976a1825 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 23 Sep 2025 19:57:04 +0000 Subject: [PATCH 2/3] clean up CompressionLogger verbosity Signed-off-by: Brian Dellabetta --- src/llmcompressor/utils/metric_logging.py | 213 +++++++++++++--------- 1 file changed, 127 insertions(+), 86 deletions(-) diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index 646544688..c5c5198b7 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -7,88 +7,24 @@ statistics and performance metrics. """ +import os import time -from typing import List, Tuple +from collections import namedtuple +from enum import Enum +from typing import List import torch from loguru import logger from torch.nn import Module -__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"] +__all__ = ["CompressionLogger"] +GPUMemory = namedtuple("GPUMemory", ["id", "pct_used", "total"]) -def get_GPU_memory_usage() -> List[Tuple[float, float]]: - if torch.version.hip: - return get_GPU_usage_amd() - else: - return get_GPU_usage_nv() - -def get_GPU_usage_nv() -> List[Tuple[float, float]]: - """ - get gpu usage for Nvidia GPUs using nvml lib - """ - try: - import pynvml - from pynvml import NVMLError - - try: - pynvml.nvmlInit() - except NVMLError as _err: - logger.warning(f"Pynml library error:\n {_err}") - return [] - - device_count = pynvml.nvmlDeviceGetCount() - usage = [] # [(percentage, total_memory_MB)] - - # Iterate through all GPUs - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - memory_usage_percentage = mem_info.used / mem_info.total - total_memory_gb = mem_info.total / (1e9) - usage.append( - (memory_usage_percentage, total_memory_gb), - ) - pynvml.nvmlShutdown() - return usage - - except ImportError: - logger.warning("Failed to obtain GPU usage from pynvml") - return [] - - -def get_GPU_usage_amd() -> List[Tuple[float, float]]: - """ - get gpu usage for AMD GPUs using amdsmi lib - """ - usage = [] - try: - import amdsmi - - try: - amdsmi.amdsmi_init() - devices = amdsmi.amdsmi_get_processor_handles() - - for device in devices: - vram_memory_usage = amdsmi.amdsmi_get_gpu_memory_usage( - device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM - ) - vram_memory_total = amdsmi.amdsmi_get_gpu_memory_total( - device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM - ) - - memory_percentage = vram_memory_usage / vram_memory_total - usage.append( - (memory_percentage, vram_memory_total / (1e9)), - ) - amdsmi.amdsmi_shut_down() - except amdsmi.AmdSmiException as error: - logger.warning(f"amdsmi library error:\n {error}") - except ImportError: - logger.warning("Failed to obtain GPU usage from amdsmi") - - return usage +class GPUType(Enum): + nv = "nv" + amd = "amd" def get_layer_size_mb(module: Module) -> float: @@ -111,14 +47,39 @@ class CompressionLogger: """ Log metrics related to compression algorithm - :param start_tick: time when algorithm started" + :param start_tick: time when algorithm started :param losses: loss as result of algorithm + :param gpu_type: device manufacturer (e.g. Nvidia, AMD) + :param visible_ids: list of device ids visible to current process """ def __init__(self, module: torch.nn.Module): self.module = module self.start_tick = None self.loss = None + self.gpu_type = GPUType.amd if torch.version.hip else GPUType.nv + + # For nvidia, parse CUDA_VISIBLE_DEVICES for visible devices to monitor + # If unset, default to all devices + self.visible_ids = [] + visible_devices_env_var = ( + "CUDA_VISIBLE_DEVICES" + if self.gpu_type == GPUType.nv + else "AMD_VISIBLE_DEVICES" + ) + visible_devices_str = os.environ.get(visible_devices_env_var, "") + try: + self.visible_ids = list( + map( + int, + visible_devices_str.lstrip("[").rstrip("]").split(","), + ) + ) + except Exception: + logger.bind(log_once=True).warning( + f"Failed to parse {visible_devices_env_var}. " + "All devices will be monitored" + ) def set_loss(self, loss: float): self.loss = loss @@ -137,18 +98,98 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): if self.loss is not None: patch.log("METRIC", f"error {self.loss:.2f}") - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) + gpu_usage: List[GPUMemory] = self.get_GPU_memory_usage() + for gpu in gpu_usage: + perc = gpu.pct_used * 100 + patch.log( + "METRIC", + ( + f"GPU {gpu.id} | usage: {perc:.2f}%" + f" | total memory: {gpu.total:.1f} GB" + ), + ) compressed_size = get_layer_size_mb(self.module) patch.log("METRIC", f"Compressed module size: {compressed_size} MB") + + def get_GPU_memory_usage(self) -> List[GPUMemory]: + if self.gpu_type == GPUType.amd: + return self._get_GPU_usage_amd(self.visible_ids) + else: + return self._get_GPU_usage_nv(self.visible_ids) + + @staticmethod + def _get_GPU_usage_nv(visible_ids: List[int]) -> List[GPUMemory]: + """ + get gpu usage for visible Nvidia GPUs using nvml lib + + :param visible_ids: list of GPUs to monitor. + If unset or zero length, defaults to all + """ + try: + import pynvml + from pynvml import NVMLError + + try: + pynvml.nvmlInit() + except NVMLError as _err: + logger.warning(f"Pynml library error:\n {_err}") + return [] + + usage: List[GPUMemory] = [] + + if len(visible_ids) == 0: + visible_ids = range(pynvml.nvmlDeviceGetCount()) + + for id in visible_ids: + handle = pynvml.nvmlDeviceGetHandleByIndex(id) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + memory_usage_percentage = mem_info.used / mem_info.total + total_memory_gb = mem_info.total / (1e9) + usage.append(GPUMemory(id, memory_usage_percentage, total_memory_gb)) + pynvml.nvmlShutdown() + return usage + + except ImportError: + logger.warning("Failed to obtain GPU usage from pynvml") + return [] + + @staticmethod + def _get_GPU_usage_amd(visible_ids: List[int]) -> List[GPUMemory]: + """ + get gpu usage for AMD GPUs using amdsmi lib + + :param visible_ids: list of GPUs to monitor. + If unset or zero length, defaults to all + """ + usage: List[GPUMemory] = [] + try: + import amdsmi + + try: + amdsmi.amdsmi_init() + devices = amdsmi.amdsmi_get_processor_handles() + + if len(visible_ids) == 0: + visible_ids = range(len(devices)) + + for id in visible_ids: + device = devices[id] + vram_memory_usage = amdsmi.amdsmi_get_gpu_memory_usage( + device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM + ) + vram_memory_total = amdsmi.amdsmi_get_gpu_memory_total( + device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM + ) + + memory_percentage = vram_memory_usage / vram_memory_total + usage.append( + GPUMemory(id, memory_percentage, vram_memory_total / (1e9)), + ) + amdsmi.amdsmi_shut_down() + except amdsmi.AmdSmiException as error: + logger.warning(f"amdsmi library error:\n {error}") + except ImportError: + logger.warning("Failed to obtain GPU usage from amdsmi") + + return usage From a17f39080cd6b37935cf3071313af06f81e197e3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 23 Sep 2025 20:31:01 +0000 Subject: [PATCH 3/3] comment update Signed-off-by: Brian Dellabetta --- src/llmcompressor/utils/metric_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index c5c5198b7..59d5012a0 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -59,8 +59,8 @@ def __init__(self, module: torch.nn.Module): self.loss = None self.gpu_type = GPUType.amd if torch.version.hip else GPUType.nv - # For nvidia, parse CUDA_VISIBLE_DEVICES for visible devices to monitor - # If unset, default to all devices + # Parse appropriate env var for visible devices to monitor + # If env var is unset, default to all devices self.visible_ids = [] visible_devices_env_var = ( "CUDA_VISIBLE_DEVICES"