Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.

# Code Style
- Always add type annotations to function signatures.
- Always add docstrings to functions, using Google-style docstring format.
107 changes: 65 additions & 42 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,6 +2921,44 @@ def cleanup_training_resources(
logger.info("✅ Process group destroyed")


@Timer("[Main Thread] 🗡️ Saving checkpoint state")
def maybe_save_checkpoint(
args: Args,
policy_group: ModelGroup,
training_step: int,
episode: int,
num_total_tokens: int,
iter_dataloader: ShufflingIterator | None,
) -> None:
"""Save checkpoint state if checkpoint frequency conditions are met.

Args:
args: Training configuration arguments.
policy_group: Group of policy models to checkpoint.
training_step: Current training step number.
episode: Current episode count.
num_total_tokens: Total number of tokens processed.
iter_dataloader: Data iterator to save state from, or None.
"""
if not (
args.checkpoint_state_freq > 0
and training_step % args.checkpoint_state_freq == 0
and args.checkpoint_state_dir is not None
):
return

client_state = {"training_step": training_step, "episode": episode, "num_total_tokens": num_total_tokens}

if iter_dataloader is not None:
client_state["shuffling_iterator_state"] = iter_dataloader.get_state()

ray_get_with_progress(
[model.save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) for model in policy_group.models],
desc=f"Saving checkpoint state at step {training_step}",
)
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")


def run_training(
args,
tokenizer,
Expand Down Expand Up @@ -3098,52 +3136,37 @@ def health_check_fn():
iter_dataloader,
)

logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
weight_sync_trigger_event.set()
async_futures = []

# Checkpoint after one_training_step (or even if it was skipped)
# This ensures we checkpoint progress even if the exact checkpoint step has no data
if (
args.checkpoint_state_freq > 0
and training_step % args.checkpoint_state_freq == 0
and args.checkpoint_state_dir is not None
):
with Timer("[Main Thread] 🗡️ Saving checkpoint state"):
# Save comprehensive client state including ShufflingIterator state
client_state = {
"training_step": training_step,
"episode": episode,
"num_total_tokens": num_total_tokens,
}

# Save ShufflingIterator state
if iter_dataloader is not None:
client_state["shuffling_iterator_state"] = iter_dataloader.get_state()

ray_get_with_progress(
[
policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state)
for i in range(args.world_size)
],
desc=f"Saving checkpoint state at step {training_step}",
)
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
async_futures.append(
executor.submit(
maybe_save_checkpoint, args, policy_group, training_step, episode, num_total_tokens, iter_dataloader
)
)

maybe_evaluate(
args,
training_step,
evaluation_inference_results_Q,
tokenizer,
reward_fn,
episode,
eval_pending_queries_map,
generation_configs["eval"],
generate_metrics_Q,
len(eval_dataset) if eval_dataset else 0,
model_dims,
actor_manager,
async_futures.append(
executor.submit(
maybe_evaluate,
args,
training_step,
evaluation_inference_results_Q,
tokenizer,
reward_fn,
episode,
eval_pending_queries_map,
generation_configs["eval"],
generate_metrics_Q,
len(eval_dataset) if eval_dataset else 0,
model_dims,
actor_manager,
)
)

futures.wait(async_futures)

logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
weight_sync_trigger_event.set()

if resume_training_step > args.num_training_steps:
raise ValueError(f"Training didn't run since {resume_training_step=} > {args.num_training_steps=}")

Expand Down