From 0f28c81ce9e8c805bc5c20858b423e74c0bece54 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 10 Jun 2025 10:54:32 -0400 Subject: [PATCH 1/3] feat: add Checkpointer class and usage Signed-off-by: Charlie Doern --- src/instructlab/training/checkpointer.py | 385 +++++++++++++++++++++++ src/instructlab/training/main_ds.py | 50 ++- src/instructlab/training/utils.py | 235 -------------- 3 files changed, 407 insertions(+), 263 deletions(-) create mode 100644 src/instructlab/training/checkpointer.py diff --git a/src/instructlab/training/checkpointer.py b/src/instructlab/training/checkpointer.py new file mode 100644 index 00000000..fb98574d --- /dev/null +++ b/src/instructlab/training/checkpointer.py @@ -0,0 +1,385 @@ +# Standard +from copy import deepcopy +from pathlib import Path +import shutil +import time +import warnings + +# Third Party +from instructlab.dolomite.hf_models import export_to_huggingface +from torch import distributed as dist +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType +import torch + +# First Party +from instructlab.training.accelerator import Accelerator +from instructlab.training.config import DistributedBackend +from instructlab.training.model import Model + +# Local +from .utils import log_rank_0, wraps + + +class Checkpointer: + def __init__( + self, + model: Model, + optimizer: torch.optim.Optimizer, + accelerator: Accelerator, + strategy="all", + ): + self.strategy = strategy.lower() + self.model = model + self.optimizer = optimizer + self.accelerator = accelerator + + # Map strategies to internal methods + self._checkpoint_fn = { + "full_state": self.save_full_state, + "hf_format": self.save_hf_format_accelerate, + "all": self.save_all_checkpoints, + }.get(self.strategy, self._no_checkpoint) + + def checkpoint(self, *args, **kwargs): + # Calls the method chosen at init + return self._checkpoint_fn(*args, **kwargs) + + # pylint: disable=unused-argument + def _no_checkpoint(self, *args, **kwargs): + print("[None] Skipping checkpointing.") + + # pylint: disable=unused-argument + def save_fsdp_lora_model( + self, + output_dir: Path, + **kwargs, + ): + """Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original + model with the trained LoRA adapters merged into the copy. + + This function creates a full copy of the model being trained and stores it in CPU memory. + If encountering OOM errors on CPU, this is likely a culprit. + + Args: + args (Namespace): Args received by the ArgumentParser. + model (FSDP): FSDP model as prepared by `accelerate.Accelerator` + accelerator (Accelerator): The given accelerator object. + """ + # Third Party + from peft import LoraModel + + if self.accelerator.distributed_type != DistributedBackend.FSDP: + raise RuntimeError( + "`save_fsdp_lora_model` was called when FSDP was not being used." + ) + if not wraps(self.model, FSDP): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not an FSDP model." + ) + if not wraps(self.model, LoraModel): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not a LoRA model." + ) + + # okay now that validation is out of the way, we are free to implement saving + sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, sd_config): + state = self.model.state_dict() + + # When training a LoRA with FSDP and Accelerate, you cannot directly merge the adapters into + # the model wrapped by FSDP. To get around this limitation, we get a copy of the state dict + # create an identical model on CPU, load the state dict into the CPU model, merge the adapters + # and save the model to disk. + if self.accelerator.is_main_process: + # Third Party + from transformers import AutoModelForCausalLM + + # remove device_map from args list so we can load the model on CPU + old_device_map = self.model.base_model_args.pop("device_map", None) + model_copy = AutoModelForCausalLM.from_pretrained( + **self.model.base_model_args, device_map="cpu" + ) + model_copy = LoraModel(model_copy, self.model.lora_config, "default") + model_copy.load_state_dict(state) + model_copy.merge_and_unload(progressbar=True) + model_copy.save_pretrained(output_dir, safe_serialization=True) + self.model.config.to_json_file(f"{output_dir}/config.json") + self.model.tokenizer.save_pretrained(output_dir) + del model_copy + if old_device_map: + # return the previous device_map so it can be used later on if needed + self.model.base_model_args["device_map"] = old_device_map + + dist.barrier() + + # pylint: disable=unused-argument + def save_full_state( + self, + output_dir, + epoch: int, + samples_seen: int, + **kwargs, + ): + """ + Saves model, optimizer, and lr_scheduler state. + TODO: save model config - decided not to do this. + TODO: save tokenizer - decided not to do this. + TODO: handle LoRA + TODO: handle granite + """ + if self.model.lora_config is not None: + raise NotImplementedError("Can't save full state for LoRA at the moment.") + + # if args.is_granite: + # raise NotImplementedError("Can't save full state for Granite models yet.") + + output_dir = Path(output_dir) / "full_state" / f"epoch_{epoch}" + log_rank_0( + f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True + ) + + # patch FSDP state dict method so it works correctly. + def _get_state_dict_patched(model, unwrap=False): + return get_state_dict_unpatched(model, unwrap=unwrap) + + if self.accelerator.distributed_framework == "fsdp": + get_state_dict_unpatched = self.accelerator.get_state_dict + self.accelerator.get_state_dict = _get_state_dict_patched + + self.accelerator.save_state( + output_dir=output_dir, + # max_shard_size="5GB", + # safe_serialization=True, + ) + + # save metadata file for current training status + if self.accelerator.is_main_process: + # TODO: should we set the global_step here rather than calculating global_step + # based on samples_seen? + metadata = {"current_epoch": epoch, "samples_seen": samples_seen} + torch.save(metadata, output_dir / "training_metadata.json") + log_rank_0( + f"\033[93mSaving training state: {metadata}\033[0m", to_print=True + ) + + log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True) + + # cleanup + if self.accelerator.distributed_framework == "fsdp": + self.accelerator.get_state_dict = get_state_dict_unpatched + + # pylint: disable=unused-argument + def save_hf_format_accelerate( + self, + output_dir, + epoch: int, + samples_seen: int, + last_epoch: bool = False, + **kwargs, + ): + # Standard + from tempfile import TemporaryDirectory + + # Build the subdirectory name + subdir = "last_epoch" if last_epoch else f"samples_{samples_seen}" + + log_rank_0( + f"\033[93mSaving model in huggingface format at: {subdir}\033[0m", + to_print=True, + ) + start = time.time() + + if self.model.model_type in ("gpt_megatron", "gpt_dolomite"): + convert_dolomite = False + else: + convert_dolomite = True + + # Build the final output directory path + final_output_dir = Path(output_dir) / "hf_format" / subdir + + if self.model.model_type == "dolomite" and convert_dolomite: + tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with + output_dir = Path(tmpdir.name) + else: + output_dir = final_output_dir + + CONFIG_NAME = "config.json" + output_config_file = output_dir / CONFIG_NAME + + # XXX(osilkin): LoRA + FSDP requires a different saving path than the others + # so we set this variable and use it to avoid those paths further down. + is_fsdp_lora = ( + self.model.lora_config is not None + and self.accelerator.distributed_type == DistributedBackend.FSDP + ) + if is_fsdp_lora: + self.save_fsdp_lora_model( + model=self.model, + accelerator=self.accelerator, + output_dir=output_dir, + ) + + get_state_dict_unpatched = self.accelerator.get_state_dict + + def _get_state_dict_patched(model, unwrap=False): + return get_state_dict_unpatched(model, unwrap=unwrap) + + self.accelerator.get_state_dict = _get_state_dict_patched + + if not is_fsdp_lora and self.accelerator.is_main_process: + if self.model.lora_config is not None: + self.model.module.merge_adapter() + model_state = self.model.module.state_dict() + + output_dir.mkdir(parents=True, exist_ok=True) + if not self.model.module.config.architectures and convert_dolomite: + arch_added = False + if self.model.model_type == "llama": + self.model.module.config.architectures = ["LlamaForCausalLM"] + arch_added = True + elif self.model.model_type == "granite": + self.model.module.config.architectures = ["GraniteForCausalLM"] + arch_added = True + if arch_added: + warnings.warn( + f"Adding architectures to ckpt: {self.model.module.config.architectures}", + ) + else: + warnings.warn( + f"Converting from dolomite, but no architecture field added to config.json", + ) + self.model.module.config.to_json_file(output_config_file) + self.model.tokenizer.save_pretrained(output_dir) + + if self.model.lora_config is not None: + self.save_dict_accelerate( + self.accelerator, + model_state, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) + self.model.module.unmerge_adapter() + + if self.model.lora_config is None: + self.accelerator.save_model( + self.model, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) + + if ( + self.model.model_type == "dolomite" + and convert_dolomite + and self.accelerator.is_main_process + ): + # export doesnt like the directory to exist + if final_output_dir.exists(): + shutil.rmtree(final_output_dir) + export_to_huggingface( + pretrained_model_name_or_path=tmpdir.name, + save_path=final_output_dir, + model_type=self.model.model_type, + ) + tmpdir.cleanup() + + log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True) + log_rank_0(f"saving took {time.time() - start} seconds") + dist.barrier() + + self.accelerator.get_state_dict = get_state_dict_unpatched + + def save_dict_accelerate( + self, + accelerator: Accelerator, + state_to_save, + save_directory, + max_shard_size="5GB", + safe_serialization=True, + ): + old_get_state = accelerator.get_state_dict + accelerator.get_state_dict = self._copy_no_lora_dict + + def skip_precheck_loops(): + return [] + + # The save model does a loop over modules and params in order to determine how to get state dict. Since we already have the state dict directly, we want to bypass those checks. + state_to_save.modules = skip_precheck_loops + state_to_save.parameters = skip_precheck_loops + + accelerator.save_model( + state_to_save, + save_directory=save_directory, + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + ) + + accelerator.get_state_dict = old_get_state + + def _copy_no_lora_dict(self, state_dict): + # Standard + from collections import OrderedDict + + cleaned_state_dict = OrderedDict() + for param_tensor in state_dict: + if not "lora" in param_tensor: + cleaned_state_dict[ + param_tensor.replace(".base_layer", "").replace( + "basemodel.model.", "" + ) + ] = deepcopy(state_dict[param_tensor]).cpu() + return cleaned_state_dict + + def load_latest_full_state(self, output_dir: Path) -> None: + """Loads accelerator state from most recently saved checkpoint + in `output_dir/full_state`. + + Args: + output_dir: Base output directory containing the full_state subdirectory + """ + full_state_dir = output_dir / "full_state" + + if not full_state_dir.is_dir(): + return + + # picks checkpoint with the largest number of samples by splitting the "samples_NNNN" string on _ + # and comparing the number at the end of the string + checkpoint_list = sorted( + list(full_state_dir.iterdir()), + reverse=True, + key=lambda x: int(str(x).rsplit("_", maxsplit=1)[-1]), + ) + + if len(checkpoint_list) == 0: + log_rank_0( + f"\033[93mNo checkpoints to load from: {full_state_dir}\033[0m", + to_print=True, + ) + return + + latest_checkpoint = checkpoint_list[0] + log_rank_0( + f"\033[93mLoading checkpoint from: {latest_checkpoint}\033[0m", + to_print=True, + ) + self.accelerator.load_state(latest_checkpoint) + + def save_all_checkpoints( + self, + output_dir, + epoch: int, + samples_seen: int, + last_epoch: bool = False, + ): + self.save_hf_format_accelerate( + output_dir=output_dir, + epoch=epoch, + samples_seen=samples_seen, + last_epoch=last_epoch, + ) + self.save_full_state( + output_dir=output_dir, epoch=epoch, samples_seen=samples_seen + ) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4ca638d0..6a2e69a1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from pathlib import Path import argparse import datetime import logging @@ -42,6 +43,7 @@ # First Party from instructlab.training import config from instructlab.training.accelerator import Accelerator +from instructlab.training.checkpointer import Checkpointer from instructlab.training.config import ( DistributedBackend, ModelTypes, @@ -70,9 +72,6 @@ from instructlab.training.utils import ( StreamablePopen, check_valid_train_args, - load_latest_full_state, - save_checkpoint, - save_hf_format_accelerate, set_random_seed, ) import instructlab.training.data_process as dp @@ -85,6 +84,7 @@ def train( model: Model, optimizer: torch.optim.Optimizer, accelerator: Accelerator, + checkpointer: Checkpointer, ): model.train() @@ -221,14 +221,10 @@ def train( global_step * batch_size % args.save_samples == 0 ): base_logger.debug(f"Saving checkpoint at step {global_step}") - save_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, + checkpointer.checkpoint( + output_dir=args.output_dir, + epoch=epoch, samples_seen=samples_seen, - is_lora=bool(args.lora_r), - hf_format=True, ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) torch.distributed.barrier() @@ -239,28 +235,20 @@ def train( torch.cuda.empty_cache() if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") - save_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, - samples_seen=samples_seen, - is_lora=bool(args.lora_r), - full_state=args.accelerate_full_state_at_epoch, - hf_format=True, + checkpointer.checkpoint( + output_dir=args.output_dir, epoch=epoch, + samples_seen=samples_seen, ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) torch.distributed.barrier() if args.save_last: - save_hf_format_accelerate( - args, - model, - model.tokenizer, - accelerator, - samples_seen, - is_lora=bool(args.lora_r), + checkpointer.save_hf_format_accelerate( + output_dir=args.output_dir, + epoch=args.num_epochs, + samples_seen=samples_seen, + last_epoch=True, ) @@ -483,13 +471,19 @@ def main(args): optimizer = accelerator.optimizer m = accelerator.model - load_latest_full_state(args=args, accelerator=accelerator) - + strategy = "all" + if not args.accelerate_full_state_at_epoch: + strategy = "hf_format" + checkpointer = Checkpointer( + strategy=strategy, model=m, optimizer=optimizer, accelerator=accelerator + ) + checkpointer.load_latest_full_state(Path(args.output_dir)) train( args, model=m, optimizer=optimizer, accelerator=accelerator, + checkpointer=checkpointer, ) torch.distributed.barrier() diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 9472884e..37858d88 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -7,7 +7,6 @@ from copy import deepcopy from functools import partial from pathlib import Path -from tempfile import TemporaryDirectory from typing import Any, List, Optional, Tuple import importlib import inspect @@ -17,7 +16,6 @@ import shutil import subprocess import sys -import time import warnings # Third Party @@ -25,7 +23,6 @@ from accelerate import Accelerator, DistributedType from instructlab.dolomite.hf_models import ( GPTDolomiteConfig, - export_to_huggingface, import_from_huggingface, ) from torch import distributed as dist @@ -673,241 +670,9 @@ def skip_precheck_loops(): accelerator.get_state_dict = old_get_state -def save_hf_format_accelerate( - args, - model, - tokenizer, - accelerator: Accelerator, - samples_seen, - is_lora=False, -): - # Build the subdirectory name - subdir = ( - "last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}" - ) - - log_rank_0( - f"\033[93mSaving model in huggingface format at: {subdir}\033[0m", - to_print=True, - ) - start = time.time() - - if args.model_type in ("gpt_megatron", "gpt_dolomite"): - convert_dolomite = False - else: - convert_dolomite = True - - # Build the final output directory path - final_output_dir = Path(args.output_dir) / "hf_format" / subdir - - if args.use_dolomite and convert_dolomite: - tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with - output_dir = Path(tmpdir.name) - else: - output_dir = final_output_dir - - CONFIG_NAME = "config.json" - output_config_file = output_dir / CONFIG_NAME - - # XXX(osilkin): LoRA + FSDP requires a different saving path than the others - # so we set this variable and use it to avoid those paths further down. - is_fsdp_lora = is_lora and accelerator.distributed_type == DistributedType.FSDP - if is_fsdp_lora: - save_fsdp_lora_model( - args=args, - model=model, - tokenizer=tokenizer, - accelerator=accelerator, - output_dir=output_dir, - ) - - get_state_dict_unpatched = accelerator.get_state_dict - - def _get_state_dict_patched(model, unwrap=False): - return get_state_dict_unpatched(model, unwrap=unwrap) - - accelerator.get_state_dict = _get_state_dict_patched - - if not is_fsdp_lora and accelerator.is_main_process: - if is_lora: - model.module.merge_adapter() - model_state = model.module.state_dict() - - output_dir.mkdir(parents=True, exist_ok=True) - if not model.module.config.architectures and convert_dolomite: - arch_added = False - if args.model_type == "llama": - model.module.config.architectures = ["LlamaForCausalLM"] - arch_added = True - elif args.model_type == "granite": - model.module.config.architectures = ["GraniteForCausalLM"] - arch_added = True - if arch_added: - warnings.warn( - f"Adding architectures to ckpt: {model.module.config.architectures}", - ) - else: - warnings.warn( - f"Converting from dolomite, but no architecture field added to config.json", - ) - model.module.config.to_json_file(output_config_file) - tokenizer.save_pretrained(output_dir) - - if is_lora: - save_dict_accelerate( - accelerator, - model_state, - save_directory=output_dir, - max_shard_size="5GB", - safe_serialization=True, - ) - model.module.unmerge_adapter() - - if not is_lora: - accelerator.save_model( - model, - save_directory=output_dir, - max_shard_size="5GB", - safe_serialization=True, - ) - - if args.use_dolomite and convert_dolomite and accelerator.is_main_process: - # export doesnt like the directory to exist - if final_output_dir.exists(): - shutil.rmtree(final_output_dir) - export_to_huggingface( - pretrained_model_name_or_path=tmpdir.name, - save_path=final_output_dir, - model_type=args.model_type, - ) - tmpdir.cleanup() - - log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True) - log_rank_0(f"saving took {time.time() - start} seconds") - dist.barrier() - - accelerator.get_state_dict = get_state_dict_unpatched - - def set_random_seed(seed): if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - -def save_checkpoint( - args, - accelerator: Accelerator, - model, - tokenizer, - samples_seen, - is_lora: bool, - epoch: int = None, - hf_format: bool = True, - full_state: bool = False, -) -> None: - if hf_format: - save_hf_format_accelerate( - args=args, - model=model, - accelerator=accelerator, - tokenizer=tokenizer, - samples_seen=samples_seen, - is_lora=is_lora, - ) - - if full_state: - save_full_state( - args=args, - accelerator=accelerator, - is_lora=is_lora, - epoch=epoch, - samples_seen=samples_seen, - ) - - -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): - """ - Saves model, optimizer, and lr_scheduler state. - TODO: save model config - decided not to do this. - TODO: save tokenizer - decided not to do this. - TODO: handle LoRA - TODO: handle granite - """ - if is_lora: - raise NotImplementedError("Can't save full state for LoRA at the moment.") - - # if args.is_granite: - # raise NotImplementedError("Can't save full state for Granite models yet.") - - output_dir = Path(args.output_dir) / "full_state" / f"epoch_{epoch}" - log_rank_0(f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True) - - # patch FSDP state dict method so it works correctly. - def _get_state_dict_patched(model, unwrap=False): - return get_state_dict_unpatched(model, unwrap=unwrap) - - if args.distributed_training_framework == "fsdp": - get_state_dict_unpatched = accelerator.get_state_dict - accelerator.get_state_dict = _get_state_dict_patched - - accelerator.save_state( - output_dir=output_dir, - # max_shard_size="5GB", - # safe_serialization=True, - ) - - # save metadata file for current training status - if accelerator.is_main_process: - # TODO: should we set the global_step here rather than calculating global_step - # based on samples_seen? - metadata = {"current_epoch": epoch, "samples_seen": samples_seen} - torch.save(metadata, output_dir / "training_metadata.json") - log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True) - - log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True) - - # cleanup - if args.distributed_training_framework == "fsdp": - accelerator.get_state_dict = get_state_dict_unpatched - - -def load_latest_full_state(args, accelerator) -> None: - """ - Loads accelerator state from most recently saved checkpoint - in `output_dir/full_state`. - """ - output_dir = Path(args.output_dir) / "full_state" - - if not output_dir.is_dir(): - return - - # picks checkpoint with the largest number of samples by splitting the "samples_NNNN" string on _ - # and comparing the number at the end of the string - checkpoint_list = sorted( - list(output_dir.iterdir()), - reverse=True, - key=lambda x: int(str(x).rsplit("_", maxsplit=1)[-1]), - ) - - if len(checkpoint_list) == 0: - log_rank_0( - f"\033[93mNo checkpoints to load from: {output_dir}\033[0m", to_print=True - ) - return - - latest = checkpoint_list[0] - - log_rank_0(f"\033[93mLoading state from: {latest}\033[0m", to_print=True) - accelerator.load_state(latest) - - training_metadata = torch.load(latest / "training_metadata.json") - log_rank_0( - f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True - ) - - # previous epoch is basis for current epoch. - args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 - args.__dict__["samples_seen"] = training_metadata["samples_seen"] From 276e9b363f0be2a131dbb0ab3b23be601d0ecfa8 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 10 Jun 2025 11:02:24 -0400 Subject: [PATCH 2/3] feat: add test_checkpointer unit test suite Signed-off-by: Charlie Doern --- tests/unit/test_checkpointer.py | 207 ++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 tests/unit/test_checkpointer.py diff --git a/tests/unit/test_checkpointer.py b/tests/unit/test_checkpointer.py new file mode 100644 index 00000000..095bb542 --- /dev/null +++ b/tests/unit/test_checkpointer.py @@ -0,0 +1,207 @@ +# Standard +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Third Party +import pytest +import torch +import torch.distributed as dist + +# First Party +from instructlab.training.accelerator import Accelerator +from instructlab.training.checkpointer import Checkpointer +from instructlab.training.config import DistributedBackend + + +@pytest.fixture(autouse=True) +def mock_distributed(): + """Mock PyTorch distributed functionality for all tests.""" + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.barrier") as mock_barrier, + patch("torch.distributed.get_rank", return_value=0), + ): + yield mock_barrier + + +@pytest.fixture +def mock_model(): + model = MagicMock() + model.lora_config = None + model.model_type = "llama" + model.module = MagicMock() + model.module.config = MagicMock() + model.tokenizer = MagicMock() + return model + + +@pytest.fixture +def mock_optimizer(): + return MagicMock() + + +@pytest.fixture +def mock_accelerator(): + accelerator = MagicMock(spec=Accelerator) + accelerator.is_main_process = True + accelerator.distributed_type = DistributedBackend.FSDP + accelerator.distributed_framework = "fsdp" + accelerator.get_state_dict = MagicMock() + # Add missing methods that are used in the checkpointer + accelerator.save_state = MagicMock() + accelerator.save_model = MagicMock() + return accelerator + + +def test_checkpointer_initialization(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="all", + ) + + assert checkpointer.model == mock_model + assert checkpointer.optimizer == mock_optimizer + assert checkpointer.accelerator == mock_accelerator + assert checkpointer.strategy == "all" + + +def test_checkpointer_no_checkpoint(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="none", + ) + + # Test that no checkpointing occurs + checkpointer.checkpoint(output_dir="test_dir", epoch=1, samples_seen=100) + mock_accelerator.save_state.assert_not_called() + + +def test_checkpointer_full_state(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="full_state", + ) + + output_dir = Path("test_dir") + full_state_dir = output_dir / "full_state" / "epoch_1" + full_state_dir.mkdir(parents=True, exist_ok=True) + + checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100) + + # Verify accelerator save_state was called + mock_accelerator.save_state.assert_called_once() + # Verify metadata was saved + assert (full_state_dir / "training_metadata.json").exists() + + +def test_checkpointer_hf_format(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="hf_format", + ) + + output_dir = Path("test_dir") + hf_format_dir = output_dir / "hf_format" / "samples_100" + hf_format_dir.mkdir(parents=True, exist_ok=True) + + checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100) + + # Verify model config and tokenizer were saved + mock_model.module.config.to_json_file.assert_called_once() + mock_model.tokenizer.save_pretrained.assert_called_once() + # Verify accelerator save_model was called + mock_accelerator.save_model.assert_called_once() + + +def test_checkpointer_all_strategies(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="all", + ) + + output_dir = Path("test_dir") + full_state_dir = output_dir / "full_state" / "epoch_1" + hf_format_dir = output_dir / "hf_format" / "samples_100" + full_state_dir.mkdir(parents=True, exist_ok=True) + hf_format_dir.mkdir(parents=True, exist_ok=True) + + checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100) + + # Verify both full state and HF format were saved + mock_accelerator.save_state.assert_called_once() + mock_model.module.config.to_json_file.assert_called_once() + mock_model.tokenizer.save_pretrained.assert_called_once() + mock_accelerator.save_model.assert_called_once() + + +def test_checkpointer_lora_not_supported(mock_model, mock_optimizer, mock_accelerator): + mock_model.lora_config = MagicMock() # Set lora_config to non-None + + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="full_state", + ) + + with pytest.raises(NotImplementedError): + checkpointer.checkpoint(output_dir="test_dir", epoch=1, samples_seen=100) + + +def test_checkpointer_load_latest_full_state( + mock_model, mock_optimizer, mock_accelerator +): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="all", + ) + + # Mock the output directory structure + output_dir = Path("test_dir") + checkpoint_dir = output_dir / "full_state" / "epoch_1" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Mock the accelerator's load_state method + mock_accelerator.load_state = MagicMock() + + checkpointer.load_latest_full_state(output_dir) + + # Verify accelerator load_state was called + mock_accelerator.load_state.assert_called_once() + + +def test_checkpointer_save_last_epoch(mock_model, mock_optimizer, mock_accelerator): + checkpointer = Checkpointer( + model=mock_model, + optimizer=mock_optimizer, + accelerator=mock_accelerator, + strategy="hf_format", + ) + + output_dir = Path("test_dir") + last_epoch_dir = output_dir / "hf_format" / "last_epoch" + last_epoch_dir.mkdir(parents=True, exist_ok=True) + + checkpointer.checkpoint( + output_dir=output_dir, + epoch=1, + samples_seen=100, + last_epoch=True, + ) + + # Verify model was saved in last_epoch directory + mock_model.module.config.to_json_file.assert_called_once() + mock_model.tokenizer.save_pretrained.assert_called_once() + mock_accelerator.save_model.assert_called_once() From e075ad35cbde19e8f10b020c29f9e709dfa057f2 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 10 Jun 2025 11:51:26 -0400 Subject: [PATCH 3/3] fix: associate model_conf with Model model_conf from `AutoConfig` has some key info we need in the checkpointer. Associate it with the model class and its subclasses Signed-off-by: Charlie Doern --- src/instructlab/training/checkpointer.py | 12 ++++++------ src/instructlab/training/main_ds.py | 1 + src/instructlab/training/model.py | 8 ++++++++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/instructlab/training/checkpointer.py b/src/instructlab/training/checkpointer.py index fb98574d..f34668d7 100644 --- a/src/instructlab/training/checkpointer.py +++ b/src/instructlab/training/checkpointer.py @@ -191,7 +191,7 @@ def save_hf_format_accelerate( ) start = time.time() - if self.model.model_type in ("gpt_megatron", "gpt_dolomite"): + if self.model.model_conf.model_type in ("gpt_megatron", "gpt_dolomite"): convert_dolomite = False else: convert_dolomite = True @@ -199,7 +199,7 @@ def save_hf_format_accelerate( # Build the final output directory path final_output_dir = Path(output_dir) / "hf_format" / subdir - if self.model.model_type == "dolomite" and convert_dolomite: + if self.model.model_conf.model_type == "dolomite" and convert_dolomite: tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with output_dir = Path(tmpdir.name) else: @@ -236,10 +236,10 @@ def _get_state_dict_patched(model, unwrap=False): output_dir.mkdir(parents=True, exist_ok=True) if not self.model.module.config.architectures and convert_dolomite: arch_added = False - if self.model.model_type == "llama": + if self.model.model_conf.model_type == "llama": self.model.module.config.architectures = ["LlamaForCausalLM"] arch_added = True - elif self.model.model_type == "granite": + elif self.model.model_conf.model_type == "granite": self.model.module.config.architectures = ["GraniteForCausalLM"] arch_added = True if arch_added: @@ -272,7 +272,7 @@ def _get_state_dict_patched(model, unwrap=False): ) if ( - self.model.model_type == "dolomite" + self.model.model_conf.model_type == "dolomite" and convert_dolomite and self.accelerator.is_main_process ): @@ -282,7 +282,7 @@ def _get_state_dict_patched(model, unwrap=False): export_to_huggingface( pretrained_model_name_or_path=tmpdir.name, save_path=final_output_dir, - model_type=self.model.model_type, + model_type=self.model.model_conf.model_type, ) tmpdir.cleanup() diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 6a2e69a1..fc25ce6e 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -362,6 +362,7 @@ def main(args): flash_enabled=flash_enabled, noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, + model_conf=model_conf, ) args.base_model_args = m.base_model_args diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 8002d2ba..49c9fb85 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -50,11 +50,13 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + model_conf=None, ): self.lora_config = lora_config self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework + self.model_conf = model_conf bnb_config = None if lora_config and lora_config.r > 0 and lora_quant_bits == 4: # Third Party @@ -385,6 +387,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + model_conf=None, ): super().__init__( model_path=model_path, @@ -394,6 +397,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + model_conf=model_conf, ) try: # Third Party @@ -426,6 +430,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + model_conf=None, ): super().__init__( model_path=model_path, @@ -435,6 +440,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + model_conf=model_conf, ) # Third Party from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM @@ -469,6 +475,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + model_conf=None, ): super().__init__( model_path=model_path, @@ -478,6 +485,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + model_conf=model_conf, ) # Third Party from transformers import AutoModelForCausalLM