Skip to content

Commit cdbb538

Browse files
committed
feat: log disk space usage info, warn if close to exhaustion
Signed-off-by: Ihar Hrachyshka <[email protected]>
1 parent 959a41a commit cdbb538

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

src/instructlab/training/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self.noise_alpha = noise_alpha
5656
self.tokenizer = tokenizer
5757
self.distributed_framework = distributed_framework
58+
self._last_checkpoint_size: int | None = None
5859
bnb_config = None
5960
if lora_config and lora_config.r > 0 and lora_quant_bits == 4:
6061
# Third Party
@@ -76,6 +77,14 @@ def __init__(
7677
if flash_enabled:
7778
self.base_model_args["attn_implementation"] = "flash_attention_2"
7879

80+
@property
81+
def last_checkpoint_size(self) -> int | None:
82+
return self._last_checkpoint_size
83+
84+
@last_checkpoint_size.setter
85+
def last_checkpoint_size(self, value: int):
86+
self._last_checkpoint_size = value
87+
7988
def _post_model_init(self):
8089
"""Common initialization steps that should happen after model initialization."""
8190
self.reconcile_tokenizer()

src/instructlab/training/utils.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
import os
1313
import random
14+
import shutil
1415
import subprocess
1516
import sys
1617
import time
@@ -462,15 +463,29 @@ def get_caller(num_frames=1):
462463
return f"In {file_name}, line {line_number}"
463464

464465

465-
def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
466+
def log_rank_0(
467+
msg, include_caller=False, rank=None, to_print=False, level=logging.INFO
468+
) -> None:
466469
if rank is None:
467470
rank = get_rank() if is_initialized() else 0
468-
if rank <= 0:
469-
if include_caller:
470-
msg = f"{get_caller(num_frames=2)}: {msg}"
471-
if to_print:
472-
print(msg)
473-
else:
471+
if rank > 0:
472+
return
473+
474+
if include_caller:
475+
msg = f"{get_caller(num_frames=2)}: {msg}"
476+
477+
if to_print:
478+
print(msg)
479+
return
480+
481+
match level:
482+
case logging.WARNING:
483+
logger.warning(msg)
484+
case logging.ERROR:
485+
logger.error(msg)
486+
case logging.DEBUG:
487+
logger.debug(msg)
488+
case _:
474489
logger.info(msg)
475490

476491

@@ -511,6 +526,13 @@ def skip_precheck_loops():
511526
accelerator.get_state_dict = old_get_state
512527

513528

529+
def _get_checkpoint_dir(args, samples_seen) -> Path:
530+
subdir = (
531+
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
532+
)
533+
return Path(args.output_dir) / "hf_format" / subdir
534+
535+
514536
def save_hf_format_accelerate(
515537
args,
516538
model,
@@ -519,20 +541,15 @@ def save_hf_format_accelerate(
519541
samples_seen,
520542
is_lora=False,
521543
):
522-
# Build the subdirectory name
523-
subdir = (
524-
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
525-
)
544+
# Build the final output directory path
545+
final_output_dir = _get_checkpoint_dir(args, samples_seen)
526546

527547
log_rank_0(
528-
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
548+
f"\033[93mSaving model in huggingface format at: {final_output_dir}\033[0m",
529549
to_print=True,
530550
)
531551
start = time.time()
532552

533-
# Build the final output directory path
534-
final_output_dir = Path(args.output_dir) / "hf_format" / subdir
535-
536553
output_dir = final_output_dir
537554

538555
CONFIG_NAME = "config.json"
@@ -611,6 +628,48 @@ def set_random_seed(seed):
611628
torch.cuda.manual_seed_all(seed)
612629

613630

631+
def _get_checkpoint_dir_size(checkpoint_dir) -> int:
632+
total = 0
633+
for dirpath, _, filenames in os.walk(checkpoint_dir):
634+
for f in filenames:
635+
fp = os.path.join(dirpath, f)
636+
if os.path.isfile(fp):
637+
total += os.path.getsize(fp)
638+
return total
639+
640+
641+
def check_disk_space_for_next_checkpoint(
642+
model: Model, output_dir: Path, warn_steps_ahead: int = 3
643+
) -> None:
644+
checkpoint_size = model.last_checkpoint_size
645+
if checkpoint_size is None:
646+
# No previous checkpoint size to estimate, do nothing.
647+
return
648+
649+
def _mb_size(num_bytes):
650+
return f"{num_bytes / 1024 / 1024:.2f} MB"
651+
652+
try:
653+
stat = shutil.disk_usage(output_dir)
654+
free_bytes = stat.free
655+
needed_bytes = checkpoint_size * warn_steps_ahead
656+
657+
log_rank_0(
658+
f"Disk space info: free={_mb_size(free_bytes)}, last_checkpoint_size={_mb_size(checkpoint_size)} (output_dir={output_dir})"
659+
)
660+
if free_bytes < needed_bytes:
661+
log_rank_0(
662+
f"Estimated free disk space ({_mb_size(free_bytes)}) is less than the estimated size of the next {warn_steps_ahead} checkpoints ({_mb_size(needed_bytes)}). "
663+
"The next checkpoint(s) may fail due to insufficient disk space.",
664+
level=logging.WARNING,
665+
)
666+
except Exception as e:
667+
log_rank_0(
668+
f"Could not check disk space after checkpoint: {e}",
669+
level=logging.ERROR,
670+
)
671+
672+
614673
def save_checkpoint(
615674
args,
616675
accelerator: Accelerator,
@@ -622,6 +681,10 @@ def save_checkpoint(
622681
hf_format: bool = True,
623682
full_state: bool = False,
624683
) -> None:
684+
# Warn if disk space is low.
685+
output_dir = Path(args.output_dir)
686+
check_disk_space_for_next_checkpoint(model, output_dir, warn_steps_ahead=3)
687+
625688
if hf_format:
626689
save_hf_format_accelerate(
627690
args=args,
@@ -641,6 +704,12 @@ def save_checkpoint(
641704
samples_seen=samples_seen,
642705
)
643706

707+
# Track last checkpoint size.
708+
if hf_format:
709+
checkpoint_dir = _get_checkpoint_dir(args, samples_seen)
710+
if checkpoint_dir.exists():
711+
model.last_checkpoint_size = _get_checkpoint_dir_size(checkpoint_dir)
712+
644713

645714
def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int):
646715
"""

0 commit comments

Comments
 (0)