From 6789673472580f5d155003f71619a6bde5763c6f Mon Sep 17 00:00:00 2001 From: Donggyu Ban Date: Wed, 18 Jun 2025 02:06:39 -0400 Subject: [PATCH 1/5] last epoch --- .../configs/llama2/7B_lora_single_device.yaml | 1 + .../configs/llama3/8B_lora_single_device.yaml | 1 + .../llama3_1/8B_lora_single_device.yaml | 1 + .../llama3_2/1B_lora_single_device.yaml | 2 +- .../llama3_2/3B_lora_single_device.yaml | 2 +- .../11B_lora_single_device.yaml | 1 + recipes/lora_finetune_single_device.py | 21 ++++++++++++------- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index afa215ae37..1304e632c6 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -46,6 +46,7 @@ checkpointer: model_type: LLAMA2 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 77cfef59e2..6d6ae18cb2 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -48,6 +48,7 @@ checkpointer: model_type: LLAMA3 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 46b3f767ee..51e197f814 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -48,6 +48,7 @@ checkpointer: model_type: LLAMA3 resume_from_checkpoint: False save_adapter_weights_only: False +save_last_epoch_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index a5479fa724..89c67bbc94 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -43,7 +43,7 @@ checkpointer: model_type: LLAMA3_2 resume_from_checkpoint: False save_adapter_weights_only: False - +save_last_epoch_only: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index 4f54caed9f..d42a246d56 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -45,7 +45,7 @@ checkpointer: model_type: LLAMA3_2 resume_from_checkpoint: False save_adapter_weights_only: False - +save_last_epoch_only: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 6b434aa499..8b0ad7723c 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -48,6 +48,7 @@ checkpointer: model_type: LLAMA3_VISION resume_from_checkpoint: False save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_last_epoch_only: False # Dataset dataset: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index a56a4df269..0938253fc8 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -156,6 +156,7 @@ def __init__(self, cfg: DictConfig) -> None: self.global_step = 0 self._resume_from_checkpoint = cfg.resume_from_checkpoint self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._save_last_epoch_only = cfg.get("save_last_epoch_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) @@ -688,14 +689,20 @@ def train(self) -> None: break self.epochs_run += 1 - start_save_checkpoint = time.perf_counter() - self._logger.info("Starting checkpoint save...") - self.save_checkpoint(epoch=curr_epoch) - self._logger.info( - "Checkpoint saved in {:.2f} seconds.".format( - time.perf_counter() - start_save_checkpoint + + # If self._save_last_epoch_only is true, only save checkpoint on the final epoch to save disk space + if not self._save_last_epoch_only or curr_epoch == self.total_epochs - 1: + start_save_checkpoint = time.perf_counter() + self._logger.info("Starting checkpoint save...") + self.save_checkpoint(epoch=curr_epoch) + log.info( + "Checkpoint saved in {:.2f} seconds.".format( + time.perf_counter() - start_save_checkpoint + ) ) - ) + else: + log.info( + f"Skipping checkpoint save for epoch {curr_epoch + 1}..") def cleanup(self) -> None: self._metric_logger.close() From 2da8bdc291364a4207d3beb3861b4258d4cfb06c Mon Sep 17 00:00:00 2001 From: Donggyu Ban Date: Wed, 18 Jun 2025 17:57:56 -0400 Subject: [PATCH 2/5] revision --- recipes/configs/llama3_2/1B_lora_single_device.yaml | 1 + recipes/configs/llama3_2/3B_lora_single_device.yaml | 1 + recipes/lora_finetune_single_device.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 89c67bbc94..eabb12a986 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -44,6 +44,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False + # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index d42a246d56..cf89c9c4d2 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -46,6 +46,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False + # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 0938253fc8..d684c127ac 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -695,13 +695,13 @@ def train(self) -> None: start_save_checkpoint = time.perf_counter() self._logger.info("Starting checkpoint save...") self.save_checkpoint(epoch=curr_epoch) - log.info( + self._logger.info( "Checkpoint saved in {:.2f} seconds.".format( time.perf_counter() - start_save_checkpoint ) ) else: - log.info( + self._logger.info( f"Skipping checkpoint save for epoch {curr_epoch + 1}..") def cleanup(self) -> None: From f35865d7721e2d74e4ab67986a4042510cb7dedb Mon Sep 17 00:00:00 2001 From: Donggyu Ban Date: Fri, 27 Jun 2025 17:03:01 -0400 Subject: [PATCH 3/5] format-change --- recipes/lora_finetune_single_device.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index d684c127ac..14e5ec201c 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -691,7 +691,10 @@ def train(self) -> None: self.epochs_run += 1 # If self._save_last_epoch_only is true, only save checkpoint on the final epoch to save disk space - if not self._save_last_epoch_only or curr_epoch == self.total_epochs - 1: + if ( + not self._save_last_epoch_only + or curr_epoch == self.total_epochs - 1 + ): start_save_checkpoint = time.perf_counter() self._logger.info("Starting checkpoint save...") self.save_checkpoint(epoch=curr_epoch) @@ -702,7 +705,8 @@ def train(self) -> None: ) else: self._logger.info( - f"Skipping checkpoint save for epoch {curr_epoch + 1}..") + f"Skipping checkpoint save for epoch {curr_epoch + 1}.." + ) def cleanup(self) -> None: self._metric_logger.close() From 35a9043cc9d72247e3d4d95ce28612b902d8b5a2 Mon Sep 17 00:00:00 2001 From: Donggyu Ban Date: Fri, 27 Jun 2025 19:20:30 -0400 Subject: [PATCH 4/5] added save_last_epoch_only feature test in recipe test --- .../test_lora_finetune_single_device.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index a42c298466..da99d1ba82 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -364,6 +364,136 @@ def test_training_state_on_resume_with_async_checkpointing( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) + +@pytest.mark.parametrize("save_last_epoch_only", [False, True]) +@pytest.mark.integration_test +@gpu_test(gpu_count=1) +def test_save_last_epoch_only(self, tmpdir, monkeypatch, save_last_epoch_only): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=False \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + epoch_folders = [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + + if save_last_epoch_only: + expected_epoch_folders = 1 + assert ( + len(epoch_folders) == expected_epoch_folders + ), f"With save_last_epoch_only=True, expected {expected_epoch_folders} epoch folder, got {len(epoch_folders)}" + assert "epoch_2" in epoch_folders, "Final epoch checkpoint should exist" + else: + expected_epoch_folders = 3 + assert ( + len(epoch_folders) == expected_epoch_folders + ), f"With save_last_epoch_only=False, expected {expected_epoch_folders} epoch folders, got {len(epoch_folders)}" + + +@pytest.mark.parametrize("save_last_epoch_only", [False, True]) +@pytest.mark.integration_test +@gpu_test(gpu_count=1) +def test_save_last_epoch_only_with_async_checkpointing( + self, tmpdir, monkeypatch, save_last_epoch_only +): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior with asunc checkpointing. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + epoch_folders = [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + + if save_last_epoch_only: + expected_epoch_folders = 1 + assert ( + len(epoch_folders) == expected_epoch_folders + ), f"With save_last_epoch_only=True, expected {expected_epoch_folders} epoch folder, got {len(epoch_folders)}" + assert "epoch_2" in epoch_folders, "Final epoch checkpoint should exist" + else: + expected_epoch_folders = 3 + assert ( + len(epoch_folders) == expected_epoch_folders + ), f"With save_last_epoch_only=False, expected {expected_epoch_folders} epoch folders, got {len(epoch_folders)}" + @pytest.mark.parametrize("use_dora", [False, True]) @pytest.mark.integration_test @gpu_test(gpu_count=1) From e5da1b459146fea55eaf810940feab900ac1641c Mon Sep 17 00:00:00 2001 From: Donggyu Ban Date: Fri, 18 Jul 2025 20:39:03 -0400 Subject: [PATCH 5/5] update with epochs_to_save --- .../configs/llama2/7B_lora_single_device.yaml | 1 + .../configs/llama3/8B_lora_single_device.yaml | 1 + .../llama3_1/8B_lora_single_device.yaml | 1 + .../llama3_2/1B_lora_single_device.yaml | 1 + .../llama3_2/3B_lora_single_device.yaml | 1 + .../11B_lora_single_device.yaml | 1 + recipes/lora_finetune_single_device.py | 28 +- .../test_lora_finetune_single_device.py | 382 ++++++++++++------ 8 files changed, 283 insertions(+), 133 deletions(-) diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index 1304e632c6..1b0a094d50 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -47,6 +47,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 6d6ae18cb2..72f42db691 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -49,6 +49,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 51e197f814..c09ac23307 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -49,6 +49,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index eabb12a986..3446743be2 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -44,6 +44,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index cf89c9c4d2..1c29f0a242 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -46,6 +46,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False save_last_epoch_only: False +epochs_to_save: 'all' # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 8b0ad7723c..994b0e6be6 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -49,6 +49,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. save_last_epoch_only: False +epochs_to_save: 'all' # Dataset dataset: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 14e5ec201c..1340b31dc1 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -156,7 +156,20 @@ def __init__(self, cfg: DictConfig) -> None: self.global_step = 0 self._resume_from_checkpoint = cfg.resume_from_checkpoint self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + if cfg.save_last_epoch_only and cfg.epochs_to_save: + utils.log_rank_zero( + self._logger, + "Both save_last_epoch_only and epochs_to_save are in use. " + "The value for save_last_epoch_only takes precedence but will be removed in a future release.", + ) self._save_last_epoch_only = cfg.get("save_last_epoch_only", False) + self._epochs_to_save = ( + [self.total_epochs - 1] + if self._save_last_epoch_only + else cfg.get("epochs_to_save", "all") + ) + if self._epochs_to_save == "all": + self._epochs_to_save = list(range(self.total_epochs)) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) @@ -689,14 +702,11 @@ def train(self) -> None: break self.epochs_run += 1 - - # If self._save_last_epoch_only is true, only save checkpoint on the final epoch to save disk space - if ( - not self._save_last_epoch_only - or curr_epoch == self.total_epochs - 1 - ): + if curr_epoch in self._epochs_to_save: start_save_checkpoint = time.perf_counter() - self._logger.info("Starting checkpoint save...") + self._logger.info( + f"Starting checkpoint save for epoch {curr_epoch}..." + ) self.save_checkpoint(epoch=curr_epoch) self._logger.info( "Checkpoint saved in {:.2f} seconds.".format( @@ -704,8 +714,8 @@ def train(self) -> None: ) ) else: - self._logger.info( - f"Skipping checkpoint save for epoch {curr_epoch + 1}.." + self._log.info( + f"Skipping checkpoint save for epoch {curr_epoch}..." ) def cleanup(self) -> None: diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index da99d1ba82..de45d976e6 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -364,135 +364,269 @@ def test_training_state_on_resume_with_async_checkpointing( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) + @pytest.mark.parametrize( + "epochs_to_save, expected_folders", + [ + ("all", ["epoch_0", "epoch_1", "epoch_2"]), + ("none", []), + ("1,3", ["epoch_0", "epoch_2"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_epochs_to_save( + self, tmpdir, monkeypatch, epochs_to_save, expected_folders + ): + """Test that epochs_to_save parameter controls which epoch folders are saved. + The test checks if the specified epochs are saved after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save={epochs_to_save} \ + save_last_epoch_only=False \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=False \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) -@pytest.mark.parametrize("save_last_epoch_only", [False, True]) -@pytest.mark.integration_test -@gpu_test(gpu_count=1) -def test_save_last_epoch_only(self, tmpdir, monkeypatch, save_last_epoch_only): - """Test that save_last_epoch_only parameter controls checkpoint saving behavior. - The test checks if the last epoch is saved when save_last_epoch_only is True - after training a model for 3 epochs. - """ - - ckpt = "llama3_tune" - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - - # Config file needed for model conversion. - write_hf_ckpt_config(ckpt_dir) - write_hf_ckpt_config(tmpdir) - - # Train for three epochs - cmd = f""" - tune run lora_finetune_single_device \ - --config llama3/8B_lora_single_device \ - batch_size=8 \ - gradient_accumulation_steps=1 \ - output_dir={tmpdir} \ - model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ - model.apply_lora_to_mlp=False \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}] \ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ - tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ - tokenizer.prompt_template=null \ - save_last_epoch_only={save_last_epoch_only} \ - enable_activation_checkpointing=True \ - enable_activation_offloading=False \ - enable_async_checkpointing=False \ - """.split() - - model_config = MODEL_TEST_CONFIGS["llama3_lora"] - - cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config - monkeypatch.setattr(sys, "argv", cmd) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Verify the checkpointing behavior - # Check if the expected epoch folders are created - epoch_folders = [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] - - if save_last_epoch_only: - expected_epoch_folders = 1 assert ( - len(epoch_folders) == expected_epoch_folders - ), f"With save_last_epoch_only=True, expected {expected_epoch_folders} epoch folder, got {len(epoch_folders)}" - assert "epoch_2" in epoch_folders, "Final epoch checkpoint should exist" - else: - expected_epoch_folders = 3 + saved_epoch_folders == expected_folders + ), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "epochs_to_save, expected_folders", + [ + ("all", ["epoch_0", "epoch_1", "epoch_2"]), + ("none", []), + ("1,3", ["epoch_0", "epoch_2"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_epochs_to_save_with_async_checkpointing( + self, tmpdir, monkeypatch, epochs_to_save, expected_folders + ): + """Test that epochs_to_save parameter controls which epoch folders are saved with async checkpointing. + The test checks if the specified epochs are saved after training a model for 3 epochs. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save={epochs_to_save} \ + save_last_epoch_only=False \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + assert ( - len(epoch_folders) == expected_epoch_folders - ), f"With save_last_epoch_only=False, expected {expected_epoch_folders} epoch folders, got {len(epoch_folders)}" - - -@pytest.mark.parametrize("save_last_epoch_only", [False, True]) -@pytest.mark.integration_test -@gpu_test(gpu_count=1) -def test_save_last_epoch_only_with_async_checkpointing( - self, tmpdir, monkeypatch, save_last_epoch_only -): - """Test that save_last_epoch_only parameter controls checkpoint saving behavior with asunc checkpointing. - The test checks if the last epoch is saved when save_last_epoch_only is True - after training a model for 3 epochs. - """ - - ckpt = "llama3_tune" - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - - # Config file needed for model conversion. - write_hf_ckpt_config(ckpt_dir) - write_hf_ckpt_config(tmpdir) - - # Train for three epochs - cmd = f""" - tune run lora_finetune_single_device \ - --config llama3/8B_lora_single_device \ - batch_size=8 \ - gradient_accumulation_steps=1 \ - output_dir={tmpdir} \ - model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ - model.apply_lora_to_mlp=False \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}] \ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ - tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ - tokenizer.prompt_template=null \ - save_last_epoch_only={save_last_epoch_only} \ - enable_activation_checkpointing=True \ - enable_activation_offloading=False \ - enable_async_checkpointing=True \ - """.split() - - model_config = MODEL_TEST_CONFIGS["llama3_lora"] - - cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config - monkeypatch.setattr(sys, "argv", cmd) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Verify the checkpointing behavior - # Check if the expected epoch folders are created - epoch_folders = [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] - - if save_last_epoch_only: - expected_epoch_folders = 1 + saved_epoch_folders == expected_folders + ), f"With epochs_to_save={epochs_to_save}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "save_last_epoch_only, expected_folders", + [ + (True, ["epoch_2"]), + (False, ["epoch_0", "epoch_1"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_save_last_epoch_only( + self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders + ): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs and if it correctly overrides epochs_to_save. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save='1,2' \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + assert ( - len(epoch_folders) == expected_epoch_folders - ), f"With save_last_epoch_only=True, expected {expected_epoch_folders} epoch folder, got {len(epoch_folders)}" - assert "epoch_2" in epoch_folders, "Final epoch checkpoint should exist" - else: - expected_epoch_folders = 3 + saved_epoch_folders == expected_folders + ), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" + + @pytest.mark.parametrize( + "save_last_epoch_only, expected_folders", + [ + (True, ["epoch_2"]), + (False, ["epoch_0", "epoch_1"]), + ], + ) + @pytest.mark.integration_test + @gpu_test(gpu_count=1) + def test_save_last_epoch_only_with_async_checkpointing( + self, tmpdir, monkeypatch, save_last_epoch_only, expected_folders + ): + """Test that save_last_epoch_only parameter controls checkpoint saving behavior with async checkpointing. + The test checks if the last epoch is saved when save_last_epoch_only is True + after training a model for 3 epochs and if it correctly overrides epochs_to_save. + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for three epochs + cmd = f""" + tune run lora_finetune_single_device \ + --config llama3/8B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=False \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \ + tokenizer.prompt_template=null \ + epochs_to_save='1,2' \ + save_last_epoch_only={save_last_epoch_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=False \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides(epochs=3) + model_config + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Verify the checkpointing behavior + # Check if the expected epoch folders are created + saved_epoch_folders = sorted( + [f for f in os.listdir(tmpdir) if f.startswith("epoch_")] + ) + assert ( - len(epoch_folders) == expected_epoch_folders - ), f"With save_last_epoch_only=False, expected {expected_epoch_folders} epoch folders, got {len(epoch_folders)}" + saved_epoch_folders == expected_folders + ), f"With save_last_epoch_only={save_last_epoch_only}, expected epoch folders {expected_folders}, got {saved_epoch_folders}" @pytest.mark.parametrize("use_dora", [False, True]) @pytest.mark.integration_test