Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ checkpointer:
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset and Sampler
dataset:
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ checkpointer:
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset and Sampler
dataset:
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ checkpointer:
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset and Sampler
dataset:
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ checkpointer:
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset and Sampler
dataset:
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama3_2/3B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ checkpointer:
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset and Sampler
dataset:
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ checkpointer:
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
save_last_epoch_only: False
epochs_to_save: 'all'

# Dataset
dataset:
Expand Down
41 changes: 30 additions & 11 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@ def __init__(self, cfg: DictConfig) -> None:
self.global_step = 0
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
if cfg.save_last_epoch_only and cfg.epochs_to_save:
utils.log_rank_zero(
self._logger,
"Both save_last_epoch_only and epochs_to_save are in use. "
"The value for save_last_epoch_only takes precedence but will be removed in a future release.",
)
self._save_last_epoch_only = cfg.get("save_last_epoch_only", False)
self._epochs_to_save = (
[self.total_epochs - 1]
if self._save_last_epoch_only
else cfg.get("epochs_to_save", "all")
)
if self._epochs_to_save == "all":
self._epochs_to_save = list(range(self.total_epochs))
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

Expand Down Expand Up @@ -718,17 +732,22 @@ def train(self) -> None:
break

self.epochs_run += 1
start_save_checkpoint = time.perf_counter()
self._logger.info("Starting checkpoint save...")

# Save final non-distributed ckpt
self.save_checkpoint(epoch=curr_epoch, full_tensors=True)
self._logger.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.perf_counter() - start_save_checkpoint
)
)

if curr_epoch in self._epochs_to_save:
start_save_checkpoint = time.perf_counter()
self._logger.info(
f"Starting checkpoint save for epoch {curr_epoch}..."
)
self.save_checkpoint(epoch=curr_epoch)
self._logger.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.perf_counter() - start_save_checkpoint
)
)
else:
self._log.info(
f"Skipping checkpoint save for epoch {curr_epoch}..."
)

def cleanup(self) -> None:
self._metric_logger.close()

Expand Down
264 changes: 264 additions & 0 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,270 @@ def test_training_state_on_resume_with_async_checkpointing(
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.parametrize(
"epochs_to_save, expected_folders",
[
("all", ["epoch_0", "epoch_1", "epoch_2"]),
("none", []),
("1,3", ["epoch_0", "epoch_2"]),
],
)
@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_epochs_to_save(
self, tmpdir, monkeypatch, epochs_to_save, expected_folders
):
"""Test that epochs_to_save parameter controls which epoch folders are saved.
The test checks if the specified epochs are saved after training a model for 3 epochs.
"""

ckpt = "llama3_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for three epochs
cmd = f"""
tune run lora_finetune_single_device \
--config llama3/8B_lora_single_device \
batch_size=8 \
gradient_accumulation_steps=1 \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
epochs_to_save={epochs_to_save} \
save_last_epoch_only=False \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama3_lora"]

cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Verify the checkpointing behavior
# Check if the expected epoch folders are created
saved_epoch_folders = sorted(
[f for f in os.listdir(tmpdir) if f.startswith("epoch_")]
)

assert (
saved_epoch_folders == expected_folders
), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}"

@pytest.mark.parametrize(
"epochs_to_save, expected_folders",
[
("all", ["epoch_0", "epoch_1", "epoch_2"]),
("none", []),
("1,3", ["epoch_0", "epoch_2"]),
],
)
@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_epochs_to_save_with_async_checkpointing(
self, tmpdir, monkeypatch, epochs_to_save, expected_folders
):
"""Test that epochs_to_save parameter controls which epoch folders are saved with async checkpointing.
The test checks if the specified epochs are saved after training a model for 3 epochs.
"""

ckpt = "llama3_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for three epochs
cmd = f"""
tune run lora_finetune_single_device \
--config llama3/8B_lora_single_device \
batch_size=8 \
gradient_accumulation_steps=1 \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
epochs_to_save={epochs_to_save} \
save_last_epoch_only=False \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama3_lora"]

cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Verify the checkpointing behavior
# Check if the expected epoch folders are created
saved_epoch_folders = sorted(
[f for f in os.listdir(tmpdir) if f.startswith("epoch_")]
)

assert (
saved_epoch_folders == expected_folders
), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}"

@pytest.mark.parametrize(
"save_last_epoch_only, expected_folders",
[
(True, ["epoch_2"]),
(False, ["epoch_0", "epoch_1"]),
],
)
@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_save_last_epoch_only(
self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders
):
"""Test that save_last_epoch_only parameter controls checkpoint saving behavior.
The test checks if the last epoch is saved when save_last_epoch_only is True
after training a model for 3 epochs and if it correctly overrides epochs_to_save.
"""

ckpt = "llama3_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for three epochs
cmd = f"""
tune run lora_finetune_single_device \
--config llama3/8B_lora_single_device \
batch_size=8 \
gradient_accumulation_steps=1 \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
epochs_to_save='1,2' \
save_last_epoch_only={save_last_epoch_only} \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama3_lora"]

cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Verify the checkpointing behavior
# Check if the expected epoch folders are created
saved_epoch_folders = sorted(
[f for f in os.listdir(tmpdir) if f.startswith("epoch_")]
)

assert (
saved_epoch_folders == expected_folders
), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}"

@pytest.mark.parametrize(
"save_last_epoch_only, expected_folders",
[
(True, ["epoch_2"]),
(False, ["epoch_0", "epoch_1"]),
],
)
@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_save_last_epoch_only_with_async_checkpointing(
self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders
):
"""Test that save_last_epoch_only parameter controls checkpoint saving behavior with async checkpointing.
The test checks if the last epoch is saved when save_last_epoch_only is True
after training a model for 3 epochs and if it correctly overrides epochs_to_save.
"""

ckpt = "llama3_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for three epochs
cmd = f"""
tune run lora_finetune_single_device \
--config llama3/8B_lora_single_device \
batch_size=8 \
gradient_accumulation_steps=1 \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
epochs_to_save='1,2' \
save_last_epoch_only={save_last_epoch_only} \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama3_lora"]

cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Verify the checkpointing behavior
# Check if the expected epoch folders are created
saved_epoch_folders = sorted(
[f for f in os.listdir(tmpdir) if f.startswith("epoch_")]
)

assert (
saved_epoch_folders == expected_folders
), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}"

@pytest.mark.parametrize("use_dora", [False, True])
@pytest.mark.integration_test
@gpu_test(gpu_count=1)
Expand Down