Skip to content

Commit 0fb3f03

Browse files
committed
Addressed comments and made some cleanup in tests.
Signed-off-by: meetkuma <[email protected]>
1 parent 1d5c858 commit 0fb3f03

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

QEfficient/finetune/experimental/core/logger.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,21 @@
1212
from typing import Optional
1313
from transformers.utils.logging import get_logger as hf_get_logger
1414

15-
from .utils.dist_utils import get_rank
15+
from .utils.dist_utils import get_local_rank
16+
17+
18+
# -----------------------------------------------------------------------------
19+
# Logger usage:
20+
# Initialize logger:
21+
# logger = Logger("my_logger", log_file="logs/output.log", level=logging.DEBUG)
22+
# Log messages:
23+
# logger.info("This is an info message")
24+
# logger.error("This is an error message")
25+
# logger.log_rank_zero("This message is logged only on rank 0")
26+
# logger.log_exception("An error occurred", exception, raise_exception=False)
27+
# Attach file handler later if needed:
28+
# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG")
29+
# -----------------------------------------------------------------------------
1630

1731

1832
class Logger:
@@ -86,7 +100,7 @@ def log_rank_zero(self, message: str, level: int = logging.INFO) -> None:
86100
message: Message to log
87101
level: Logging level
88102
"""
89-
if get_rank() == 0:
103+
if get_local_rank() == 0:
90104
self.logger.log(level, message)
91105

92106
def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None:
@@ -104,13 +118,13 @@ def log_exception(self, message: str, exception: Exception, raise_exception: boo
104118
if raise_exception:
105119
raise exception
106120

107-
def prepare_for_logs(self, output_dir: str, save_metrics: bool = True, log_level: str = "INFO") -> None:
121+
def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "INFO") -> None:
108122
"""
109-
Prepare logger for training logs.
123+
Prepare existing logger to log to both console and file with specified
124+
output directory and log level.
110125
111126
Args:
112127
output_dir: Output directory for logs
113-
save_metrics: Whether to save metrics to file
114128
log_level: Logging level as string
115129
"""
116130
# Convert string log level to logging constant
@@ -122,7 +136,7 @@ def prepare_for_logs(self, output_dir: str, save_metrics: bool = True, log_level
122136
handler.setLevel(level)
123137

124138
# Add file handler if saving metrics
125-
if save_metrics:
139+
if output_dir:
126140
log_file = Path(output_dir) / "training.log"
127141
log_file.parent.mkdir(parents=True, exist_ok=True)
128142

QEfficient/finetune/experimental/core/utils/dist_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ def is_dist_available_and_initialized() -> bool:
1414

1515

1616
def get_rank() -> int:
17-
"""Get the rank of the current process in distributed training."""
17+
"""Return the global rank of the current process, else 0."""
1818
if not is_dist_available_and_initialized():
1919
return 0
2020
return dist.get_rank()
2121

2222

23+
def get_local_rank() -> int:
24+
"""Return the local rank of the current process on its node, else 0."""
25+
if not is_dist_available_and_initialized():
26+
return 0
27+
return dist.get_node_local_rank()
28+
29+
2330
def get_world_size() -> int:
2431
"""Get the total number of processes in distributed training."""
2532
if not is_dist_available_and_initialized():

QEfficient/finetune/experimental/tests/test_logger.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,21 @@ def test_log_levels(self, caplog):
6363
assert "Error message" in caplog.text
6464
assert "Critical message" in caplog.text
6565

66-
@patch("QEfficient.finetune.experimental.core.logger.get_rank")
67-
def test_log_rank_zero(self, mock_get_rank, caplog):
66+
@patch("QEfficient.finetune.experimental.core.logger.get_local_rank")
67+
def test_log_rank_zero_positive_case(self, mock_get_local_rank, caplog):
6868
"""Test rank zero logging functionality"""
69-
mock_get_rank.return_value = 0
69+
mock_get_local_rank.return_value = 0
7070
logger = Logger("rank_test_logger")
7171

7272
with caplog.at_level(logging.INFO):
7373
logger.log_rank_zero("Rank zero message")
7474

7575
assert "Rank zero message" in caplog.text
7676

77-
@patch("QEfficient.finetune.experimental.core.logger.get_rank")
78-
def test_log_rank_zero_not_zero(self, mock_get_rank, caplog):
79-
"""Test that non-rank zero messages are not logged"""
80-
mock_get_rank.return_value = 1
77+
@patch("QEfficient.finetune.experimental.core.logger.get_local_rank")
78+
def test_log_rank_zero_negative_case(self, mock_get_local_rank, caplog):
79+
"""Test to verify that only rank‑zero messages are logged"""
80+
mock_get_local_rank.return_value = 1
8181
logger = Logger("rank_test_logger")
8282

8383
with caplog.at_level(logging.INFO):
@@ -112,7 +112,7 @@ def test_prepare_for_logs(self, tmp_path):
112112
logger = Logger("prepare_test_logger")
113113

114114
# Prepare for logs
115-
logger.prepare_for_logs(str(output_dir), save_metrics=True, log_level="DEBUG")
115+
logger.prepare_for_logs(str(output_dir), log_level="DEBUG")
116116

117117
# Check file handler was added
118118
file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)]
@@ -125,13 +125,12 @@ def test_prepare_for_logs(self, tmp_path):
125125
# Check log level was updated
126126
assert logger.logger.level == logging.DEBUG
127127

128-
def test_prepare_for_logs_no_metrics(self, tmp_path):
129-
"""Test preparing logger without saving metrics"""
130-
output_dir = tmp_path / "output"
128+
def test_prepare_for_logs_no_file_handler(self):
129+
"""Test preparing logger without saving to file"""
131130
logger = Logger("prepare_test_logger")
132131

133132
# Prepare for logs without saving metrics
134-
logger.prepare_for_logs(str(output_dir), save_metrics=False, log_level="INFO")
133+
logger.prepare_for_logs(log_level="INFO")
135134

136135
# Check no file handler was added
137136
file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)]
@@ -149,7 +148,7 @@ def test_prepare_for_logs_already_has_file_handler(self, tmp_path):
149148
logger.logger.addHandler(file_handler)
150149

151150
# Prepare for logs again
152-
logger.prepare_for_logs(str(output_dir), save_metrics=True, log_level="INFO")
151+
logger.prepare_for_logs(str(output_dir), log_level="INFO")
153152

154153
# Should still have only one file handler
155154
file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)]
@@ -204,7 +203,7 @@ def test_complete_workflow(self, tmp_path, caplog):
204203
logger.log_exception("Caught exception", e, raise_exception=False)
205204

206205
# Test rank zero logging
207-
with patch("QEfficient.finetune.experimental.core.logger.get_rank") as mock_rank:
206+
with patch("QEfficient.finetune.experimental.core.logger.get_local_rank") as mock_rank:
208207
mock_rank.return_value = 0
209208
logger.log_rank_zero("Rank zero test")
210209

0 commit comments

Comments
 (0)