Skip to content

Commit 1d5c858

Browse files
committed
Added logger and its test cases. Also added dist_utils which serves as utility code when dealing with distributed training.
Signed-off-by: meetkuma <[email protected]>
1 parent ea26341 commit 1d5c858

File tree

3 files changed

+416
-0
lines changed

3 files changed

+416
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
import logging
10+
import sys
11+
from pathlib import Path
12+
from typing import Optional
13+
from transformers.utils.logging import get_logger as hf_get_logger
14+
15+
from .utils.dist_utils import get_rank
16+
17+
18+
class Logger:
19+
"""Custom logger with console and file logging capabilities."""
20+
21+
def __init__(
22+
self,
23+
name: str = "transformers", # We are using "transformers" as default to align with HF logs
24+
log_file: Optional[str] = None,
25+
level: int = logging.INFO,
26+
):
27+
"""
28+
Initialize the logger.
29+
30+
Args:
31+
name: Logger name
32+
log_file: Path to log file (if None, log only to console)
33+
level: Logging level
34+
"""
35+
self.logger = hf_get_logger(name)
36+
self.logger.setLevel(level)
37+
38+
# Clear any existing handlers
39+
self.logger.handlers.clear()
40+
41+
# Create formatter
42+
self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
43+
44+
# Console handler
45+
console_handler = logging.StreamHandler(sys.stdout)
46+
console_handler.setLevel(level)
47+
console_handler.setFormatter(self.formatter)
48+
self.logger.addHandler(console_handler)
49+
50+
# File handler (if log_file is provided)
51+
if log_file:
52+
# Create directory if it doesn't exist
53+
log_path = Path(log_file)
54+
log_path.parent.mkdir(parents=True, exist_ok=True)
55+
56+
file_handler = logging.FileHandler(log_file)
57+
file_handler.setLevel(level)
58+
file_handler.setFormatter(self.formatter)
59+
self.logger.addHandler(file_handler)
60+
61+
def debug(self, message: str) -> None:
62+
"""Log debug message."""
63+
self.logger.debug(message)
64+
65+
def info(self, message: str) -> None:
66+
"""Log info message."""
67+
self.logger.info(message)
68+
69+
def warning(self, message: str) -> None:
70+
"""Log warning message."""
71+
self.logger.warning(message)
72+
73+
def error(self, message: str) -> None:
74+
"""Log error message."""
75+
self.logger.error(message)
76+
77+
def critical(self, message: str) -> None:
78+
"""Log critical message."""
79+
self.logger.critical(message)
80+
81+
def log_rank_zero(self, message: str, level: int = logging.INFO) -> None:
82+
"""
83+
Log message only on rank 0 process.
84+
85+
Args:
86+
message: Message to log
87+
level: Logging level
88+
"""
89+
if get_rank() == 0:
90+
self.logger.log(level, message)
91+
92+
def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None:
93+
"""
94+
Log exception message and optionally raise the exception.
95+
96+
Args:
97+
message: Custom message to log
98+
exception: Exception to log
99+
raise_exception: Whether to raise the exception after logging
100+
"""
101+
error_message = f"{message}: {str(exception)}"
102+
self.logger.error(error_message)
103+
104+
if raise_exception:
105+
raise exception
106+
107+
def prepare_for_logs(self, output_dir: str, save_metrics: bool = True, log_level: str = "INFO") -> None:
108+
"""
109+
Prepare logger for training logs.
110+
111+
Args:
112+
output_dir: Output directory for logs
113+
save_metrics: Whether to save metrics to file
114+
log_level: Logging level as string
115+
"""
116+
# Convert string log level to logging constant
117+
level = getattr(logging, log_level.upper(), logging.INFO)
118+
self.logger.setLevel(level)
119+
120+
# Update existing handlers' levels
121+
for handler in self.logger.handlers:
122+
handler.setLevel(level)
123+
124+
# Add file handler if saving metrics
125+
if save_metrics:
126+
log_file = Path(output_dir) / "training.log"
127+
log_file.parent.mkdir(parents=True, exist_ok=True)
128+
129+
# Check if file handler already exists
130+
file_handler_exists = any(isinstance(handler, logging.FileHandler) for handler in self.logger.handlers)
131+
132+
if not file_handler_exists:
133+
file_handler = logging.FileHandler(log_file)
134+
file_handler.setLevel(level)
135+
file_handler.setFormatter(self.formatter)
136+
self.logger.addHandler(file_handler)
137+
138+
139+
# Global logger instance
140+
_logger: Optional[Logger] = None
141+
142+
143+
def get_logger(log_file: Optional[str] = None) -> Logger:
144+
"""
145+
Get or create a logger instance.
146+
147+
Args:
148+
log_file: Path to log file (if None, log only to console)
149+
150+
Returns:
151+
Logger instance
152+
"""
153+
global _logger
154+
if _logger is None:
155+
_logger = Logger(log_file=log_file)
156+
return _logger

QEfficient/finetune/experimental/core/utils/dist_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,29 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import torch.distributed as dist
9+
10+
11+
def is_dist_available_and_initialized() -> bool:
12+
"""Check if distributed training is available and initialized."""
13+
return dist.is_available() and dist.is_initialized()
14+
15+
16+
def get_rank() -> int:
17+
"""Get the rank of the current process in distributed training."""
18+
if not is_dist_available_and_initialized():
19+
return 0
20+
return dist.get_rank()
21+
22+
23+
def get_world_size() -> int:
24+
"""Get the total number of processes in distributed training."""
25+
if not is_dist_available_and_initialized():
26+
return 1
27+
return dist.get_world_size()
28+
29+
30+
def is_main_process() -> bool:
31+
"""Check if the current process is the main process (rank 0)."""
32+
return get_rank() == 0

0 commit comments

Comments
 (0)