Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 127 additions & 86 deletions src/llmcompressor/utils/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,88 +7,24 @@
statistics and performance metrics.
"""

import os
import time
from typing import List, Tuple
from collections import namedtuple
from enum import Enum
from typing import List

import torch
from loguru import logger
from torch.nn import Module

__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"]
__all__ = ["CompressionLogger"]

GPUMemory = namedtuple("GPUMemory", ["id", "pct_used", "total"])

def get_GPU_memory_usage() -> List[Tuple[float, float]]:
if torch.version.hip:
return get_GPU_usage_amd()
else:
return get_GPU_usage_nv()


def get_GPU_usage_nv() -> List[Tuple[float, float]]:
"""
get gpu usage for Nvidia GPUs using nvml lib
"""
try:
import pynvml
from pynvml import NVMLError

try:
pynvml.nvmlInit()
except NVMLError as _err:
logger.warning(f"Pynml library error:\n {_err}")
return []

device_count = pynvml.nvmlDeviceGetCount()
usage = [] # [(percentage, total_memory_MB)]

# Iterate through all GPUs
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage_percentage = mem_info.used / mem_info.total
total_memory_gb = mem_info.total / (1e9)
usage.append(
(memory_usage_percentage, total_memory_gb),
)
pynvml.nvmlShutdown()
return usage

except ImportError:
logger.warning("Failed to obtain GPU usage from pynvml")
return []


def get_GPU_usage_amd() -> List[Tuple[float, float]]:
"""
get gpu usage for AMD GPUs using amdsmi lib
"""
usage = []
try:
import amdsmi

try:
amdsmi.amdsmi_init()
devices = amdsmi.amdsmi_get_processor_handles()

for device in devices:
vram_memory_usage = amdsmi.amdsmi_get_gpu_memory_usage(
device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM
)
vram_memory_total = amdsmi.amdsmi_get_gpu_memory_total(
device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM
)

memory_percentage = vram_memory_usage / vram_memory_total
usage.append(
(memory_percentage, vram_memory_total / (1e9)),
)
amdsmi.amdsmi_shut_down()
except amdsmi.AmdSmiException as error:
logger.warning(f"amdsmi library error:\n {error}")
except ImportError:
logger.warning("Failed to obtain GPU usage from amdsmi")

return usage
class GPUType(Enum):
nv = "nv"
amd = "amd"


def get_layer_size_mb(module: Module) -> float:
Expand All @@ -111,14 +47,39 @@ class CompressionLogger:
"""
Log metrics related to compression algorithm

:param start_tick: time when algorithm started"
:param start_tick: time when algorithm started
:param losses: loss as result of algorithm
:param gpu_type: device manufacturer (e.g. Nvidia, AMD)
:param visible_ids: list of device ids visible to current process
"""

def __init__(self, module: torch.nn.Module):
self.module = module
self.start_tick = None
self.loss = None
self.gpu_type = GPUType.amd if torch.version.hip else GPUType.nv

# Parse appropriate env var for visible devices to monitor
# If env var is unset, default to all devices
self.visible_ids = []
visible_devices_env_var = (
"CUDA_VISIBLE_DEVICES"
if self.gpu_type == GPUType.nv
else "AMD_VISIBLE_DEVICES"
)
visible_devices_str = os.environ.get(visible_devices_env_var, "")
try:
self.visible_ids = list(
map(
int,
visible_devices_str.lstrip("[").rstrip("]").split(","),
)
)
except Exception:
logger.bind(log_once=True).warning(
f"Failed to parse {visible_devices_env_var}. "
"All devices will be monitored"
)

def set_loss(self, loss: float):
self.loss = loss
Expand All @@ -137,18 +98,98 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb):
if self.loss is not None:
patch.log("METRIC", f"error {self.loss:.2f}")

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
patch.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)
gpu_usage: List[GPUMemory] = self.get_GPU_memory_usage()
for gpu in gpu_usage:
perc = gpu.pct_used * 100
patch.log(
"METRIC",
(
f"GPU {gpu.id} | usage: {perc:.2f}%"
f" | total memory: {gpu.total:.1f} GB"
),
)

compressed_size = get_layer_size_mb(self.module)
patch.log("METRIC", f"Compressed module size: {compressed_size} MB")

def get_GPU_memory_usage(self) -> List[GPUMemory]:
if self.gpu_type == GPUType.amd:
return self._get_GPU_usage_amd(self.visible_ids)
else:
return self._get_GPU_usage_nv(self.visible_ids)

@staticmethod
def _get_GPU_usage_nv(visible_ids: List[int]) -> List[GPUMemory]:
"""
get gpu usage for visible Nvidia GPUs using nvml lib

:param visible_ids: list of GPUs to monitor.
If unset or zero length, defaults to all
"""
try:
import pynvml
from pynvml import NVMLError

try:
pynvml.nvmlInit()
except NVMLError as _err:
logger.warning(f"Pynml library error:\n {_err}")
return []

usage: List[GPUMemory] = []

if len(visible_ids) == 0:
visible_ids = range(pynvml.nvmlDeviceGetCount())

for id in visible_ids:
handle = pynvml.nvmlDeviceGetHandleByIndex(id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage_percentage = mem_info.used / mem_info.total
total_memory_gb = mem_info.total / (1e9)
usage.append(GPUMemory(id, memory_usage_percentage, total_memory_gb))
pynvml.nvmlShutdown()
return usage

except ImportError:
logger.warning("Failed to obtain GPU usage from pynvml")
return []

@staticmethod
def _get_GPU_usage_amd(visible_ids: List[int]) -> List[GPUMemory]:
"""
get gpu usage for AMD GPUs using amdsmi lib

:param visible_ids: list of GPUs to monitor.
If unset or zero length, defaults to all
"""
usage: List[GPUMemory] = []
try:
import amdsmi

try:
amdsmi.amdsmi_init()
devices = amdsmi.amdsmi_get_processor_handles()

if len(visible_ids) == 0:
visible_ids = range(len(devices))

for id in visible_ids:
device = devices[id]
vram_memory_usage = amdsmi.amdsmi_get_gpu_memory_usage(
device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM
)
vram_memory_total = amdsmi.amdsmi_get_gpu_memory_total(
device, amdsmi.amdsmi_interface.AmdSmiMemoryType.VRAM
)

memory_percentage = vram_memory_usage / vram_memory_total
usage.append(
GPUMemory(id, memory_percentage, vram_memory_total / (1e9)),
)
amdsmi.amdsmi_shut_down()
except amdsmi.AmdSmiException as error:
logger.warning(f"amdsmi library error:\n {error}")
except ImportError:
logger.warning("Failed to obtain GPU usage from amdsmi")

return usage