diff --git a/CLAUDE.md b/CLAUDE.md index 3d2e569a7..ed3045c5d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,3 +8,7 @@ - To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes. - Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`. - Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`. + +# Code Style +- Always add type annotations to function signatures. +- Always add docstrings to functions, using Google-style docstring format. diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index d8db97f16..65c91cf05 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2921,6 +2921,44 @@ def cleanup_training_resources( logger.info("✅ Process group destroyed") +@Timer("[Main Thread] 🗡️ Saving checkpoint state") +def maybe_save_checkpoint( + args: Args, + policy_group: ModelGroup, + training_step: int, + episode: int, + num_total_tokens: int, + iter_dataloader: ShufflingIterator | None, +) -> None: + """Save checkpoint state if checkpoint frequency conditions are met. + + Args: + args: Training configuration arguments. + policy_group: Group of policy models to checkpoint. + training_step: Current training step number. + episode: Current episode count. + num_total_tokens: Total number of tokens processed. + iter_dataloader: Data iterator to save state from, or None. + """ + if not ( + args.checkpoint_state_freq > 0 + and training_step % args.checkpoint_state_freq == 0 + and args.checkpoint_state_dir is not None + ): + return + + client_state = {"training_step": training_step, "episode": episode, "num_total_tokens": num_total_tokens} + + if iter_dataloader is not None: + client_state["shuffling_iterator_state"] = iter_dataloader.get_state() + + ray_get_with_progress( + [model.save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) for model in policy_group.models], + desc=f"Saving checkpoint state at step {training_step}", + ) + logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}") + + def run_training( args, tokenizer, @@ -3098,52 +3136,37 @@ def health_check_fn(): iter_dataloader, ) - logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}") - weight_sync_trigger_event.set() + async_futures = [] - # Checkpoint after one_training_step (or even if it was skipped) - # This ensures we checkpoint progress even if the exact checkpoint step has no data - if ( - args.checkpoint_state_freq > 0 - and training_step % args.checkpoint_state_freq == 0 - and args.checkpoint_state_dir is not None - ): - with Timer("[Main Thread] 🗡️ Saving checkpoint state"): - # Save comprehensive client state including ShufflingIterator state - client_state = { - "training_step": training_step, - "episode": episode, - "num_total_tokens": num_total_tokens, - } - - # Save ShufflingIterator state - if iter_dataloader is not None: - client_state["shuffling_iterator_state"] = iter_dataloader.get_state() - - ray_get_with_progress( - [ - policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) - for i in range(args.world_size) - ], - desc=f"Saving checkpoint state at step {training_step}", - ) - logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}") + async_futures.append( + executor.submit( + maybe_save_checkpoint, args, policy_group, training_step, episode, num_total_tokens, iter_dataloader + ) + ) - maybe_evaluate( - args, - training_step, - evaluation_inference_results_Q, - tokenizer, - reward_fn, - episode, - eval_pending_queries_map, - generation_configs["eval"], - generate_metrics_Q, - len(eval_dataset) if eval_dataset else 0, - model_dims, - actor_manager, + async_futures.append( + executor.submit( + maybe_evaluate, + args, + training_step, + evaluation_inference_results_Q, + tokenizer, + reward_fn, + episode, + eval_pending_queries_map, + generation_configs["eval"], + generate_metrics_Q, + len(eval_dataset) if eval_dataset else 0, + model_dims, + actor_manager, + ) ) + futures.wait(async_futures) + + logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}") + weight_sync_trigger_event.set() + if resume_training_step > args.num_training_steps: raise ValueError(f"Training didn't run since {resume_training_step=} > {args.num_training_steps=}")