diff --git a/mason.py b/mason.py index 4ce339eb9e..edbbd163a0 100644 --- a/mason.py +++ b/mason.py @@ -10,9 +10,12 @@ import time import beaker +import requests from rich.console import Console from rich.text import Text +from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS + console = Console() @@ -87,11 +90,6 @@ def parse_env_var(env_var_str: str) -> dict[str, str]: return {"name": name, "value": value} -WEKA_CLUSTERS = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] -GCP_CLUSTERS = ["ai2/augusta"] - -INTERCONNECT_CLUSTERS = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"] - # by default, we turn off vllm compile cache # torch compile caching seems consistently broken, but the actual compiling isn't. # Not sure why, for now we have disabled the caching (VLLM_DISABLE_COMPILE_CACHE=1). @@ -292,6 +290,7 @@ def get_env_vars( "AZURE_API_KEY", "AZURE_API_BASE", "ANTHROPIC_API_KEY", + "SLACK_WEBHOOK", ] for useful_secret in useful_secrets: if f"{whoami}_{useful_secret}" in beaker_secrets: @@ -827,8 +826,82 @@ def main(): budget=args.budget, retry=beaker.BeakerRetrySpec(allowed_task_retries=args.max_retries), ) - exp = beaker_client.experiment.create(spec=experiment_spec) - console.log(f"Kicked off Beaker job. https://beaker.org/ex/{exp.experiment.id}") + + # Increase timeout for HTTP requests and add retry logic + # The beaker library uses requests internally, and large experiment specs can take longer to process + max_retries = 3 + retry_delay = 5 # Start with 5 seconds + timeout_seconds = 300 # Increase timeout to 300 seconds (5 minutes) for large experiment specs + + # Monkey-patch requests.Session to intercept and increase timeout values + # The beaker library hardcodes a 5-second timeout, so we need to patch at the requests level + original_session_request = requests.Session.request + original_session_post = requests.Session.post + + def patched_session_request(self, method, url, **kwargs): + # Override timeout if it's set to a low value (less than our desired timeout) + if "timeout" in kwargs and kwargs["timeout"] is not None: + current_timeout = kwargs["timeout"] + # Handle tuple timeouts (connect, read) or single value + if isinstance(current_timeout, tuple): + if len(current_timeout) == 2 and current_timeout[1] < timeout_seconds: + kwargs["timeout"] = (current_timeout[0], timeout_seconds) + elif isinstance(current_timeout, (int, float)) and current_timeout < timeout_seconds: + kwargs["timeout"] = timeout_seconds + elif "timeout" not in kwargs: + kwargs["timeout"] = timeout_seconds + return original_session_request(self, method, url, **kwargs) + + def patched_session_post(self, url, **kwargs): + # Override timeout if it's set to a low value + if "timeout" in kwargs and kwargs["timeout"] is not None: + current_timeout = kwargs["timeout"] + if isinstance(current_timeout, tuple): + if len(current_timeout) == 2 and current_timeout[1] < timeout_seconds: + kwargs["timeout"] = (current_timeout[0], timeout_seconds) + elif isinstance(current_timeout, (int, float)) and current_timeout < timeout_seconds: + kwargs["timeout"] = timeout_seconds + elif "timeout" not in kwargs: + kwargs["timeout"] = timeout_seconds + return original_session_post(self, url, **kwargs) + + # Apply the patches + requests.Session.request = patched_session_request + requests.Session.post = patched_session_post + console.log(f"✅ Patched requests.Session to use minimum {timeout_seconds} second timeout") + + # Also try to increase the timeout on the beaker client itself + try: + if hasattr(beaker_client, "_timeout"): + beaker_client._timeout = timeout_seconds + console.log(f"✅ Set beaker client timeout to {timeout_seconds} seconds") + except Exception as e: + console.log(f"⚠️ Could not modify beaker client timeout: {e}. Will rely on retries.") + + # Retry logic with exponential backoff for timeout errors + for attempt in range(max_retries): + try: + exp = beaker_client.experiment.create(spec=experiment_spec) + console.log(f"Kicked off Beaker job. https://beaker.org/ex/{exp.experiment.id}") + break + except (requests.exceptions.ReadTimeout, requests.exceptions.Timeout): + if attempt < max_retries - 1: + wait_time = retry_delay * (2**attempt) # Exponential backoff: 5s, 10s, 20s + console.log( + f"⚠️ Timeout occurred (attempt {attempt + 1}/{max_retries}). " + f"Retrying in {wait_time} seconds... " + f"Large experiment specs may take longer to process." + ) + time.sleep(wait_time) + else: + console.log( + f"❌ Failed to create Beaker experiment after {max_retries} attempts due to timeout. " + f"The experiment spec may be too large or the Beaker API may be experiencing issues." + ) + raise + except Exception: + # For other exceptions, don't retry + raise if __name__ == "__main__": diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 96857b3f22..9d85f08d4b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -112,11 +112,11 @@ RayProcess, _z3_params_to_fetch, calibrate_checkpoint_state_dir, + clean_last_n_checkpoints, clean_last_n_checkpoints_deepspeed, combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, - get_eval_ds_config, get_optimizer_grouped_parameters, get_train_ds_config, get_wandb_tags, @@ -209,6 +209,8 @@ class Args: """Run evaluation after this many training steps. This controls in-loop evals, which reuse the generation/reward verifier setup. Set to -1 to disable.""" save_freq: int = 200 """How many train steps to save the model""" + beaker_eval_freq: int = -1 + """How many train steps to launch beaker evaluation jobs. Set to -1 to disable.""" allow_world_padding: bool = False """Whether to allow world padding. This is useful for model sweeps, but wastes compute.""" backend_timeout: int = 120 @@ -259,8 +261,6 @@ class Args: reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly used. E.g., [TD3](https://arxiv.org/abs/1802.09477) uses https://github.com/vwxyzjn/cleanrl/blob/dcc289fc6f0bda492fa7360a155262cf826b12a5/cleanrl/td3_continuous_action.py#L269 """ - ref_policy_update_freq: int | None = None - """How many training steps to take before updating the reference policy.""" advantage_normalization_type: Literal["standard", "centered"] = "standard" """The type of advantage normalization to use. Standard normalization is the default: it subtracts the mean and divides by the standard deviation. Centered normalization is the same but subtracts the mean only (e.g., used in @@ -271,7 +271,7 @@ class Args: active_sampling: bool = False """Whether to continue sampling responses until you get a full batch.""" filter_zero_std_samples: bool = True - """Whether to filter out prompts with zero reward std (all samples have the same score). Must be True when active_sampling is True.""" + """Whether to filter out prompts with zero reward std (all samples have the same score).""" no_resampling_pass_rate: float | None = None """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" @@ -393,8 +393,10 @@ class Args: """Whether to save learning data traces""" cache_dataset_only: bool = False """Immediately exit after caching the dataset""" - keep_last_n_checkpoints: int = 3 - """How many checkpoints to keep in the output directory. -1 for all.""" + keep_last_n_checkpoint_states: int = 3 + """How many checkpoint states (optimizer, lr scheduler, etc.) to keep in the output directory. -1 for all.""" + keep_last_n_checkpoints: int = -1 + """How many regular model checkpoints to keep in the output directory. -1 for all.""" checkpoint_state_freq: int = -1 """How often to save the model checkpoint, optimizer states, and lr scheduler states (in steps)""" checkpoint_state_dir: str | None = None @@ -419,6 +421,8 @@ class Args: """multiply the gpus used for each oe-eval task""" eval_priority: Literal["low", "normal", "high", "urgent"] = "normal" """the priority of auto-launched evaluation jobs""" + eval_workspace: str = "ai2/tulu-3-results" + """the workspace to launch evaluation jobs on""" # Evaluation behavior eval_on_step_0: bool = False @@ -484,6 +488,12 @@ def __post_init__(self): if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1: raise ValueError("`checkpoint_state_freq` must be greater than 0 if `checkpoint_state_dir` is provided!") + if self.beaker_eval_freq > 0 and self.save_freq > 0 and self.beaker_eval_freq % self.save_freq != 0: + raise ValueError( + f"`beaker_eval_freq` (={self.beaker_eval_freq}) must be a multiple of `save_freq` (={self.save_freq}) " + "because beaker eval jobs require checkpoints to exist." + ) + if self.gs_checkpoint_state_dir is not None and not self.gs_checkpoint_state_dir.startswith("gs://"): raise ValueError(f"`gs_checkpoint_state_dir` must start with 'gs://', got: {self.gs_checkpoint_state_dir}") if self.gs_bucket_path is not None and not self.gs_bucket_path.startswith("gs://"): @@ -498,6 +508,8 @@ def __post_init__(self): self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{beaker_users}/{checkpoint_dir_name}" else: self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}" + if not checkpoint_dir_name.startswith("/filestore"): + self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}" if self.checkpoint_state_dir is not None: if self.gs_checkpoint_state_dir is not None: @@ -819,6 +831,9 @@ def load(self, path: str, map_location=None): f"Skipping loading checkpoint state from {args.checkpoint_state_dir} because it does not exist!" ) else: + # print what is in the checkpoint state dir + print(f"Checkpoint state dir: {args.checkpoint_state_dir}") + print(os.listdir(args.checkpoint_state_dir)) path, states = self.model.load_checkpoint( args.checkpoint_state_dir, load_module_strict=True, @@ -844,58 +859,10 @@ def load(self, path: str, map_location=None): torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"]) logger.info(f"{self.rank=}: Restored RNG states from checkpoint") - - # Save reference policy path to load later (after ref_policy is initialized) - self.ref_policy_checkpoint_path = None - if states.get("ref_policy_saved", False): - ref_policy_dir = os.path.join(args.checkpoint_state_dir, "ref_policy") - model_path = os.path.join(ref_policy_dir, "pytorch_model.bin") - if os.path.exists(model_path): - self.ref_policy_checkpoint_path = model_path - logger.info(f"{self.rank=}: Will load reference policy from {model_path}") - logger.info( f"{self.rank=}: Loaded checkpoint from {args.checkpoint_state_dir} with {optimization_steps_done=}" ) self.model.train() - - # reference model - ds_config = get_eval_ds_config( - offload=args.deepspeed_offload_param, - # inference model only has stage 3 (sharding) or stage 0 (no sharding) - # stage 2 is optimizer sharding which doesn't apply to inference - stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["gradient_accumulation_steps"] = 1 - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - logger.info(f"DeepSpeed config: {dschf=}") - - self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - **({"device_map": {"": self.local_rank}} if args.deepspeed_stage != 3 else {}), - ) - disable_dropout_in_model(self.ref_policy) - self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) - self.ref_policy.eval() - - # Load reference policy checkpoint if available - if hasattr(self, "ref_policy_checkpoint_path") and self.ref_policy_checkpoint_path: - state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device) - if hasattr(self.ref_policy, "module"): - # If wrapped by DeepSpeed - self.ref_policy.module.load_state_dict(state_dict) - else: - self.ref_policy.load_state_dict(state_dict) - logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") self.local_metrics = utils.MetricsTracker(device=self.device) return optimization_steps_done @@ -1012,15 +979,6 @@ def broadcast_to_vllm(self): all_refs.extend(refss) return all_refs - def update_ref_policy(self): - for ref_param, param in zip(self.ref_policy.parameters(), self.model.parameters()): - if self.args.deepspeed_stage == 3: - with deepspeed.zero.GatheredParameters([param, ref_param], modifier_rank=0): - if deepspeed.comm.get_rank() == 0: - ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - else: - ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def train( self, collated_query_responses, @@ -1057,32 +1015,6 @@ def train( # recalculate the "real" number of mini-batches num_mini_batches = len(collated_query_responses) // accumulation_steps - # Calculate the logprob of the reference policy - collated_ref_logprobs = [] - with Timer("Inference Calculation", noop=self.rank != 0), torch.no_grad(): - for i in range(len(collated_query_responses)): - query_response = collated_query_responses[i] - tool_mask = collated_tool_masks[i] - attention_mask = collated_attention_masks[i] - position_id = collated_position_ids[i] - response_mask = collated_response_masks[i] - ref_logprob, _ = self.forward( - self.ref_policy, - query_response, - attention_mask, - position_id, - pad_token_id, - args.temperature, - return_entropy=False, - ) - if args.mask_tool_use and args.tool_use: - # mask logprobs for tool tokens - response_mask = response_mask.bool() & tool_mask.bool() - else: - response_mask = response_mask.bool() - ref_logprob = torch.masked_fill(ref_logprob, ~response_mask[:, 1:], INVALID_LOGPROB) - collated_ref_logprobs.append(ref_logprob) - torch.cuda.empty_cache() # if we have multiple minibatches, we need to calculate the old logprobs for each minibatch # following gtrl scripts in just doing this on the current active policy, rather than use the logprobs # from the generator (note that async mode means these are a bit diff!) @@ -1125,11 +1057,6 @@ def train( local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): - kl1_stats = torch.zeros(len(collated_query_responses)) - kl2_stats = torch.zeros(len(collated_query_responses)) - kl3_stats = torch.zeros(len(collated_query_responses)) - kl4_stats = torch.zeros(len(collated_query_responses)) - kl_loss_stats = torch.zeros(len(collated_query_responses)) pg_clipfrac_stats = torch.zeros(len(collated_query_responses)) pg_loss_stats = torch.zeros(len(collated_query_responses)) loss_stats = torch.zeros(len(collated_query_responses)) @@ -1137,7 +1064,6 @@ def train( entropy_stats = torch.zeros(len(collated_query_responses)) for epoch_idx in range(args.num_epochs): for i in range(len(collated_query_responses)): - mb_ref_logprob = collated_ref_logprobs[i] mb_query_responses = collated_query_responses[i] mb_tool_mask = collated_tool_masks[i] mb_advantages = collated_advantages[i] @@ -1250,30 +1176,9 @@ def train( pg_losses2 = pg_losses2 * tis_imp_ratio pg_loss_max = torch.max(pg_losses, pg_losses2) - - # Here we recalculate kl: we want the KL loss to backpropagate through the model - # We also clamp the KL loss to avoid numerical instability - # https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae - ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - kl1 = ref_logprobs_diff - kl2 = (ref_logprobs_diff) ** 2 / 2 - kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff # this is more numerically stable - kl4 = ratio * ref_logprobs_diff - if args.kl_estimator == "kl1": - kl = kl1 - elif args.kl_estimator == "kl2": - kl = kl2 - elif args.kl_estimator == "kl3": - kl = kl3 - elif args.kl_estimator == "kl4": - kl = kl4 - # grpo change: directly subtract KL in loss (add) loss = masked_mean( - pg_loss_max + (args.beta * kl), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, + pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator ) loss = loss / accumulation_steps # Clear CUDA cache before backward pass to free memory for reduce_scatter operations @@ -1283,27 +1188,6 @@ def train( self.model.step() local_step += 1 with torch.no_grad(): - # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - if args.kl_estimator == "kl1": - kl_loss_stats[i] = kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - kl_loss_stats[i] = kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - kl_loss_stats[i] = kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - kl_loss_stats[i] = kl4_stats[i] * args.beta pg_clipfrac_stats[i] = masked_mean( (pg_losses2 > pg_losses).float(), mb_response_masks_bool, @@ -1324,12 +1208,7 @@ def train( ).float() with torch.no_grad(): - self.local_metrics["objective/kl_avg"] = kl1_stats.mean() - self.local_metrics["objective/kl2_avg"] = kl2_stats.mean() - self.local_metrics["objective/kl3_avg"] = kl3_stats.mean() - self.local_metrics["objective/kl4_avg"] = kl4_stats.mean() self.local_metrics["loss/policy_avg"] = pg_loss_stats.mean() - self.local_metrics["loss/kl_avg"] = kl_loss_stats.mean() self.local_metrics["loss/total_avg"] = loss_stats.mean() self.local_metrics["policy/clipfrac_avg"] = pg_clipfrac_stats.mean() self.local_metrics["val/ratio"] = ratio_stats.mean() @@ -1360,31 +1239,13 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[st client_state["rng_states"] = rng_states client_state["rank"] = self.rank - # Save reference policy checkpoint (model only, no optimizer) - if hasattr(self, "ref_policy") and self.ref_policy is not None: - ref_policy_dir = os.path.join(checkpoint_state_dir, "ref_policy") - os.makedirs(ref_policy_dir, exist_ok=True) - - # For reference policy, we save just the model weights - # We can't use save_checkpoint because it would try to save DummyOptim - # which doesn't have state_dict - if self.rank == 0: - # Only rank 0 saves the model state - model_to_save = self.ref_policy.module if hasattr(self.ref_policy, "module") else self.ref_policy - - # Save the state dict - torch.save(model_to_save.state_dict(), os.path.join(ref_policy_dir, "pytorch_model.bin")) - logger.info(f"Saved reference policy model to {ref_policy_dir}") - - client_state["ref_policy_saved"] = True - # Save the main model checkpoint with enhanced client state self.model.save_checkpoint(checkpoint_state_dir, client_state=client_state) # `save_checkpoint` needs to be called on all ranks, only rank 0 will have all the states if self.rank == 0: - if args.keep_last_n_checkpoints >= 0: - clean_last_n_checkpoints_deepspeed(checkpoint_state_dir, args.keep_last_n_checkpoints) + if args.keep_last_n_checkpoint_states >= 0: + clean_last_n_checkpoints_deepspeed(checkpoint_state_dir, args.keep_last_n_checkpoint_states) # Sync to GCS if configured (check the actual target, not just gs_bucket_path) if args.gs_checkpoint_state_dir is not None: @@ -1451,7 +1312,9 @@ def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTra self.tokenizer.save_pretrained(output_dir) # we need this because we don't know which node is rank 0 is on - def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url, training_step): + def launch_ai2_evals_on_weka_wrapper( + self, step_dir, leaderboard_name, wandb_url, training_step, run_eval_jobs: bool = True + ): args = self.args if self.rank == 0: ray.remote(launch_ai2_evals_on_weka).options(num_cpus=1).remote( @@ -1464,8 +1327,10 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url args.stop_strings, args.gs_bucket_path, args.eval_priority, + args.eval_workspace, args.oe_eval_beaker_image, args.oe_eval_gpu_multiplier, + run_eval_jobs, ) @@ -1668,6 +1533,7 @@ def accumulate_inference_batches( actor_manager=None, timeout: float | None = None, filter_zero_std_samples: bool = False, + active_sampling: bool = False, no_resampling_pass_rate: float | None = None, iter_dataloader: ShufflingIterator | None = None, ) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: @@ -1680,7 +1546,8 @@ def accumulate_inference_batches( generation_config: Generation config containing n (number of samples per prompt) num_prompts: Number of prompts to accumulate timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - filter_zero_std_samples: Whether to filter samples with zero reward std and continue sampling + filter_zero_std_samples: Whether to filter out samples with zero reward std + active_sampling: Whether to continue sampling when a sample is filtered (try again) no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate and exclude them from further sampling iter_dataloader: Optional, used for no_resampling_pass_rate @@ -1712,8 +1579,25 @@ def accumulate_inference_batches( bar_format="{l_bar}{bar}{r_bar}\n", disable=not args.verbose, ) - while len(results) < num_prompts: - result = inference_results_Q.get(timeout=timeout) + + def process_result(result): + """Process a single result and optionally add it to results.""" + nonlocal \ + total_filtered_prompts, \ + filtered_prompt_zero, \ + filtered_prompt_solved, \ + filtered_prompt_nonzero, \ + total_no_resampled + nonlocal \ + results, \ + all_queries, \ + all_ground_truths, \ + all_datasets, \ + all_raw_queries, \ + all_decoded_responses, \ + all_scores, \ + all_reward_metrics, \ + all_percent_solved if isinstance(result, ShutdownSentinel): return result, None, None, None @@ -1775,8 +1659,10 @@ def accumulate_inference_batches( logging.debug( f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" ) - continue + # Don't add filtered samples - return early + return None + # Add the result if not filtered results.append(result) all_queries.extend(k_queries) all_ground_truths.extend(k_ground_truths) @@ -1787,6 +1673,24 @@ def accumulate_inference_batches( all_reward_metrics.append(reward_metrics) all_percent_solved.append(percent_solved) progress_bar.update(1) + return None + + if active_sampling: + # With active_sampling, use while loop to retry when filtering + while len(results) < num_prompts: + result = inference_results_Q.get(timeout=timeout) + shutdown_result = process_result(result) + if shutdown_result is not None and shutdown_result[0] is not None: + return shutdown_result + else: + # Without active_sampling, use for loop for exactly num_prompts iterations + # note we may filter out samples and end up with less than num_prompts + # if zero_std_filtering is True. + for _ in range(num_prompts): + result = inference_results_Q.get(timeout=timeout) + shutdown_result = process_result(result) + if shutdown_result is not None and shutdown_result[0] is not None: + return shutdown_result # Combine all results into a single GenerationResult combined_responses = [] @@ -1860,7 +1764,7 @@ def accumulate_inference_batches( masks=combined_masks, request_info=combined_request_info, dataset_index=None, - epoch_number=results[0].epoch_number, + epoch_number=results[0].epoch_number if results else None, token_statistics=accumulated_stats, logprobs=combined_logprobs, ) @@ -1927,6 +1831,7 @@ def data_preparation_thread( reward_fn=reward_fn, actor_manager=actor_manager, filter_zero_std_samples=args.filter_zero_std_samples, + active_sampling=args.active_sampling, no_resampling_pass_rate=args.no_resampling_pass_rate, iter_dataloader=iter_dataloader, ) @@ -2423,6 +2328,7 @@ def weight_sync_thread( policy_group: ModelGroup, actor_manager: ActorManager, weight_sync_metrics_Q: Queue, + params_lock: threading.Lock, resume_training_step: int = 1, ): """Thread function that handles weight sync operations and actor manager coordination.""" @@ -2441,24 +2347,27 @@ def weight_sync_thread( with Timer("[Weight Sync]") as timer: logger.debug("[Weight Sync Thread] Starting weight sync") - # Set actors to stop - ray.get(actor_manager.set_should_stop.remote(True)) - logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync") + with params_lock: + # Set actors to stop + ray.get(actor_manager.set_should_stop.remote(True)) + logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync") - # Broadcast weights to vLLM engines - # First get the futures - weight_broadcast_futures: list[ray.ObjectRef] = [m.broadcast_to_vllm.remote() for m in policy_group.models] + # Broadcast weights to vLLM engines + # First get the futures + weight_broadcast_futures: list[ray.ObjectRef] = [ + m.broadcast_to_vllm.remote() for m in policy_group.models + ] - # Wait for all weight updates to complete and collect individual timings - _, actor_sync_times = ray_get_with_progress( - weight_broadcast_futures, - desc="[Weight Sync Thread] Waiting for weight updates to complete", - enable=args.verbose, - ) + # Wait for all weight updates to complete and collect individual timings + _, actor_sync_times = ray_get_with_progress( + weight_broadcast_futures, + desc="[Weight Sync Thread] Waiting for weight updates to complete", + enable=args.verbose, + ) - # Allow actors to resume - ray.get(actor_manager.set_should_stop.remote(False)) - logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync") + # Allow actors to resume + ray.get(actor_manager.set_should_stop.remote(False)) + logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync") # Calculate distribution statistics sync_time_stats = { @@ -2499,7 +2408,6 @@ def one_training_step( iter_dataloader: Iterator | None = None, ) -> None: """Train the model for one step.""" - update_ref_policy_future = [] with Timer("[Main Thread] 🗡️ Training") as train_timer: metrics_list, _ = ray_get_with_progress( [ @@ -2510,14 +2418,6 @@ def one_training_step( ], desc=f"Running training step {training_step}", ) - if ( - args.ref_policy_update_freq is not None - and training_step % args.ref_policy_update_freq == 0 - and args.alpha > 0 - ): - update_ref_policy_future.extend( - [policy_group.models[i].update_ref_policy.remote() for i in range(args.world_size)] - ) save_time = 0 if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1): @@ -2525,25 +2425,46 @@ def one_training_step( checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") logger.info(f"Saving model at step {training_step} to {step_dir}") - ray_get_with_progress( - [ - policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer) - for i in range(args.world_size) - ], - desc=f"Saving model at step {training_step}", - ) - if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): - leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" - for i in range(args.world_size): - policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote( - step_dir, leaderboard_name, wandb_url, training_step + max_retries = 3 + for attempt in range(max_retries): + try: + ray_get_with_progress( + [ + policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer) + for i in range(args.world_size) + ], + desc=f"Saving model at step {training_step}", + timeout=600, ) + logger.info(f"Saved model at step {training_step} to {step_dir}") + # Clean up old checkpoints (only on rank 0, similar to checkpoint states) + if args.keep_last_n_checkpoints >= 0: + try: + clean_last_n_checkpoints(checkpoint_dir, args.keep_last_n_checkpoints) + except Exception as cleanup_error: + logger.warning(f"Failed to clean up old checkpoints: {cleanup_error}") + break + except Exception as e: + if attempt < max_retries - 1: + logger.warning( + f"Failed to save model at step {training_step} (attempt {attempt + 1}/{max_retries}): {e}. Retrying..." + ) + else: + logger.error(f"Failed to save model at step {training_step} after {max_retries} attempts: {e}") + raise + leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" + run_eval_jobs = ( + args.beaker_eval_freq > 0 + and training_step % args.beaker_eval_freq == 0 + and args.try_launch_beaker_eval_jobs_on_weka + and is_beaker_job() + ) + for i in range(args.world_size): + policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote( + step_dir, leaderboard_name, wandb_url, training_step, run_eval_jobs + ) save_time += timer.duration - if len(update_ref_policy_future) > 0: - with Timer("[Main Thread] 🔃 Updating reference policy"): - ray_get_with_progress(update_ref_policy_future, desc="Updating reference policy") - ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} @@ -2700,7 +2621,7 @@ def save_final_model( leaderboard_name = args.hf_repo_revision for i in range(args.world_size): policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote( - args.output_dir, leaderboard_name, wandb_url, training_step + args.output_dir, leaderboard_name, wandb_url, training_step, True ) @@ -2909,6 +2830,7 @@ def run_training( logger.info("======== ✅ weight sync thread starts =========") weight_sync_trigger_event = threading.Event() + params_lock = threading.Lock() weight_sync_thread_future = executor.submit( weight_sync_thread, args, @@ -2917,6 +2839,7 @@ def run_training( policy_group, actor_manager, weight_sync_metrics_Q, + params_lock, resume_training_step, ) @@ -3087,14 +3010,35 @@ def health_check_fn(): if iter_dataloader is not None: client_state["shuffling_iterator_state"] = iter_dataloader.get_state() - ray_get_with_progress( - [ - policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) - for i in range(args.world_size) - ], - desc=f"Saving checkpoint state at step {training_step}", - ) - logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}") + # Retry checkpoint save up to 3 times + max_retries = 3 + for attempt in range(max_retries): + try: + with params_lock: + ray_get_with_progress( + [ + policy_group.models[i].save_checkpoint_state.remote( + args.checkpoint_state_dir, client_state + ) + for i in range(args.world_size) + ], + desc=f"Saving checkpoint state at step {training_step}", + timeout=1800, + ) + logger.info( + f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}" + ) + break + except Exception as e: + if attempt < max_retries - 1: + logger.warning( + f"Failed to save checkpoint state at step {training_step} (attempt {attempt + 1}/{max_retries}): {e}. Retrying..." + ) + else: + logger.error( + f"Failed to save checkpoint state at step {training_step} after {max_retries} attempts: {e}" + ) + raise maybe_evaluate( args, @@ -3235,6 +3179,14 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): model_dims, checkpoint_state, ) + except Exception as e: + beaker_url = utils.get_beaker_experiment_url() + if beaker_url: + error_message = f" A RL job has died. Check it out: {beaker_url}. Error message: {str(e)}" + else: + error_message = f" A RL job has died. Error message: {str(e)}" + utils.send_slack_alert(error_message) + raise finally: cleanup_training_resources( stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 3d720c869a..504186ef14 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -66,9 +66,13 @@ from tqdm import tqdm from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser -import mason from open_instruct import logger_utils +WEKA_CLUSTERS = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] +GCP_CLUSTERS = ["ai2/augusta"] + +INTERCONNECT_CLUSTERS = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"] + MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -786,6 +790,7 @@ def clean_last_n_checkpoints_deepspeed(output_dir: str, keep_last_n_checkpoints: print("Remaining files:" + str(os.listdir(output_dir))) + def calibrate_checkpoint_state_dir(checkpoint_state_dir: str) -> None: """ Find the latest valid checkpoint directory and update the 'latest' file. @@ -1167,11 +1172,12 @@ def launch_ai2_evals_on_weka( eval_workspace: str | None = "ai2/tulu-3-results", beaker_image: str | None = None, oe_eval_gpu_multiplier: int | None = None, + run_eval_jobs: bool = True, ) -> None: beaker_users = get_beaker_whoami() if gs_bucket_path is not None: - cluster_str = f"--cluster {' '.join(mason.GCP_CLUSTERS)}" + cluster_str = f"--cluster {' '.join(GCP_CLUSTERS)}" if beaker_users is not None: gs_saved_path = f"{gs_bucket_path}/{beaker_users}/{path}" else: @@ -1193,6 +1199,10 @@ def launch_ai2_evals_on_weka( path = gs_saved_path else: cluster_str = "" + + if not run_eval_jobs: + return + command = f"""\ python scripts/submit_eval_jobs.py \ --model_name {leaderboard_name} \ @@ -1257,6 +1267,26 @@ def wandb_url_to_run_path(url: str) -> str: return f"{entity}/{project}/{run_id}" +def send_slack_alert(message: str) -> None: + slack_webhook_url = os.environ.get("SLACK_WEBHOOK") + if not slack_webhook_url: + logger.warning("SLACK_WEBHOOK environment variable not set. Skipping Slack alert.") + return + payload = {"text": message} + response = requests.post(slack_webhook_url, json=payload) + response.raise_for_status() + + +def get_beaker_experiment_url() -> str | None: + try: + beaker_client = beaker.Beaker.from_env() + workload = beaker_client.workload.get(os.environ["BEAKER_WORKLOAD_ID"]) + url = beaker_client.experiment.url(workload.experiment) + return url + except Exception: + return None + + # ---------------------------------------------------------------------------- # HF utilities diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index cb3f2f6e8d..449c6d8670 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -7,7 +7,7 @@ import yaml -import mason +from open_instruct import utils ######################################## @@ -181,7 +181,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): # remove nfs if asked or jupiter in cluster list. weka_available = False -if all(c in mason.WEKA_CLUSTERS for c in cluster): +if all(c in utils.WEKA_CLUSTERS for c in cluster): d1["tasks"][0]["datasets"].append({"mountPath": "/weka/oe-adapt-default", "source": {"weka": "oe-adapt-default"}}) d1["tasks"][0]["datasets"].append( {"mountPath": "/weka/oe-training-default", "source": {"weka": "oe-training-default"}} diff --git a/scripts/train/debug/single_gpu_on_beaker.sh b/scripts/train/debug/single_gpu_on_beaker.sh index fbbe0f5901..347d5366e6 100755 --- a/scripts/train/debug/single_gpu_on_beaker.sh +++ b/scripts/train/debug/single_gpu_on_beaker.sh @@ -9,7 +9,6 @@ echo "Using Beaker image: $BEAKER_IMAGE" uv run python mason.py \ --cluster ai2/jupiter \ --cluster ai2/saturn \ - --cluster ai2/ceres \ --image "$BEAKER_IMAGE" \ --description "Single GPU on Beaker test script." \ --pure_docker_mode \ diff --git a/scripts/train/olmo3/32b_rl_smoke_test.sh b/scripts/train/olmo3/32b_rl_smoke_test.sh index 17b8db54a6..8ba36837fc 100755 --- a/scripts/train/olmo3/32b_rl_smoke_test.sh +++ b/scripts/train/olmo3/32b_rl_smoke_test.sh @@ -1,25 +1,58 @@ #!/bin/bash +export lr=1e-6 +export seed=1 +export exp_name=BIGTEST_test_dpo_olmo3_32b_s${seed}_lr${lr}_${RANDOM} +export data_mix="hamishivi/math_rlvr_mixture_dpo 1.0 saurabh5/code_rlvr_mixture_dpo 1.0 allenai/IF_multi_constraints_upto5_filtered_dpo_0625_filter-keyword-filtered-topic-char-topic-filtered 30186 allenai/rlvr_general_mix-keyword-filtered-topic-chars-char-filt-topic-filtered 21387" +export beaker_image=hamishivi/open_instruct_rl32_no_ref18 +export gs_model_name=test_dpo_olmo3_32b_s${seed}_lr${lr}_${RANDOM} +export cluster=ai2/augusta +# if we need the sft model, use this: +# export model_path=/weka/oe-adapt-default/hamishi/model_checkpoints/final_olmo_32b_sft +export model_path=/weka/oe-adapt-default/allennlp/deletable_checkpoint/scottg/olmo3-32b-DPO-8k-0.6b-200k-lucafilt-s42-7e-8__42__1762948744 -export exp_name=test_olmo3_32b_rl_run_${RANDOM} -export data_mix="hamishivi/math_rlvr_mixture_dpo 1.0 hamishivi/code_rlvr_mixture_dpo 1.0 hamishivi/IF_multi_constraints_upto5_filtered_dpo_0625_filter 30186 allenai/rlvr_general_mix-keyword-filtered 21387" -export beaker_image=hamishivi/open_instruct_rl32_test10 -export model_path=/weka/oe-adapt-default/hamishi/model_checkpoints/olmo3-merge-32b-1e-4-5e-5/olmo3-merge-32b-1e-4-5e-5/ - +# annoying restart nonsense. +# +# 2e-6 s1 ckpt +# export lr=2e-6 +# export model_path=/weka/oe-adapt-default/hamishi/olmo_3_emergency_ckpts/test_dpo_olmo3_32b_s1_lr2e-6_1305__1__1763092241_checkpoints_step_100 +# export exp_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# export gs_model_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# +# 1e-6 s42 ckpt +# export model_path=/weka/oe-adapt-default/hamishi/olmo_3_emergency_ckpts/test_dpo_olmo3_32b_s42_lr1e-6_31282__42__1763093382_checkpoints_step_100 +# export seed=42 +# export exp_name=dpo_olmo3_32b_res100_jup_s${seed}_lr${lr}_${RANDOM} +# export gs_model_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# export cluster=ai2/jupiter +# +# 1e-6 s1 ckpt +# export model_path=/weka/oe-adapt-default/hamishi/olmo_3_emergency_ckpts/test_dpo_olmo3_32b_s1_lr1e-6_25849__1__1763083366_checkpoints_step_200 +# export seed=1 +# export exp_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# export gs_model_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# +# 5e-7 s1 ckpt +# export lr=5e-7 +# export model_path=/weka/oe-adapt-default/hamishi/olmo_3_emergency_ckpts/test_dpo_olmo3_32b_s1_lr5e-7_9654__1__1763092410_checkpoints_step_200 +# export seed=1 +# export exp_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} +# export gs_model_name=test_dpo_olmo3_32b_res100_s${seed}_lr${lr}_${RANDOM} python mason.py \ --budget ai2/oe-adapt \ - --cluster ai2/augusta \ + --cluster ${cluster} \ --image ${beaker_image} \ --pure_docker_mode \ --workspace ai2/olmo-instruct \ --priority urgent \ - --gs_model_name "sft_olmo3_32b_rl_run_testing" \ + --gs_model_name "${gs_model_name}" \ --preemptible \ - --num_nodes 18 \ + --num_nodes 28 \ --gpus 8 \ --max_retries 0 \ --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ --env LD_LIBRARY_PATH=/var/lib/tcpxo/lib64 \ --env NCCL_LIB_DIR=/var/lib/tcpxo/lib64 \ --env HOSTED_VLLM_API_BASE=http://ceres-cs-aus-447.reviz.ai2.in:8001/v1 \ @@ -27,10 +60,10 @@ python mason.py \ --exp_name ${exp_name} \ --beta 0.0 \ --num_samples_per_prompt_rollout 8 \ - --num_unique_prompts_rollout 64 \ + --num_unique_prompts_rollout 128 \ --num_mini_batches 1 \ --num_epochs 1 \ - --learning_rate 1e-6 \ + --learning_rate ${lr} \ --per_device_train_batch_size 1 \ --output_dir /output \ --kl_estimator kl3 \ @@ -50,15 +83,17 @@ python mason.py \ --sft_messages_key messages \ --total_episodes 10000000 \ --deepspeed_stage 3 \ - --num_learners_per_node 8 8 8 8 8 8 8 8 8 8 8 8 \ - --vllm_num_engines 6 \ + --num_learners_per_node 8 8 8 8 8 8 8 8 \ + --vllm_num_engines 20 \ + --inference_batch_size 200 \ --gather_whole_model False \ --vllm_tensor_parallel_size 8 \ --lr_scheduler_type constant \ --apply_verifiable_reward true \ - --seed 1 \ + --seed ${seed} \ --local_eval_every 50 \ - --save_freq 25 \ + --save_freq 50 \ + --beaker_eval_freq 50 \ --eval_priority urgent \ --try_launch_beaker_eval_jobs_on_weka True \ --gradient_checkpointing \ @@ -72,15 +107,16 @@ python mason.py \ --use_fp8_kv_cache False \ --code_api_url https://p9f1719l7f.execute-api.us-west-2.amazonaws.com/prod/test_program \ --code_pass_rate_reward_threshold 0.99 \ + --code_max_execution_time 6 \ --oe_eval_max_length 32768 \ - --checkpoint_state_freq 100 \ + --checkpoint_state_freq 5 \ --backend_timeout 1200 \ --inflight_updates true \ --async_steps 8 \ --active_sampling \ --advantage_normalization_type centered \ --truncated_importance_sampling_ratio_cap 2.0 \ - --oe_eval_beaker_image oe-eval-beaker/oe_eval_olmo2_retrofit_auto \ - --oe_eval_tasks mmlu:cot::hamish_zs_reasoning_deepseek,bbh:cot::hamish_zs_reasoning_deepseek_v2,gpqa:0shot_cot::qwen3-instruct,zebralogic::hamish_zs_reasoning_deepseek,agi_eval_english:0shot_cot::hamish_zs_reasoning_deepseek,omega_500:0-shot-chat_deepseek,aime:zs_cot_r1::pass_at_32_2024_deepseek,aime:zs_cot_r1::pass_at_32_2025_deepseek,codex_humanevalplus:0-shot-chat::tulu-thinker_deepseek,mbppplus:0-shot-chat::tulu-thinker_deepseek,livecodebench_codegeneration::tulu-thinker_deepseek,alpaca_eval_v3::hamish_zs_reasoning_deepseek,ifeval::hamish_zs_reasoning_deepseek \ + --oe_eval_tasks mmlu:cot::hamish_zs_reasoning_deepseek,bbh:cot::hamish_zs_reasoning_deepseek_v2,gpqa:0shot_cot::qwen3-instruct,zebralogic::hamish_zs_reasoning_deepseek,agi_eval_english:0shot_cot::hamish_zs_reasoning_deepseek,omega_500:0-shot-chat_deepseek,aime:zs_cot_r1::pass_at_32_2024_deepseek,aime:zs_cot_r1::pass_at_32_2025_deepseek,codex_humanevalplus:0-shot-chat::tulu-thinker_deepseek,mbppplus:0-shot-chat::tulu-thinker_deepseek,livecodebench_codegeneration::tulu-thinker_deepseek_no_think_tags_lite,alpaca_eval_v3::hamish_zs_reasoning_deepseek,ifeval::hamish_zs_reasoning_deepseek \ + --oe_eval_gpu_multiplier 2 \ --vllm_enforce_eager \ - --deepspeed_zpg 32 \ No newline at end of file + --deepspeed_zpg 1 \ No newline at end of file