diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 13ee119ee..dcb140b39 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -36,11 +36,11 @@ with contextlib.suppress(Exception): import deepspeed -from open_instruct import utils +from open_instruct import streaming_data_loader, utils +from open_instruct.streaming_data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn # isort: on import asyncio -import json import logging import math import random @@ -50,7 +50,7 @@ import time from argparse import Namespace from collections import defaultdict -from collections.abc import Callable, Iterator +from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import timedelta from queue import Empty, Full, Queue @@ -73,17 +73,13 @@ from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint -from tqdm import tqdm from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig from open_instruct import logger_utils, vllm_utils from open_instruct.actor_manager import ActorManager from open_instruct.dataset_transformation import ( - GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, - RAW_PROMPT_KEY, - VERIFIER_SOURCE_KEY, TokenizerConfig, get_cached_dataset_tulu, visualize_token, @@ -94,7 +90,6 @@ soft_format_reward_func, ) from open_instruct.model_utils import ( - Batch, ModelConfig, apply_verifiable_reward, disable_dropout_in_model, @@ -105,8 +100,8 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.queue_types import ShutdownSentinel +from open_instruct.rl_utils import PackedSequences, Timer from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -114,7 +109,6 @@ _z3_params_to_fetch, calibrate_checkpoint_state_dir, clean_last_n_checkpoints_deepspeed, - combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, get_eval_ds_config, @@ -128,7 +122,6 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, ray_get_with_progress, - repeat_each, sync_gs_bucket, ) @@ -137,10 +130,6 @@ INVALID_LOGPROB = 1.0 -class ShutdownSentinel: - """Sentinel value to signal thread shutdown via queue.""" - - @dataclass class Args: # Dataset @@ -166,8 +155,6 @@ class Args: """Whether to skip the cache.""" shuffle_eval_dataset: bool = False """Whether to shuffle the evaluation dataset.""" - max_prompt_token_length: int = 256 - """The maximum prompt token length to use for the dataset""" system_prompt_override_file: str | None = None """Path to a text file containing a system prompt to override the dataset's system prompts""" @@ -216,14 +203,10 @@ class Args: """Timeout for inference/training backends in minutes. Default is 2 hours (120 min).""" # Generation - response_length: int = 256 - """the length of the response""" temperature: float = 0.7 """the sampling temperature""" num_unique_prompts_rollout: int = 16 """The number of unique prompts during rollout""" - num_samples_per_prompt_rollout: int = 4 - """the number of samples to generate per prompt during rollout, useful for easy-star""" stop_strings: list[str] | None = None """List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.""" @@ -231,8 +214,6 @@ class Args: """Whether to use fp8 kv cache. This is useful for larger models or olmo.""" # Algorithm - async_steps: int = 1 - """Number of steps ahead to generate responses. Set to 0 to make the code synchronous. Values greater than 0 learn from a policy up to async_steps old like Cleanba (https://arxiv.org/abs/2310.00036)""" num_epochs: int = 1 """the number of epochs to train""" num_mini_batches: int = 1 @@ -249,8 +230,6 @@ class Args: """Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption""" kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3" """the KL estimator to use""" - pack_length: int = 512 - """the length of the pack (you should prob set to the max length of the model)""" masked_mean_axis: int | None = None """the axis to compute the mean of the masked values""" masked_mean_denominator: float | None = None @@ -262,19 +241,6 @@ class Args: """ ref_policy_update_freq: int | None = None """How many training steps to take before updating the reference policy.""" - advantage_normalization_type: Literal["standard", "centered"] = "standard" - """The type of advantage normalization to use. Standard normalization is the default: it subtracts the mean and - divides by the standard deviation. Centered normalization is the same but subtracts the mean only (e.g., used in - DR.GRPO https://arxiv.org/pdf/2503.20783).""" - mask_truncated_completions: bool = False - """Whether to mask out truncated completions. Also called overlong filtering, from DAPO (https://arxiv.org/abs/2503.14476).""" - - active_sampling: bool = False - """Whether to continue sampling responses until you get a full batch.""" - filter_zero_std_samples: bool = True - """Whether to filter out prompts with zero reward std (all samples have the same score).""" - no_resampling_pass_rate: float | None = None - """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" @@ -460,9 +426,6 @@ def __post_init__(self): assert self.masked_mean_denominator > 0, ( f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" ) - assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" - if self.num_samples_per_prompt_rollout == 1: - logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") assert self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty, ( "At least one reward must be applied!" ) @@ -476,12 +439,6 @@ def __post_init__(self): # Initialize stop_strings if None if self.stop_strings is None: self.stop_strings = [] - if self.inference_batch_size is None: - total_prompts = self.num_samples_per_prompt_rollout * self.num_unique_prompts_rollout - self.inference_batch_size = max(1, math.ceil(total_prompts / self.vllm_num_engines)) - assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( - "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" - ) if self.checkpoint_state_freq > 0 and self.checkpoint_state_dir is None: raise ValueError("`checkpoint_state_dir` must be provided if `checkpoint_state_freq` is greater than 0!") if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1: @@ -523,22 +480,6 @@ def __post_init__(self): if self.apply_r1_style_format_reward and self.additive_format_reward: self.max_possible_score += self.r1_style_format_reward - if self.active_sampling: - assert self.async_steps > 1, ( - "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " - "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " - "prompt will cause the trainer to stall waiting for more data . " - ) - assert self.filter_zero_std_samples, ( - "filter_zero_std_samples must be True when active_sampling is True. " - "Active sampling requires filtering to work correctly." - ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: - raise ValueError( - "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " - "as the reward standard deviation will always be 0, causing all samples to be filtered." - ) - def masked_mean( values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None @@ -549,13 +490,6 @@ def masked_mean( return (numerator / denom).mean() -def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: - padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory: - padded_tensor = padded_tensor.pin_memory() - return padded_tensor - - @Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker") def prepare_collated_data_for_workers( packed_sequences: PackedSequences, @@ -645,72 +579,52 @@ def to_device_inplace(tensors_list: list[torch.Tensor], device: torch.device): tensors_list[i] = tensors_list[i].to(device, non_blocking=True) -class ShufflingIterator: - def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None): - self.data = data.copy() - self.batch_size = batch_size - self.index = 0 - self.epoch_number = 0 - self.rng = np.random.default_rng(seed) - self.rng.shuffle(self.data) - self.exclude_list = [] - - self._update_effective_size() - - def __iter__(self) -> Iterator[list[int]]: - return self - - def __next__(self) -> list[int] | int: - """Return a list of next indices or a single index if batch size is 1""" - if self.index >= self.effective_size: - self.index = 0 - self._update_effective_size() - self.epoch_number += 1 - self.rng.shuffle(self.data) - - end_index = self.index + self.batch_size - batch = self.data[self.index : end_index].tolist() - if self.batch_size == 1: - batch = batch[0] - self.index = end_index - - return batch - - def get_state(self) -> dict[str, Any]: - """Get the current state of the iterator for checkpointing.""" - return { - "index": self.index, - "epoch_number": self.epoch_number, - "data": self.data.copy(), - "rng_state": self.rng.bit_generator.state, - "exclude_list": self.exclude_list.copy(), - } - - def set_state(self, state: dict[str, Any]) -> None: - """Restore the iterator state from a checkpoint.""" - self.index = state["index"] - self.epoch_number = state.get("epoch_number", 0) - self.data = state["data"].copy() - self.rng.bit_generator.state = state["rng_state"] - self.exclude_list = state.get("exclude_list", []) - self._update_effective_size() - - def exclude_index(self, index: int) -> None: - """Exclude provided data points from future sampling.""" - self.exclude_list.append(index) - - def _update_effective_size(self) -> None: - """Ensure the effective dataset size is divisible by batch_size and filter out all the indices excluded in the last epoch""" - if self.exclude_list: - mask = ~np.isin(self.data, self.exclude_list) - self.data = self.data[mask] - self.exclude_list = [] - - self.effective_size = len(self.data) - (len(self.data) % self.batch_size) - - @ray.remote(num_gpus=1) class PolicyTrainerRayProcess(RayProcess): + def __init__( + self, + world_size: int, + rank: int, + local_rank: int, + master_addr: str | None, + master_port: int | None, + args: Args, + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + generation_config, + actor_manager, + model_dims: utils.ModelDims, + ): + super().__init__(world_size, rank, local_rank, master_addr, master_port) + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + self.num_mini_batches = args.num_mini_batches + dataset = dataset.shard(num_shards=world_size, index=rank) + self.dataloader = data_loader_config.build( + dataset=dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + tokenizer=tokenizer, + generation_config=generation_config, + dp_rank=self.local_rank, + fs_local_rank=self.local_rank, + num_training_steps=args.num_training_steps, + seed=args.seed, + per_device_train_batch_size=args.per_device_train_batch_size, + verbose=args.verbose, + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=world_size, + max_possible_score=args.max_possible_score, + actor_manager=actor_manager, + model_dims=model_dims, + ) + def from_pretrained( self, args: Args, @@ -718,7 +632,7 @@ def from_pretrained( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, - ): + ) -> int: # ------------------------------------------------------------ # Monkey patch to load checkpoints with `weights_only=False` # otherwise it errors out with: @@ -750,7 +664,11 @@ def load(self, path: str, map_location=None): np.random.seed(worker_seed) random.seed(worker_seed) + logger.info( + f"[DEBUG] Rank {self.rank}: Initializing DeepSpeed distributed (timeout={args.backend_timeout} minutes)..." + ) deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout)) + logger.info(f"[DEBUG] Rank {self.rank}: DeepSpeed distributed initialized successfully") ds_config = get_train_ds_config( offload=args.deepspeed_offload_param, @@ -888,7 +806,7 @@ def load(self, path: str, map_location=None): else: self.ref_policy.load_state_dict(state_dict) logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") - self.local_metrics = utils.MetricsTracker(device=self.device) + self.local_metrics = utils.MetricsTracker(max_metrics=64, device=self.device) return optimization_steps_done def forward( @@ -925,8 +843,11 @@ def forward( return logprob, entropy def setup_model_update_group(self, vllm_engines): + logger = logger_utils.setup_logger(__name__) + logger.info(f"[DEBUG] Rank {self.rank}: Entered setup_model_update_group") self.vllm_engines = vllm_engines if self.rank == 0: + logger.info(f"[DEBUG] Rank 0: Initializing process group for {len(vllm_engines)} vLLM engines") master_address = ray._private.services.get_node_ip_address() with socket.socket() as sock: sock.bind(("", 0)) @@ -937,6 +858,10 @@ def setup_model_update_group(self, vllm_engines): ) world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 backend = self.args.vllm_sync_backend + logger.info( + f"[DEBUG] Rank 0: master_address={master_address}, master_port={master_port}, " + f"world_size={world_size}, backend={backend}" + ) refs = [ engine.init_process_group.remote( master_address, @@ -957,8 +882,15 @@ def setup_model_update_group(self, vllm_engines): group_name="openrlhf", timeout=timedelta(minutes=self.args.backend_timeout), ) + logger.info( + f"[DEBUG] Rank 0: Waiting for {len(refs)} vLLM engines to initialize process groups (timeout=600s)..." + ) ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=600) + logger.info("[DEBUG] Rank 0: All vLLM engines initialized, approaching barrier") + else: + logger.info(f"[DEBUG] Rank {self.rank}: Approaching barrier") torch.distributed.barrier() + logger.info(f"[DEBUG] Rank {self.rank}: Passed barrier successfully") def broadcast_to_vllm(self): # avoid OOM @@ -1013,18 +945,17 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def train( - self, - collated_query_responses, - collated_tool_masks, - collated_attention_masks, - collated_position_ids, - collated_advantages, - collated_response_masks, - collated_vllm_logprobs, - pad_token_id: int, - num_mini_batches: int, - ): + def step(self): + batch_data = next(self.dataloader) + batch_metrics = batch_data["metrics"] + collated_query_responses = batch_data["collated_query_responses"] + collated_tool_masks = batch_data["collated_tool_masks"] + collated_attention_masks = batch_data["collated_attention_masks"] + collated_position_ids = batch_data["collated_position_ids"] + collated_advantages = batch_data["collated_advantages"] + collated_response_masks = batch_data["collated_response_masks"] + collated_vllm_logprobs = batch_data["collated_vllm_logprobs"] + args = self.args to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_tool_masks, self.device) @@ -1034,7 +965,11 @@ def train( to_device_inplace(collated_response_masks, self.device) to_device_inplace(collated_vllm_logprobs, self.device) # accumulation steps should always be at least 1 - accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1) + accumulation_steps = max(math.ceil(len(collated_query_responses) / self.num_mini_batches - 0.5), 1) + # Sync accumulation_steps across ranks so all learners call allreduce on the same iterations + accumulation_steps_tensor = torch.tensor([accumulation_steps], device=self.device, dtype=torch.int32) + torch.distributed.all_reduce(accumulation_steps_tensor, op=torch.distributed.ReduceOp.MIN) + accumulation_steps = int(accumulation_steps_tensor.item()) leftover = len(collated_query_responses) % accumulation_steps if leftover > 0: collated_query_responses = collated_query_responses[0:-leftover] @@ -1044,7 +979,7 @@ def train( collated_advantages = collated_advantages[0:-leftover] collated_response_masks = collated_response_masks[0:-leftover] collated_vllm_logprobs = collated_vllm_logprobs[0:-leftover] - logger.warning(f"{leftover} samples are dropped due to batch size {num_mini_batches}") + logger.warning(f"{leftover} samples are dropped due to batch size {self.num_mini_batches}") # recalculate the "real" number of mini-batches num_mini_batches = len(collated_query_responses) // accumulation_steps @@ -1063,7 +998,7 @@ def train( query_response, attention_mask, position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=False, ) @@ -1093,7 +1028,7 @@ def train( query_response, attention_mask, position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=False, ) @@ -1145,7 +1080,7 @@ def train( mb_query_responses, mb_attention_mask, mb_position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=args.record_entropy, ) @@ -1329,7 +1264,15 @@ def train( if args.record_entropy: self.local_metrics["policy/entropy_avg"] = entropy_stats.mean() self.local_metrics["lr"] = self.scheduler.get_last_lr()[0] - return self.local_metrics.get_metrics_list() + array_metrics = {} + for key, value in batch_metrics.items(): + if value is None: + continue + if isinstance(value, (int, float, np.floating, np.integer)): + self.local_metrics[key] = value + else: + array_metrics[key] = value + return self.local_metrics.get_metrics_list(), array_metrics def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: args = self.args @@ -1463,7 +1406,21 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url class ModelGroup: def __init__( - self, pg: PlacementGroup, ray_process_cls: RayProcess, num_gpus_per_node: list[int], single_gpu_mode: bool + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: list[int], + single_gpu_mode: bool, + args: Args, + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + generation_config, + actor_manager, + model_dims: utils.ModelDims, ): self.pg = pg self.ray_process_cls = ray_process_cls @@ -1478,7 +1435,23 @@ def __init__( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=self.pg, placement_group_bundle_index=0 ), - ).remote(world_size, 0, 0, None, None) + ).remote( + world_size, + 0, + 0, + None, + None, + args, + data_loader_config, + dataset, + reward_fn, + inference_results_Q, + param_prompt_Q, + tokenizer, + generation_config, + actor_manager, + model_dims, + ) self.models.append(master_policy) results, _ = ray_get_with_progress( @@ -1511,74 +1484,26 @@ def get_bundle_index(rank, num_gpus_per_node): num_cpus=self.num_cpus_per_actor, num_gpus=self.num_gpus_per_actor, scheduling_strategy=scheduling_strategy, - ).remote(world_size, rank, 0, master_addr, master_port) + ).remote( + world_size, + rank, + 0, + master_addr, + master_port, + args, + data_loader_config, + dataset, + reward_fn, + inference_results_Q, + param_prompt_Q, + tokenizer, + generation_config, + actor_manager, + model_dims, + ) self.models.append(worker_policy) -class PendingQueriesMap: - """Thread-safe map for tracking pending queries with reference counting.""" - - def __init__(self): - self._map = {} # dataset_idx -> (query, ground_truth, dataset, count) - self._lock = threading.Lock() - - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): - """Insert or increment count for a dataset index.""" - with self._lock: - if dataset_idx in self._map: - # Already exists - just increment count - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) - else: - # New entry - count starts at 1 - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) - - def pop(self, dataset_idx): - """Retrieve data and decrement count. Removes entry when count reaches 0.""" - with self._lock: - if dataset_idx not in self._map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] - - if count > 1: - # More results expected - just decrement - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) - else: - # Last result - remove entry - del self._map[dataset_idx] - - return query, ground_truth, dataset, raw_query - - def __len__(self): - """Return the number of entries in the map.""" - with self._lock: - return len(self._map) - - def __contains__(self, dataset_idx): - """Check if a dataset index is in the map.""" - with self._lock: - return dataset_idx in self._map - - def __getitem__(self, dataset_idx): - """Get the value for a dataset index.""" - with self._lock: - return self._map[dataset_idx] - - def keys(self): - """Return a view of the keys in the map.""" - with self._lock: - return list(self._map.keys()) - - def calculate_utilization_metrics( model_dims: utils.ModelDims, prompt_lengths: list[int], @@ -1636,526 +1561,7 @@ def calculate_utilization_metrics( return utilization_metrics -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - no_resampled_prompts: int - total_prompts: int - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - pending_queries_map: PendingQueriesMap, - args: Args, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - reward_fn: Callable, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, - iter_dataloader: ShufflingIterator | None = None, - prompt_dataset: Dataset = None, - param_prompt_Q: ray_queue.Queue | None = None, - training_step: int = None, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - """Accumulate multiple inference results into a single training batch. - - Args: - inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) - pending_queries_map: PendingQueriesMap instance for thread-safe query tracking - args: Arguments containing vllm_num_engines and batch size info - generation_config: Generation config containing n (number of samples per prompt) - num_prompts: Number of prompts to accumulate - timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std - filter_zero_std_samples: Whether to filter samples with zero reward std - replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result - no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate - and exclude them from further sampling - iter_dataloader: Optional, used for no_resampling_pass_rate - param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts - - Raises: - queue.Empty: If timeout is specified and no data is available within timeout. - - Returns: - Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths) - or (ShutdownSentinel, None, None, None) if shutdown signal received - """ - if no_resampling_pass_rate is not None: - assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" - - if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" - ) - - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not args.verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - # Validate that each individual result has the expected number of responses - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" - ) - - query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) - - # Replenish generation queue with new prompt - if replenish_prompts: - dataset_index = next(iter_dataloader) - add_prompt_to_generator( - prompt_dataset[dataset_index], - dataset_index, - iter_dataloader.epoch_number, - training_step, - pending_queries_map, - param_prompt_Q, - generation_config, - is_eval=False, - ) - - # TODO(finbarrtimbers): Move this to LLMRayActor. - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - # TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method. - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) - - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - k_ground_truths, - k_datasets, - result.finish_reasons, - result.request_info, - k_raw_queries, - ) - ) - - percent_solved = np.mean(scores).item() / args.max_possible_score - # Don't resample prompt that was solved at more than no_resample_positive_rate - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - iter_dataloader.exclude_index(result.dataset_index) - total_no_resampled += 1 - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" - ) - - # Filter out zero std prompts - if filter_zero_std_samples and np.std(scores) == 0: - # If we're not active sampling, still count this as a sample - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if scores[0] == 0: - filtered_prompt_zero += 1 - elif scores[0] == args.max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - all_queries.extend(k_queries) - all_ground_truths.extend(k_ground_truths) - all_datasets.extend(k_datasets) - all_raw_queries.extend(k_raw_queries) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(scores) - all_reward_metrics.append(reward_metrics) - all_percent_solved.append(percent_solved) - - # Combine all results into a single GenerationResult - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - # Use the maximum generation time across engines since they work in parallel - # This avoids including queue overhead and accumulation time in MFU/MBU calculations - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - # Create combined RequestInfo - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - # Create combined GenerationResult - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - epoch_number=results[0].epoch_number, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - # Note: We don't have dataset_indices here, but they're not needed for the returned batch - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, # Not meaningful for combined results - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - logging.info( - f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" - ) - - return combined_result, batch, combined_reward_metrics, batch_stats - - -def data_preparation_thread( - reward_fn: Callable, - inference_results_Q: ray_queue.Queue, # Ray queue - param_prompt_Q: ray_queue.Queue, - packed_sequences_Q: Queue, - pending_queries_map: dict, - args: Args, - tokenizer: PreTrainedTokenizer, - num_training_steps: int, - generation_config, - resume_training_step: int, - iter_dataloader: ShufflingIterator, - train_dataset: Dataset, - actor_manager=None, - model_dims: utils.ModelDims = None, -): - for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - inference_results_Q, - pending_queries_map, - args, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - tokenizer=tokenizer, - reward_fn=reward_fn, - actor_manager=actor_manager, - active_sampling=args.active_sampling, - filter_zero_std_samples=args.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=args.no_resampling_pass_rate, - iter_dataloader=iter_dataloader, - prompt_dataset=train_dataset, - param_prompt_Q=param_prompt_Q, - training_step=training_step, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") - - if args.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("ðŸ“Ķ [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=args.pack_length, - pad_token_id=tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - ) - num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) - # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value - # and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - # if we have less batches than world size, we need to pad out so each world is fine - # ideally, you should avoid this since its wasting computation. - if args.allow_world_padding: - with Timer("ðŸĪš [Data Preparation Thread] Padding sequences for world size"): - shortfall = args.world_size - len(packed_sequences.query_responses) - if shortfall > 0: - logger.warning( - f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this." - ) - # construct "dummy" sequences for padding out the world size - dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long) - dummy_tool_mask = torch.zeros_like(dummy_qr) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - # pad out the world size - for _ in range(shortfall): - packed_sequences.query_responses.append(dummy_qr) - packed_sequences.tool_masks.append(dummy_tool_mask) - packed_sequences.attention_masks.append(dummy_attention) - packed_sequences.position_ids.append(dummy_position_ids) - packed_sequences.response_masks.append(dummy_response_mask) - packed_sequences.advantages.append(dummy_advantage) - - collated_data = prepare_collated_data_for_workers( - packed_sequences, args.world_size, args.per_device_train_batch_size, tokenizer.pad_token_id - ) - B = len(packed_sequences.query_responses) // args.world_size - - # Create a result package with metrics and data - if len(result.responses) == 0: - # Handle empty responses case - # in this case, we won't log metrics, so it should be fine. - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = args.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout - - unsolved_num_responses = (scores < args.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) if np.all(scores == 0) else np.array(sequence_lengths[scores == args.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) if np.all(scores == args.max_possible_score) else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / args.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array([len(item) > 0 for item in result.request_info.tool_errors]).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - if args.save_traces: - traces = { - "scores": scores.tolist(), - "finish_reasons": result.finish_reasons, - "responses": result.responses, - "training_step": training_step, - **asdict(batch), # Unpack all batch fields - **reward_metrics, - } - os.makedirs(args.output_dir, exist_ok=True) - with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: - json.dump(traces, f) - f.write("\n") - - # Put the packed sequences and metrics into the output queue - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, # for debugging purposes - "collated_data": collated_data, - "metrics": metrics, - "responses_count": len(result.responses), - "num_new_tokens": num_new_tokens, - "B": B, - "prompt_lengths": batch_stats.prompt_lengths, - "response_lengths": batch_stats.response_lengths, - "num_filtered_prompts": batch_stats.filtered_prompts, - } - ) - - -def setup_runtime_variables(args: Args) -> Args: +def setup_runtime_variables(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" args.output_dir = os.path.join(args.output_dir, args.run_name) @@ -2164,8 +1570,11 @@ def setup_runtime_variables(args: Args) -> Args: args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache" args.world_size = sum(args.num_learners_per_node) args.num_training_steps = args.total_episodes // ( - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout ) + if args.inference_batch_size is None: + total_prompts = streaming_config.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout + args.inference_batch_size = max(1, math.ceil(total_prompts / args.vllm_num_engines)) args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job() if args.push_to_hub: if args.hf_repo_id is None: # auto-generate one @@ -2209,7 +1618,12 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod return beaker_config, wandb_url -def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer): +def setup_datasets( + args: Args, + tc: TokenizerConfig, + tokenizer: PreTrainedTokenizer, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, +): """Set up training and evaluation datasets.""" system_prompt_override = None if args.system_prompt_override_file is not None: @@ -2220,7 +1634,7 @@ def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokeniz transform_fn_args = [ {"system_prompt_override": system_prompt_override}, - {"max_prompt_token_length": args.max_prompt_token_length}, + {"max_prompt_token_length": streaming_config.max_prompt_token_length}, ] train_dataset = get_cached_dataset_tulu( dataset_mixer_list=args.dataset_mixer_list, @@ -2236,6 +1650,7 @@ def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokeniz system_prompt_override=system_prompt_override, ) train_dataset = train_dataset.shuffle(seed=args.seed) + train_dataset = train_dataset.map(lambda example, idx: {**example, "index": idx}, with_indices=True) eval_dataset = None if len(args.dataset_mixer_eval_list) > 0: @@ -2270,22 +1685,19 @@ def create_model_and_optimizer( inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, evaluation_inference_results_Q: ray_queue.Queue, -) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + train_dataset: Dataset, + reward_fn: Callable, + generation_config, +) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int, ActorManager, utils.ModelDims]: """Create the model, optimizer, and vLLM engines.""" # Create placement group bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] pg = placement_group(bundles, strategy="STRICT_SPREAD") ray_get_with_progress([pg.ready()], desc="Waiting for placement group") - inits = [] - policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode) - wandb_url = wandb.run.get_url() if args.with_tracking else None - inits.extend( - model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) - for model in policy_group.models - ) # Set up tools - max_len = args.max_prompt_token_length + args.response_length + max_len = data_loader_config.max_prompt_token_length + data_loader_config.response_length tool_objects = {} if args.tools: for tool in args.tools: @@ -2342,39 +1754,79 @@ def create_model_and_optimizer( use_fp8_kv_cache=args.use_fp8_kv_cache, inflight_updates=args.inflight_updates, ) + logger.info(f"[DEBUG] Created {len(vllm_engines)} vLLM engines") - results, _ = ray_get_with_progress(inits, desc="Initializing models") - resume_training_step = results[0] + 1 - episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout - logger.info("======== ✅ all models and vLLM engines initialized =========") + # Get model dimensions from vLLM engine + logger.info("[DEBUG] Fetching model dimensions from first vLLM engine...") + model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) + logger.info("======== ✅ vLLM engines and actor_manager initialized =========") # Get and set KV cache max concurrency from the first engine (all engines have the same config) # fp8 kv cache for now forces v0 engine and breaks this. + logger.info("[DEBUG] Setting up KV cache configuration...") if vllm_engines and not args.use_fp8_kv_cache: kv_cache_max_concurrency = ray.get(vllm_engines[0].get_kv_cache_info.remote()) ray.get(actor_manager.set_kv_cache_max_concurrency.remote(kv_cache_max_concurrency)) else: # dummy value ray.get(actor_manager.set_kv_cache_max_concurrency.remote(-1)) + logger.info("[DEBUG] KV cache configuration complete") + + # Now create policy actors with all dependencies + logger.info("[DEBUG] Creating ModelGroup with policy actors...") + wandb_url = wandb.run.get_url() if args.with_tracking else None + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.num_learners_per_node, + args.single_gpu_mode, + args=args, + data_loader_config=data_loader_config, + dataset=train_dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + tokenizer=tokenizer, + generation_config=generation_config, + actor_manager=actor_manager, + model_dims=model_dims, + ) + logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") + + logger.info("[DEBUG] Starting model initialization across all ranks...") + inits = [ + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) + for model in policy_group.models + ] + + results, _ = ray_get_with_progress(inits, desc="Initializing models") + resume_training_step = results[0] + 1 + episode = ( + (resume_training_step - 1) + * args.num_unique_prompts_rollout + * data_loader_config.num_samples_per_prompt_rollout + ) + logger.info("======== ✅ all models initialized =========") + logger.info("[DEBUG] Setting up model update group across all ranks...") ray_get_with_progress( [m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models], desc="Setting up model update group", ) logger.info("======== ✅ model update group setup successfully =========") - return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager + return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims -def create_generation_configs(args: Args): +def create_generation_configs(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig): """Create generation configs for training and evaluation.""" generation_config = vllm.SamplingParams( temperature=args.temperature, top_p=args.vllm_top_p, # prevent rare out-of-vocab tokens with qwen - max_tokens=args.response_length, + max_tokens=streaming_config.response_length, include_stop_str_in_output=True, skip_special_tokens=False, - n=args.num_samples_per_prompt_rollout, + n=streaming_config.num_samples_per_prompt_rollout, stop=args.stop_strings, seed=args.seed, logprobs=1, # Enable logprobs to compare with local calculations @@ -2390,74 +1842,6 @@ def create_generation_configs(args: Args): return {"train": generation_config, "eval": eval_generation_config} -def add_prompt_to_generator( - example: dict[str, Any], - example_index: int, - epoch_number: int, - training_step: int, - pending_queries_map: PendingQueriesMap, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, -) -> None: - """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) - - param_prompt_Q.put( - PromptRequest( - prompt=query, - generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example_index, - is_eval=is_eval, - ) - ) - - -def load_data_from_packing_thread( - packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None] -) -> tuple[list[dict[str, list[torch.Tensor]]] | None, dict[str, Any], int, int, list[int] | None, list[int] | None]: - """Get the packed sequences with advantages from the packing thread.""" - with Timer("[Main Thread] ðŸ“Ķ Getting packed sequences from thread") as timer: - while True: - if stop_event.is_set(): - logger.warning("[Main Thread] Stop event detected while waiting for packed sequences") - return None, {}, num_total_tokens, 0, None, None, 0 - try: - packed_data = packed_sequences_Q.get(timeout=30.0) - break - except Empty: - health_check_fn() - logger.warning("[Main Thread] Timeout waiting for packed sequences. Retrying...") - data_thread_metrics = packed_data["metrics"] - B = packed_data["B"] - collated_data = packed_data["collated_data"] - num_step_tokens = packed_data["num_new_tokens"] - num_total_tokens += num_step_tokens - prompt_lengths = packed_data["prompt_lengths"] - response_lengths = packed_data["response_lengths"] - num_filtered_prompts = packed_data["num_filtered_prompts"] - - data_thread_metrics["time/trainer_idling"] = timer.duration - if B == 0: - logger.warning("[Main Thread] ðŸĪĄ After packing, there is not enough data to train") - return None, data_thread_metrics, num_total_tokens, 0, None, None, 0 - return ( - collated_data, - data_thread_metrics, - num_total_tokens, - num_step_tokens, - prompt_lengths, - response_lengths, - num_filtered_prompts, - ) - - def weight_sync_thread( args: Args, stop_event: threading.Event, @@ -2521,37 +1905,30 @@ def weight_sync_thread( def one_training_step( args: Args, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, policy_group: ModelGroup, - collated_data: list[dict[str, list[torch.Tensor]]], tokenizer: PreTrainedTokenizer, data_thread_metrics: dict[str, Any], episode: int, training_step: int, num_total_tokens: int, - num_step_tokens: int, start_time: float, train_dataset: datasets.Dataset, training_start_time: float, wandb_url: str, chat_template_name: str, model_dims: utils.ModelDims, - prompt_lengths: list[int], - response_lengths: list[int], actor_manager: ActorManager | None = None, - iter_dataloader: Iterator | None = None, ) -> None: """Train the model for one step.""" update_ref_policy_future = [] with Timer("[Main Thread] ðŸ—Ąïļ Training") as train_timer: - metrics_list, _ = ray_get_with_progress( - [ - policy_group.models[i].train.remote( - **collated_data[i], pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches - ) - for i in range(args.world_size) - ], + results, _ = ray_get_with_progress( + [policy_group.models[i].step.remote() for i in range(args.world_size)], desc=f"Running training step {training_step}", ) + metrics_list = [r[0] for r in results] + array_metrics_list = [r[1] for r in results] if ( args.ref_policy_update_freq is not None and training_step % args.ref_policy_update_freq == 0 @@ -2588,18 +1965,29 @@ def one_training_step( ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) - average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} + all_keys = set() + for m in metrics_list: + all_keys.update(m.keys()) + average_metrics = {} + for k in all_keys: + values = [m[k] for m in metrics_list if k in m] + average_metrics[k] = sum(values) / len(values) + for key, value in array_metrics_list[0].items(): + average_metrics[key] = value step_time = time.perf_counter() - start_time total_training_time = time.perf_counter() - training_start_time - total_generation_time = data_thread_metrics["time/getting_response"] + total_generation_time = average_metrics["time/getting_response"] + prompt_lengths = array_metrics_list[0]["batch/prompt_lengths"] + response_lengths = array_metrics_list[0]["batch/response_lengths"] + num_step_tokens = sum(prompt_lengths) + sum(response_lengths) utilization_metrics = calculate_utilization_metrics( model_dims=model_dims, prompt_lengths=prompt_lengths, response_lengths=response_lengths, total_generation_time=total_generation_time, - samples_per_prompt=args.num_samples_per_prompt_rollout, + samples_per_prompt=streaming_config.num_samples_per_prompt_rollout, num_engines=args.vllm_num_engines, num_gpus_per_engine=args.vllm_tensor_parallel_size, training_time=train_timer.duration, @@ -2612,7 +2000,7 @@ def one_training_step( "training_step": training_step, "val/num_total_tokens": num_total_tokens, "val/num_step_tokens": num_step_tokens, - "epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset), + "epoch": episode / streaming_config.num_samples_per_prompt_rollout / len(train_dataset), "learner_tokens_per_second_overall": num_total_tokens / total_training_time, "learner_tokens_per_second_step": num_step_tokens / step_time, "time/total": step_time, @@ -2641,7 +2029,7 @@ def maybe_evaluate( tokenizer, reward_fn, episode, - eval_pending_queries_map: PendingQueriesMap, + eval_dataset, eval_generation_config, generate_metrics_Q: Queue, num_eval_prompts: int, @@ -2657,13 +2045,12 @@ def maybe_evaluate( # Accumulate evaluation results from all vLLM engines eval_result, eval_batch, eval_reward_metrics, _ = accumulate_inference_batches( evaluation_inference_results_Q, - eval_pending_queries_map, - args, eval_generation_config, num_prompts=num_eval_prompts, model_dims=model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=eval_dataset, actor_manager=actor_manager, timeout=timeout, active_sampling=False, @@ -2920,13 +2307,13 @@ def cleanup_training_resources( def run_training( args, + streaming_config, tokenizer, train_dataset, eval_dataset, policy_group, vllm_engines, generation_configs, - iter_dataloader, reward_fn, resume_training_step, episode, @@ -2938,8 +2325,6 @@ def run_training( param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - pending_queries_map, - eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, actor_manager: ActorManager, @@ -2967,46 +2352,16 @@ def run_training( [engine.ready.remote() for engine in vllm_engines], "Checking engines are ready to work", timeout=300 ) - logger.info("======== ✅ data preparation thread starts =========") - packing_future = executor.submit( - data_preparation_thread, - reward_fn, - inference_results_Q, - param_prompt_Q, - packed_sequences_Q, - pending_queries_map, - args, - tokenizer, - args.num_training_steps, - generation_configs["train"], - resume_training_step, - iter_dataloader, - train_dataset, - actor_manager, - model_dims, - ) + logger.info("======== ✅ Dataloaders already initialized in actors =========") def health_check_fn(): - [f.result() for f in [packing_future, weight_sync_thread_future] if f.done()] + [f.result() for f in [weight_sync_thread_future] if f.done()] ray_get_with_progress( [engine.check_background_threads.remote() for engine in vllm_engines], desc="Checking vLLM engine health", enable=False, ) - # Send initial data to ensure we have a N-step offset. - for _ in range(args.async_steps * args.num_unique_prompts_rollout): - dataset_index = next(iter_dataloader) - add_prompt_to_generator( - train_dataset[dataset_index], - dataset_index, - iter_dataloader.epoch_number, - resume_training_step, - pending_queries_map, - param_prompt_Q, - generation_configs["train"], - is_eval=False, - ) if checkpoint_state and "num_total_tokens" in checkpoint_state: num_total_tokens = checkpoint_state["num_total_tokens"] logger.info(f"Restored num_total_tokens: {num_total_tokens}") @@ -3034,16 +2389,6 @@ def health_check_fn(): health_check_fn() health_check_time = time.perf_counter() - health_check_start - ( - collated_data, - data_thread_metrics, - num_total_tokens, - num_step_tokens, - prompt_lengths, - response_lengths, - num_filtered_prompts, - ) = load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn) - if ( training_step % args.local_eval_every == 0 and eval_dataset is not None @@ -3053,18 +2398,16 @@ def health_check_fn(): add_prompt_to_generator( eval_example, eval_index, - iter_dataloader.epoch_number, + 0, training_step, - eval_pending_queries_map, param_prompt_Q, generation_configs["eval"], is_eval=True, ) - if collated_data is None: - continue - episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + episode += args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout + data_thread_metrics = {} for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: try: data_thread_metrics |= metrics_Q.get_nowait() @@ -3075,24 +2418,20 @@ def health_check_fn(): one_training_step( args, + streaming_config, policy_group, - collated_data, tokenizer, data_thread_metrics, episode, training_step, num_total_tokens, - num_step_tokens, start_time, train_dataset, training_start_time, wandb_url, tc.chat_template_name, model_dims, - prompt_lengths, - response_lengths, actor_manager, - iter_dataloader, ) logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}") @@ -3106,17 +2445,12 @@ def health_check_fn(): and args.checkpoint_state_dir is not None ): with Timer("[Main Thread] ðŸ—Ąïļ Saving checkpoint state"): - # Save comprehensive client state including ShufflingIterator state client_state = { "training_step": training_step, "episode": episode, "num_total_tokens": num_total_tokens, } - # Save ShufflingIterator state - if iter_dataloader is not None: - client_state["shuffling_iterator_state"] = iter_dataloader.get_state() - ray_get_with_progress( [ policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) @@ -3133,7 +2467,7 @@ def health_check_fn(): tokenizer, reward_fn, episode, - eval_pending_queries_map, + eval_dataset, generation_configs["eval"], generate_metrics_Q, len(eval_dataset) if eval_dataset else 0, @@ -3147,9 +2481,14 @@ def health_check_fn(): save_final_model(args, policy_group, tokenizer, training_step, wandb_url, tc.chat_template_name) -def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): +def main( + args: Args, + tc: TokenizerConfig, + model_config: ModelConfig, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, +): tokenizer = make_tokenizer(tc, model_config) - args = setup_runtime_variables(args) + args = setup_runtime_variables(args, streaming_config) if args.verbose: logging.getLogger().setLevel(logging.DEBUG) @@ -3158,9 +2497,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): beaker_config, wandb_url = setup_experiment_tracking(args, tc, model_config) - train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer) + train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer, streaming_config) - if len(train_dataset) < (needed := max(args.async_steps, 1) * args.num_unique_prompts_rollout): + if len(train_dataset) < (needed := max(streaming_config.async_steps, 1) * args.num_unique_prompts_rollout): raise ValueError( f"Train dataset is too small! Is {len(train_dataset)} prompts, but {needed} are needed to have enough prompts for bsz and prefill. Try reducing async_steps or num_unique_prompts_rollout, or increasing the dataset size." ) @@ -3176,13 +2515,16 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): # Create Ray queues. # Since we now send/receive individual prompts, queue size should accommodate # all prompts from async_steps + 1 training steps - queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout + queue_size = (streaming_config.async_steps + 1) * args.num_unique_prompts_rollout inference_results_Q = ray_queue.Queue(maxsize=queue_size) param_prompt_Q = ray_queue.Queue(maxsize=queue_size) # We don't care if we ever hit the max, so we let the queue be unbounded. evaluation_inference_results_Q = ray_queue.Queue() - policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = ( + reward_fn = make_reward_fn(args) + generation_configs = create_generation_configs(args, streaming_config) + + (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( create_model_and_optimizer( args, tc, @@ -3193,14 +2535,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, + streaming_config, + train_dataset, + reward_fn, + generation_configs["train"], ) ) - # Get the model dimensions from one of the engines without loading weights - model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) - - generation_configs = create_generation_configs(args) - checkpoint_state = None if args.checkpoint_state_dir and os.path.exists(args.checkpoint_state_dir): # Try to load the checkpoint state from the first rank @@ -3212,21 +2553,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): episode = checkpoint_state["episode"] logger.info(f"Restored episode count: {episode}") - train_dataset_idxs = np.arange(len(train_dataset)) - iter_dataloader = ShufflingIterator(train_dataset_idxs, 1, seed=args.seed) - - if checkpoint_state and "shuffling_iterator_state" in checkpoint_state: - iter_dataloader.set_state(checkpoint_state["shuffling_iterator_state"]) - logger.info("Restored ShufflingIterator state from checkpoint") - # Create additional queues (main queues already created above) - packed_sequences_Q = Queue(maxsize=args.async_steps) - pending_queries_map = PendingQueriesMap() - eval_pending_queries_map = PendingQueriesMap() - generate_metrics_Q = Queue(maxsize=args.async_steps) - weight_sync_metrics_Q = Queue(maxsize=args.async_steps) - - reward_fn = make_reward_fn(args) + packed_sequences_Q = Queue(maxsize=streaming_config.async_steps) + generate_metrics_Q = Queue(maxsize=streaming_config.async_steps) + weight_sync_metrics_Q = Queue(maxsize=streaming_config.async_steps) stop_event = threading.Event() executor = futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix="grpo") @@ -3234,13 +2564,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): try: episode = run_training( args, + streaming_config, tokenizer, train_dataset, eval_dataset, policy_group, vllm_engines, generation_configs, - iter_dataloader, reward_fn, resume_training_step, episode, @@ -3252,8 +2582,6 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - pending_queries_map, - eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, actor_manager, @@ -3296,10 +2624,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): if __name__ == "__main__": utils.check_oe_eval_internal() - parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig)) - args, tokenizer_config, model_config = parser.parse_args_into_dataclasses() + parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig, streaming_data_loader.StreamingDataLoaderConfig)) + args, tokenizer_config, model_config, streaming_config = parser.parse_args_into_dataclasses() assert isinstance(args, Args) assert isinstance(tokenizer_config, TokenizerConfig) assert isinstance(model_config, ModelConfig) + assert isinstance(streaming_config, streaming_data_loader.StreamingDataLoaderConfig) - main(args, tokenizer_config, model_config) + main(args, tokenizer_config, model_config, streaming_config) diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py index 0cc047bca..8267f2f44 100644 --- a/open_instruct/queue_types.py +++ b/open_instruct/queue_types.py @@ -2,6 +2,10 @@ from typing import Any +class ShutdownSentinel: + """Sentinel value to signal thread shutdown via queue.""" + + @dataclass class TokenStatistics: """Container for token statistics from inference.""" diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py new file mode 100644 index 000000000..124c48f0b --- /dev/null +++ b/open_instruct/streaming_data_loader.py @@ -0,0 +1,968 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import threading +from abc import abstractmethod +from collections.abc import Callable, Iterable +from dataclasses import asdict, dataclass +from pathlib import Path +from queue import Queue as StdQueue +from typing import Any + +import numpy as np +import torch +import vllm +from datasets import Dataset +from ray.util import queue as ray_queue +from tqdm import tqdm +from transformers import PreTrainedTokenizer + +from open_instruct import utils +from open_instruct.dataset_transformation import ( + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + RAW_PROMPT_KEY, + VERIFIER_SOURCE_KEY, +) +from open_instruct.model_utils import Batch +from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.utils import combine_reward_metrics, repeat_each + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamingDataLoaderConfig: + max_prompt_token_length: int = 256 + response_length: int = 256 + async_steps: int = 1 + num_samples_per_prompt_rollout: int = 4 + active_sampling: bool = False + filter_zero_std_samples: bool = True + no_resampling_pass_rate: float | None = None + advantage_normalization_type: str = "standard" + mask_truncated_completions: bool = False + pack_length: int = 512 + + def __post_init__(self): + assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( + "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" + ) + assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" + if self.num_samples_per_prompt_rollout == 1: + logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") + + if self.active_sampling: + assert self.async_steps > 1, ( + "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " + "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " + "prompt will cause the trainer to stall waiting for more data . " + ) + assert self.filter_zero_std_samples, ( + "filter_zero_std_samples must be True when active_sampling is True. " + "Active sampling requires filtering to work correctly." + ) + if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + raise ValueError( + "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " + "as the reward standard deviation will always be 0, causing all samples to be filtered." + ) + + def build( + self, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + generation_config: Any, + dp_rank: int, + fs_local_rank: int, + num_training_steps: int, + seed: int, + per_device_train_batch_size: int, + verbose: bool, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int, + max_possible_score: float, + actor_manager=None, + model_dims: utils.ModelDims | None = None, + ) -> "StreamingDataLoader": + return StreamingDataLoader( + dataset=dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + tokenizer=tokenizer, + config=self, + generation_config=generation_config, + work_dir=work_dir, + global_batch_size=global_batch_size, + num_training_steps=num_training_steps, + seed=seed, + per_device_train_batch_size=per_device_train_batch_size, + verbose=verbose, + max_possible_score=max_possible_score, + actor_manager=actor_manager, + model_dims=model_dims, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + +class DataLoaderBase: + def __init__( + self, + *, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + self.work_dir = Path(work_dir) + self._global_batch_size = global_batch_size + self.dp_world_size = dp_world_size + self.dp_rank = dp_rank + self.fs_local_rank = fs_local_rank + self.batches_processed = 0 + self.epoch: int | None = None + + @property + def global_batch_size(self) -> int: + return self._global_batch_size + + @global_batch_size.setter + def global_batch_size(self, value: int): + self._global_batch_size = value + + @property + def rank_batch_size(self) -> int: + return self.global_batch_size // self.dp_world_size + + @property + @abstractmethod + def total_batches(self) -> int | None: + pass + + @abstractmethod + def state_dict(self) -> dict[str, Any]: + pass + + @abstractmethod + def load_state_dict(self, state_dict: dict[str, Any]): + pass + + @abstractmethod + def reshuffle(self, epoch: int | None = None, **kwargs): + pass + + @abstractmethod + def _iter_batches(self) -> Iterable[dict[str, Any]]: + pass + + @abstractmethod + def get_mock_batch(self) -> dict[str, Any]: + pass + + def __iter__(self): + return self._iter_batches() + + def __next__(self): + if not hasattr(self, "_iterator"): + self._iterator = self._iter_batches() + return next(self._iterator) + + def reset(self): + if hasattr(self, "_iterator"): + del self._iterator + self.batches_processed = 0 + + +class TextDataLoaderBase(DataLoaderBase): + def __init__( + self, + *, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + self.tokens_processed: int = 0 + + def reset(self): + super().reset() + self.tokens_processed = 0 + + def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int | None: + del batch + return self.global_batch_size + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.epoch_number = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + self.exclude_list = [] + + self._update_effective_size() + + def __iter__(self): + return self + + def __next__(self) -> list[int] | int: + if self.index >= self.effective_size: + self.index = 0 + self._update_effective_size() + self.epoch_number += 1 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + if self.batch_size == 1: + batch = batch[0] + self.index = end_index + + return batch + + def get_state(self) -> dict[str, Any]: + return { + "index": self.index, + "epoch_number": self.epoch_number, + "data": self.data.copy(), + "rng_state": self.rng.bit_generator.state, + "exclude_list": self.exclude_list.copy(), + } + + def set_state(self, state: dict[str, Any]) -> None: + self.index = state["index"] + self.epoch_number = state.get("epoch_number", 0) + self.data = state["data"].copy() + self.rng.bit_generator.state = state["rng_state"] + self.exclude_list = state.get("exclude_list", []) + self._update_effective_size() + + def exclude_index(self, index: int) -> None: + self.exclude_list.append(index) + + def _update_effective_size(self) -> None: + if self.exclude_list: + mask = ~np.isin(self.data, self.exclude_list) + self.data = self.data[mask] + self.exclude_list = [] + + self.effective_size = len(self.data) - (len(self.data) % self.batch_size) + + +class StreamingDataLoader(TextDataLoaderBase): + def __init__( + self, + *, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + config: StreamingDataLoaderConfig, + generation_config: Any, + work_dir: Path | str, + global_batch_size: int, + num_training_steps: int = 0, + seed: int, + per_device_train_batch_size: int, + verbose: bool, + max_possible_score: float, + actor_manager=None, + model_dims: utils.ModelDims = None, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + self.dataset = dataset + self.reward_fn = reward_fn + self.inference_results_Q = inference_results_Q + self.param_prompt_Q = param_prompt_Q + self.tokenizer = tokenizer + self.config = config + self.config.max_possible_score = max_possible_score + self.generation_config = generation_config + self.num_training_steps = num_training_steps + self.actor_manager = actor_manager + self.model_dims = model_dims + + self.per_device_train_batch_size = per_device_train_batch_size + self.verbose = verbose + + self.training_step = 0 + self.current_epoch = 0 + + dataset_indices = np.arange(len(dataset)) + self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed + dp_rank) + + self.local_queue = StdQueue(maxsize=config.async_steps) + self.background_thread = None + self.shutdown_requested = False + + @property + def total_batches(self) -> int | None: + return self.num_training_steps + + def state_dict(self) -> dict[str, Any]: + return { + "training_step": self.training_step, + "current_epoch": self.current_epoch, + "iter_dataloader_state": self.iter_dataloader.get_state(), + } + + def load_state_dict(self, state_dict: dict[str, Any]): + self.training_step = state_dict["training_step"] + self.current_epoch = state_dict.get("current_epoch", 0) + self.iter_dataloader.set_state(state_dict["iter_dataloader_state"]) + + def reshuffle(self, epoch: int | None = None, **kwargs): + if epoch is not None: + self.current_epoch = epoch + + def get_mock_batch(self) -> dict[str, Any]: + dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) + dummy_tool_mask = torch.zeros_like(dummy_qr) + dummy_attention = torch.tensor([1, 1], dtype=torch.long) + dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) + dummy_response_mask = torch.zeros_like(dummy_qr) + dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) + + return { + "collated_query_responses": [dummy_qr], + "collated_tool_masks": [dummy_tool_mask], + "collated_attention_masks": [dummy_attention], + "collated_position_ids": [dummy_position_ids], + "collated_advantages": [dummy_advantage], + "collated_response_masks": [dummy_response_mask], + "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], + } + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + if self.background_thread is None: + self._start_background_thread() + + while self.training_step < self.num_training_steps: + batch_data = self.local_queue.get() + self.training_step += 1 + yield batch_data + + def _start_background_thread(self): + self.shutdown_requested = False + self.background_thread = threading.Thread( + target=self._data_preparation_loop, daemon=True, name=f"DataLoader-Worker-Rank{self.dp_rank}" + ) + self.background_thread.start() + + def _data_preparation_loop(self): + for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): + local_index = next(self.iter_dataloader) + example = self.dataset[local_index] + dataset_index = example["index"] + add_prompt_to_generator( + example, + dataset_index, + self.iter_dataloader.epoch_number, + self.training_step, + self.param_prompt_Q, + self.generation_config, + is_eval=False, + ) + + for training_step in range(self.training_step, self.num_training_steps): + if self.shutdown_requested: + logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") + return + + with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: + result, batch, reward_metrics, batch_stats = accumulate_inference_batches( + self.inference_results_Q, + self.generation_config, + num_prompts=self.rank_batch_size, + model_dims=self.model_dims, + tokenizer=self.tokenizer, + reward_fn=self.reward_fn, + dataset=self.dataset, + actor_manager=self.actor_manager, + active_sampling=self.config.active_sampling, + filter_zero_std_samples=self.config.filter_zero_std_samples, + replenish_prompts=True, + no_resampling_pass_rate=self.config.no_resampling_pass_rate, + iter_dataloader=self.iter_dataloader, + param_prompt_Q=self.param_prompt_Q, + training_step=training_step, + verbose=self.verbose, + max_possible_score=self.config.max_possible_score, + ) + if isinstance(result, ShutdownSentinel): + logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") + return + + getting_response_time = timer.duration + scores = np.array(batch.scores) + + good_outputs = [ + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) + ] + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + if self.config.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif self.config.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") + + if self.config.mask_truncated_completions: + stop_idxes = torch.tensor( + [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] + ) + num_truncated = len(result.finish_reasons) - len(stop_idxes) + if num_truncated > 0: + logger.info( + f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " + f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + ) + scores = scores[stop_idxes] + advantages = advantages[stop_idxes] + batch = batch[stop_idxes.tolist()] + result.responses = [result.responses[i] for i in stop_idxes] + result.masks = [result.masks[i] for i in stop_idxes] + result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] + result.logprobs = [result.logprobs[i] for i in stop_idxes] + + with Timer("ðŸ“Ķ [Data Preparation Thread] Packing sequences"): + packed_sequences = pack_sequences( + queries=batch.queries, + responses=result.responses, + masks=result.masks, + pack_length=self.config.pack_length, + pad_token_id=self.tokenizer.pad_token_id, + vllm_logprobs=result.logprobs, + ) + lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + collated_data = self._prepare_collated_data_for_self(packed_sequences) + + if len(result.responses) == 0: + metrics = {} + logger.warning(f"No responses in batch {training_step}.") + else: + real_num_responses = len(result.responses) + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size + + unsolved_num_responses = (scores < self.config.max_possible_score).sum() + sequence_lengths = np.array([len(response) for response in result.responses]) + sequence_length_solved = ( + np.array([]) + if np.all(scores == 0) + else np.array(sequence_lengths[scores == self.config.max_possible_score]) + ) + sequence_length_unsolved = ( + np.array([]) + if np.all(scores == self.config.max_possible_score) + else np.array(sequence_lengths[scores == 0]) + ) + stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( + result.finish_reasons + ) + + batch_metrics = asdict(batch_stats) + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} + + metrics = { + "scores": scores.mean(), + "real_batch_size_ratio": real_num_responses / expected_num_responses, + "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, + "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, + "val/solve_rate_hist": batch_stats.percent_solved_hist, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/sequence_lengths_unsolved": ( + 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() + ), + "val/sequence_lengths_solved": ( + 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() + ), + "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, + "val/sequence_lengths_solved_hist": sequence_length_solved, + "val/stop_rate": stop_rate, + "val/advantages_mean": advantages.mean(), + "val/advantages_min": advantages.min(), + "val/advantages_max": advantages.max(), + "val/advantages_hist": advantages, + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array( + [len(item) > 0 for item in result.request_info.tool_errors] + ).mean(), + "val/good_outputs_rate": np.array(good_outputs).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), + "time/getting_response": getting_response_time, + **reward_metrics, + **batch_metrics_prefixed, + } + + total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens + metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + + collated_data["metrics"] = metrics + self.local_queue.put(collated_data) + + def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: + per_device_packed_query_responses = packed_sequences.query_responses + per_device_packed_tool_masks = packed_sequences.tool_masks + per_device_packed_attention_masks = packed_sequences.attention_masks + per_device_packed_position_ids = packed_sequences.position_ids + per_device_packed_advantages = packed_sequences.advantages + per_device_packed_response_masks = packed_sequences.response_masks + per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs + + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_tool_masks = [] + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + collated_vllm_logprobs = [] + for j in range(0, len(per_device_packed_query_responses), self.per_device_train_batch_size): + micro_range = b_inds[j : j + self.per_device_train_batch_size] + collated_query_responses.append( + collate_fn( + [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True + ) + ) + collated_tool_masks.append(collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True)) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, True) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, True) + ) + collated_advantages.append(collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, True)) + collated_vllm_logprobs.append( + collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) + ) + + return { + "collated_query_responses": collated_query_responses, + "collated_tool_masks": collated_tool_masks, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + "collated_vllm_logprobs": collated_vllm_logprobs, + } + + def shutdown(self): + self.shutdown_requested = True + if self.background_thread is not None: + self.background_thread.join(timeout=5.0) + + +def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: + padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) + if pin_memory: + padded_tensor = padded_tensor.pin_memory() + return padded_tensor + + +@dataclass +class BatchStatistics: + prompt_lengths: list[int] + response_lengths: list[int] + filtered_prompts: int + filtered_prompts_zero: int + filtered_prompts_solved: int + filtered_prompts_nonzero: int + percent_solved_mean: float + percent_solved_hist: np.ndarray + no_resampled_prompts: int + total_prompts: int + + +class PendingQueriesMap: + def __init__(self): + self._map = {} + self._lock = threading.Lock() + + def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + with self._lock: + if dataset_idx in self._map: + existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ + dataset_idx + ] + self._map[dataset_idx] = ( + existing_query, + existing_ground_truth, + existing_dataset, + existing_raw_query, + count + 1, + ) + else: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + + def pop(self, dataset_idx): + with self._lock: + if dataset_idx not in self._map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + + query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + + if count > 1: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + else: + del self._map[dataset_idx] + + return query, ground_truth, dataset, raw_query + + def __len__(self): + with self._lock: + return len(self._map) + + def __contains__(self, dataset_idx): + with self._lock: + return dataset_idx in self._map + + def __getitem__(self, dataset_idx): + with self._lock: + return self._map[dataset_idx] + + def keys(self): + with self._lock: + return list(self._map.keys()) + + +def add_prompt_to_generator( + example: dict[str, Any], + example_index: int, + epoch_number: int, + training_step: int, + param_prompt_Q: ray_queue.Queue, + generation_config, + is_eval: bool, +) -> None: + query = example[INPUT_IDS_PROMPT_KEY] + + param_prompt_Q.put( + PromptRequest( + prompt=query, + generation_config=generation_config, + epoch_number=epoch_number, + training_step=training_step, + dataset_index=example_index, + is_eval=is_eval, + ) + ) + + +def accumulate_inference_batches( + inference_results_Q: ray_queue.Queue, + generation_config: vllm.SamplingParams, + num_prompts: int, + model_dims: utils.ModelDims, + tokenizer: PreTrainedTokenizer, + reward_fn: Callable, + dataset: Dataset, + actor_manager=None, + timeout: float | None = None, + active_sampling: bool = False, + filter_zero_std_samples: bool = False, + replenish_prompts: bool = False, + no_resampling_pass_rate: float | None = None, + iter_dataloader: ShufflingIterator | None = None, + param_prompt_Q: ray_queue.Queue | None = None, + training_step: int = None, + verbose: bool = False, + max_possible_score: float = 1.0, +) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: + import ray + + if no_resampling_pass_rate is not None: + assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" + + if replenish_prompts: + assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( + "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" + ) + + results = [] + all_queries = [] + all_ground_truths = [] + all_datasets = [] + all_raw_queries = [] + all_decoded_responses = [] + all_reward_metrics = [] + all_scores = [] + all_percent_solved = [] + total_filtered_prompts = 0 + filtered_prompt_zero = 0 + filtered_prompt_solved = 0 + filtered_prompt_nonzero = 0 + total_no_resampled = 0 + progress_bar = tqdm( + total=num_prompts, + desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", + bar_format="{l_bar}{bar}{r_bar}\n", + disable=not verbose, + ) + num_prompts_sampled = 0 + while num_prompts_sampled < num_prompts: + result = inference_results_Q.get(timeout=timeout) + + if isinstance(result, ShutdownSentinel): + return result, None, None, None + + assert len(result.responses) == generation_config.n, ( + f"Mismatch: individual prompt result has {len(result.responses)} responses " + f"but expected {generation_config.n} samples per prompt. " + f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" + ) + + example = dataset[result.dataset_index] + query = example[INPUT_IDS_PROMPT_KEY] + ground_truth = example[GROUND_TRUTHS_KEY] + dataset_name = example[VERIFIER_SOURCE_KEY] + raw_query = example[RAW_PROMPT_KEY] + + if replenish_prompts: + local_index = next(iter_dataloader) + example = dataset[local_index] + dataset_index = example["index"] + add_prompt_to_generator( + example, + dataset_index, + iter_dataloader.epoch_number, + training_step, + param_prompt_Q, + generation_config, + is_eval=False, + ) + + for i in range(len(result.finish_reasons)): + if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) + result.logprobs[i].append(float("nan")) + + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) + + k_queries = repeat_each([query], generation_config.n) + k_ground_truths = repeat_each([ground_truth], generation_config.n) + k_datasets = repeat_each([dataset_name], generation_config.n) + k_raw_queries = repeat_each([raw_query], generation_config.n) + + scores, reward_metrics = asyncio.run( + reward_fn( + result.responses, + decoded_responses, + k_ground_truths, + k_datasets, + result.finish_reasons, + result.request_info, + k_raw_queries, + ) + ) + + percent_solved = np.mean(scores).item() / max_possible_score + if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: + iter_dataloader.exclude_index(result.dataset_index) + total_no_resampled += 1 + logging.debug( + f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" + ) + + if filter_zero_std_samples and np.std(scores) == 0: + if not active_sampling: + num_prompts_sampled += 1 + progress_bar.update(1) + + total_filtered_prompts += 1 + if scores[0] == 0: + filtered_prompt_zero += 1 + elif scores[0] == max_possible_score: + filtered_prompt_solved += 1 + else: + filtered_prompt_nonzero += 1 + logging.debug( + f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" + ) + continue + else: + num_prompts_sampled += 1 + progress_bar.update(1) + + results.append(result) + all_queries.extend(k_queries) + all_ground_truths.extend(k_ground_truths) + all_datasets.extend(k_datasets) + all_raw_queries.extend(k_raw_queries) + all_decoded_responses.extend(decoded_responses) + all_scores.extend(scores) + all_reward_metrics.append(reward_metrics) + all_percent_solved.append(percent_solved) + + combined_responses = [] + combined_finish_reasons = [] + combined_masks = [] + combined_num_calls = [] + combined_timeouts = [] + combined_tool_errors = [] + combined_tool_outputs = [] + combined_tool_runtimes = [] + combined_tool_calleds = [] + combined_logprobs = [] + + earliest_start_time = float("inf") + prompt_lengths = [] + response_lengths = [] + + total_prompt_tokens = 0 + total_response_tokens = 0 + max_generation_time = 0 + + for i, result in enumerate(results): + combined_responses.extend(result.responses) + combined_finish_reasons.extend(result.finish_reasons) + combined_masks.extend(result.masks) + combined_num_calls.extend(result.request_info.num_calls) + combined_timeouts.extend(result.request_info.timeouts) + combined_tool_errors.extend(result.request_info.tool_errors) + combined_tool_outputs.extend(result.request_info.tool_outputs) + combined_tool_runtimes.extend(result.request_info.tool_runtimes) + combined_tool_calleds.extend(result.request_info.tool_calleds) + + combined_logprobs.extend(result.logprobs) + + earliest_start_time = min(earliest_start_time, result.start_time) + + prompt_lengths.append(len(all_queries[i * generation_config.n])) + + for response in result.responses: + response_lengths.append(len(response)) + + total_prompt_tokens += result.token_statistics.num_prompt_tokens + total_response_tokens += result.token_statistics.num_response_tokens + max_generation_time = max(max_generation_time, result.token_statistics.generation_time) + + total_generation_time = max_generation_time + + accumulated_stats = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=total_generation_time, + earliest_start_time=earliest_start_time, + ) + + combined_request_info = RequestInfo( + num_calls=combined_num_calls, + timeouts=combined_timeouts, + tool_errors=combined_tool_errors, + tool_outputs=combined_tool_outputs, + tool_runtimes=combined_tool_runtimes, + tool_calleds=combined_tool_calleds, + ) + + combined_result = GenerationResult( + responses=combined_responses, + finish_reasons=combined_finish_reasons, + masks=combined_masks, + request_info=combined_request_info, + dataset_index=None, + epoch_number=results[0].epoch_number, + token_statistics=accumulated_stats, + logprobs=combined_logprobs, + ) + + if actor_manager is not None: + ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) + + batch = Batch( + queries=all_queries, + ground_truths=all_ground_truths, + datasets=all_datasets, + raw_queries=all_raw_queries, + decoded_responses=all_decoded_responses, + indices=None, + scores=all_scores, + ) + + combined_reward_metrics = combine_reward_metrics(all_reward_metrics) + percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 + + batch_stats = BatchStatistics( + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + filtered_prompts=total_filtered_prompts, + filtered_prompts_zero=filtered_prompt_zero, + filtered_prompts_solved=filtered_prompt_solved, + filtered_prompts_nonzero=filtered_prompt_nonzero, + percent_solved_mean=percent_solved_mean, + percent_solved_hist=np.array(all_percent_solved), + no_resampled_prompts=total_no_resampled, + total_prompts=len(results), + ) + logging.info( + f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" + ) + + return combined_result, batch, combined_reward_metrics, batch_stats diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index d39e9208d..b010783c3 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -23,6 +23,7 @@ VERIFIER_SOURCE_KEY, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics +from open_instruct.streaming_data_loader import PendingQueriesMap, ShufflingIterator from open_instruct.vllm_utils import create_vllm_engines @@ -234,7 +235,6 @@ def setup_and_add_prompts_to_generator( queue_size = max(len(queries), num_engines * 2) param_prompt_Q = ray_queue.Queue(maxsize=queue_size) inference_results_Q = ray_queue.Queue(maxsize=queue_size) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queues for cleanup self._ray_queues.extend([param_prompt_Q, inference_results_Q]) @@ -247,25 +247,29 @@ def setup_and_add_prompts_to_generator( # Calculate inference_batch_size based on number of queries and engines mock_args.inference_batch_size = max(1, len(queries) // num_engines) - for index in range(len(queries)): + # Create a mock dataset that can be indexed by dataset_index + max_index = max(indices) + 1 + mock_dataset = [{} for _ in range(max_index)] + for i, index in enumerate(indices): + mock_dataset[index] = { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + + for i in range(len(queries)): example = { - INPUT_IDS_PROMPT_KEY: queries[index], - GROUND_TRUTHS_KEY: ground_truths[index], - VERIFIER_SOURCE_KEY: datasets[index], - RAW_PROMPT_KEY: raw_queries[index], + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - 0, - training_step, - pending_queries_map, - param_prompt_Q, - mock_generation_config, - False, + example, indices[i], 0, training_step, param_prompt_Q, mock_generation_config, False ) - return param_prompt_Q, inference_results_Q, pending_queries_map + return param_prompt_Q, inference_results_Q, mock_dataset class TestGrpoFastVLLM(TestGrpoFastBase): @@ -350,13 +354,10 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) - # Verify that we have individual prompts in the map (not batches) - self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) - # Verify that we have the expected number of items in the queue (one per prompt) self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout) @@ -383,8 +384,12 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, result = inference_results_Q.get() dataset_index = result.dataset_index - # Get query from pending_queries_map - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + # Get query from mock_dataset + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_responses.extend(result.responses) combined_queries.append(q) @@ -418,9 +423,6 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) self.assertEqual(len(combined_result.masks), len(queries_next)) - # Verify that the pending_queries_map is empty after accumulation - self.assertEqual(len(pending_queries_map), 0) - # Verify that the inference_results_Q is empty after accumulation self.assertEqual(inference_results_Q.qsize(), 0) @@ -435,7 +437,7 @@ def test_dataset_index_preservation_through_pipeline(self): ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) @@ -457,7 +459,11 @@ def test_dataset_index_preservation_through_pipeline(self): result = inference_results_Q.get() dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_queries.append(q) combined_raw_queries.append(raw_q) combined_ground_truths.append(gt) @@ -467,7 +473,6 @@ def test_dataset_index_preservation_through_pipeline(self): self.assertEqual(combined_queries, queries_next) self.assertEqual(combined_ground_truths, ground_truths_next) self.assertEqual(combined_datasets, datasets_next) - self.assertEqual(len(pending_queries_map), 0) @parameterized.expand([(1, 16), (2, 8), (4, 4)]) def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_per_prompt: int): @@ -480,18 +485,10 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) - # For multiple samples, we need to add additional references to the pending_queries_map - # The first reference is already added by setup_and_add_prompts_to_generator - for _ in range(num_samples_per_prompt - 1): - for idx, query, ground_truth, dataset, raw_query in zip( - dataset_indices, queries_next, ground_truths_next, datasets_next, raw_queries_next - ): - pending_queries_map.insert(idx, query, ground_truth, dataset, raw_query) - # Simulate vLLM processing with multiple samples batch_idx = 0 while not param_prompt_Q.empty(): @@ -511,11 +508,12 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe result = inference_results_Q.get() dataset_index = result.dataset_index - # Pop the query data for this specific result - pop multiple times for multiple samples - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) - # Pop additional times to handle multiple samples per prompt - for _ in range(num_samples_per_prompt - 1): - pending_queries_map.pop(dataset_index) + # Get query data from mock_dataset + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_responses.extend(result.responses) combined_queries.append(q) @@ -542,7 +540,6 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe self.assertEqual(combined_queries, queries_next) self.assertEqual(combined_ground_truths, ground_truths_next) self.assertEqual(combined_datasets, datasets_next) - self.assertEqual(len(pending_queries_map), 0) # Verify correct number of responses expected_responses = num_unique_prompts_rollout * num_samples_per_prompt @@ -600,7 +597,7 @@ def test_out_of_order_processing(self): tokenizer, reward_fn = self.create_mock_tokenizer_and_reward_fn() # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries, ground_truths, datasets, raw_queries, indices, num_engines ) @@ -615,7 +612,6 @@ def test_out_of_order_processing(self): inference_results_Q.put(mock_result) # Accumulate results - mock_args = self.create_mock_args(num_engines, num_samples_per_prompt) # Create a mock generation config with n mock_generation_config = Mock() mock_generation_config.n = num_samples_per_prompt @@ -623,23 +619,21 @@ def test_out_of_order_processing(self): mock_model_dims = self.create_mock_model_dims() combined_result, batch, reward_metrics, batch_stats = grpo_fast.accumulate_inference_batches( inference_results_Q, - pending_queries_map, - mock_args, - generation_config=mock_generation_config, + mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=mock_dataset, ) # Verify results work correctly even with out-of-order processing self.assertEqual(len(batch.queries), num_prompts * num_samples_per_prompt) self.assertEqual(len(combined_result.responses), num_prompts * num_samples_per_prompt) - self.assertEqual(len(pending_queries_map), 0) def test_thread_safety_pending_queries_map(self): """Test concurrent access to pending_queries_map.""" - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() errors = [] num_threads = 4 entries_per_thread = 50 @@ -697,11 +691,17 @@ def test_accumulate_waits_for_all_engines(self): # Track queue for cleanup self._ray_queues.append(inference_results_Q) - pending_queries_map = grpo_fast.PendingQueriesMap() - - # Add entries to map + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}") + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: f"q_{i}", + GROUND_TRUTHS_KEY: f"t_{i}", + VERIFIER_SOURCE_KEY: f"d_{i}", + RAW_PROMPT_KEY: f"q_{i}", + } + ) # Add results from only 3 engines (missing one) # With individual prompts, we add individual results @@ -710,8 +710,6 @@ def test_accumulate_waits_for_all_engines(self): mock_result = self.create_mock_result(i, 1) inference_results_Q.put(mock_result) - mock_args = self.create_mock_args(num_engines) - completed = threading.Event() def run_accumulate(): @@ -723,13 +721,12 @@ def run_accumulate(): mock_model_dims = self.create_mock_model_dims() grpo_fast.accumulate_inference_batches( inference_results_Q, - pending_queries_map, - mock_args, - generation_config=mock_generation_config, + mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=mock_dataset, ) completed.set() except Exception: @@ -744,8 +741,6 @@ def run_accumulate(): # Queue should be empty after consuming 12 results self.assertEqual(inference_results_Q.qsize(), 0) - # 12 entries should be removed from the map (4 still pending) - self.assertEqual(len(pending_queries_map), 4) class TestStreamingAccumulation(TestGrpoFastBase): @@ -754,12 +749,10 @@ class TestStreamingAccumulation(TestGrpoFastBase): def test_more_engines_than_queries(self): """Test that add_prompt_to_generator handles gracefully when engines > queries.""" # More engines than queries - should handle gracefully with single-prompt batches - num_engines = 8 num_queries = 4 queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -767,10 +760,6 @@ def test_more_engines_than_queries(self): mock_generation_config = MagicMock() mock_generation_config.n = 1 - # Create mock args with inference_batch_size - mock_args = MagicMock() - mock_args.inference_batch_size = max(1, num_queries // num_engines) - for index in range(len(queries)): example = { INPUT_IDS_PROMPT_KEY: queries[index], @@ -779,14 +768,7 @@ def test_more_engines_than_queries(self): RAW_PROMPT_KEY: raw_queries[index], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - epoch_number=0, - training_step=1, - pending_queries_map=pending_queries_map, - param_prompt_Q=param_prompt_Q, - generation_config=mock_generation_config, - is_eval=False, + example, indices[index], 0, 1, param_prompt_Q, mock_generation_config, False ) # Should have 4 batches (one for each query) @@ -804,17 +786,13 @@ def test_more_engines_than_queries(self): # Should have exactly num_queries PromptRequests self.assertEqual(prompt_count, num_queries, f"Should have {num_queries} PromptRequests") - # All queries should be in the pending map - self.assertEqual(len(pending_queries_map), num_queries) def test_uneven_distribution_no_empty_batches(self): """Test that uneven query distribution doesn't create empty batches.""" - num_engines = 3 - num_queries = 7 # 7/3 = ceil(2.33) = 3, so distribution should be [3, 3, 1] + num_queries = 7 queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -822,10 +800,6 @@ def test_uneven_distribution_no_empty_batches(self): mock_generation_config = MagicMock() mock_generation_config.n = 1 - # Create mock args with inference_batch_size - mock_args = MagicMock() - mock_args.inference_batch_size = max(1, num_queries // num_engines + (1 if num_queries % num_engines else 0)) - for index in range(len(queries)): example = { INPUT_IDS_PROMPT_KEY: queries[index], @@ -834,14 +808,7 @@ def test_uneven_distribution_no_empty_batches(self): RAW_PROMPT_KEY: raw_queries[index], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - epoch_number=0, - training_step=1, - pending_queries_map=pending_queries_map, - param_prompt_Q=param_prompt_Q, - generation_config=mock_generation_config, - is_eval=False, + example, indices[index], 0, 1, param_prompt_Q, mock_generation_config, False ) # With single-prompt architecture, verify we have the right number of individual requests @@ -865,16 +832,23 @@ def test_streaming_accumulation_basic(self): # Create test data queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) - # Create queues and maps + # Create queues and mock dataset inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) - # Insert data into pending_queries_map + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + ) # Create mock results - one per prompt for i in range(num_prompts): @@ -891,14 +865,17 @@ def test_streaming_accumulation_basic(self): results_list.append(result) - # Get query for this prompt + # Get query for this prompt from dataset dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] queries_list.append((q, gt, d, raw_q)) # Verify all results processed self.assertEqual(len(results_list), expected_results) - self.assertEqual(len(pending_queries_map), 0) # Combine in order combined_queries = [] @@ -917,17 +894,23 @@ def test_streaming_with_multiple_samples(self): # Create test data queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) - # Create queues and maps + # Create queues and mock dataset inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) - # Insert data with reference counting for multiple samples + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - for _ in range(num_samples): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + ) # Create results - one per prompt with multiple samples for i in range(num_prompts): @@ -944,14 +927,13 @@ def test_streaming_with_multiple_samples(self): self.assertEqual(len(result.responses), expected_responses) total_responses += len(result.responses) - # Pop multiple times to match the number of samples (reference counting) + # Get query from dataset (can be looked up multiple times) idx = result.dataset_index - for _ in range(num_samples): - pending_queries_map.pop(idx) + example = mock_dataset[idx] + self.assertIsNotNone(example[INPUT_IDS_PROMPT_KEY]) # Verify total responses self.assertEqual(total_responses, num_prompts * num_samples) - self.assertEqual(len(pending_queries_map), 0) class TestShufflingIterator(unittest.TestCase): @@ -962,7 +944,7 @@ def test_basic_iteration(self): data = np.arange(100) batch_size = 10 - iterator = grpo_fast.ShufflingIterator(data, batch_size, seed=42) + iterator = ShufflingIterator(data, batch_size, seed=42) # Get first batch batch1 = next(iterator) @@ -983,7 +965,7 @@ def test_state_preservation_and_restoration(self): seed = 42 # Create original iterator - iter1 = grpo_fast.ShufflingIterator(data, batch_size, seed=seed) + iter1 = ShufflingIterator(data, batch_size, seed=seed) # Get a few batches _ = next(iter1) @@ -1004,7 +986,7 @@ def test_state_preservation_and_restoration(self): batch5_original = next(iter1) # Create new iterator with different seed and restore state - iter2 = grpo_fast.ShufflingIterator(data, batch_size, seed=999) + iter2 = ShufflingIterator(data, batch_size, seed=999) iter2.set_state(state) # Get batches from restored iterator @@ -1022,7 +1004,7 @@ def test_epoch_boundary_state(self): batch_size = 5 # Create iterator and complete one epoch - iterator = grpo_fast.ShufflingIterator(data, batch_size, seed=123) + iterator = ShufflingIterator(data, batch_size, seed=123) for _ in range(4): # 20 / 5 = 4 batches per epoch next(iterator) @@ -1032,7 +1014,7 @@ def test_epoch_boundary_state(self): self.assertEqual(state["index"], 20) # Create new iterator and restore state - iter2 = grpo_fast.ShufflingIterator(data, batch_size, seed=456) + iter2 = ShufflingIterator(data, batch_size, seed=456) iter2.set_state(state) # Next batches should match @@ -1047,8 +1029,8 @@ def test_rng_state_preservation(self): batch_size = 50 # Create two iterators with same seed - iter1 = grpo_fast.ShufflingIterator(data, batch_size, seed=42) - _ = grpo_fast.ShufflingIterator(data, batch_size, seed=42) + iter1 = ShufflingIterator(data, batch_size, seed=42) + _ = ShufflingIterator(data, batch_size, seed=42) # Advance first iterator for _ in range(5): @@ -1056,7 +1038,7 @@ def test_rng_state_preservation(self): # Save state and create new iterator with different seed state = iter1.get_state() - iter3 = grpo_fast.ShufflingIterator(data, batch_size, seed=999) + iter3 = ShufflingIterator(data, batch_size, seed=999) # Restore state - this should override the different seed iter3.set_state(state) diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 390f295c5..6affd135b 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -831,7 +831,9 @@ def create_vllm_engines( # ensure we use bundles on the same node where possible if tp>1. bundle_indices_list = get_bundle_indices_list(pg) + logger.info(f"[DEBUG] Creating {num_engines} vLLM engines with tensor_parallel_size={tensor_parallel_size}") for i in range(num_engines): + logger.info(f"[DEBUG] Creating vLLM engine {i + 1}/{num_engines}") bundle_indices = None bundle_indices = bundle_indices_list[i * tensor_parallel_size : (i + 1) * tensor_parallel_size] @@ -880,9 +882,12 @@ def create_vllm_engines( calculate_kv_scales=use_fp8_kv_cache, ) ) + logger.info(f"[DEBUG] vLLM engine {i + 1}/{num_engines} actor created") + logger.info(f"[DEBUG] All {num_engines} vLLM engine actors created, waiting for ready() (timeout=1200s)...") ray_get_with_progress( [engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=1200 ) + logger.info(f"[DEBUG] All {num_engines} vLLM engines ready!") return vllm_engines