Skip to content
245 changes: 180 additions & 65 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
push_folder_to_hub,
)
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences
from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences
from open_instruct.utils import (
ArgumentParserPlus,
BeakerRuntimeConfig,
Expand All @@ -118,6 +118,7 @@
combine_reward_metrics,
download_latest_checkpoint_from_gs,
get_beaker_whoami,
get_denominator,
get_eval_ds_config,
get_optimizer_grouped_parameters,
get_train_ds_config,
Expand Down Expand Up @@ -249,10 +250,11 @@ class Args:
"""the KL estimator to use"""
pack_length: int = 512
"""the length of the pack (you should prob set to the max length of the model)"""
masked_mean_axis: int | None = None
"""the axis to compute the mean of the masked values"""
masked_mean_denominator: float | None = None
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum"""
masked_mean_denominator: float | str | None = "token"
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum.
Special value "token" means use total_batch_tokens (computed across all ranks in distributed training).
Special value "group" means use group-level averaging (average across tokens in a group, then average across groups).
When using "token", total_batch_tokens is gathered via allreduce across all ranks."""
alpha: float = 0.6
"""The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param)
reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly
Expand Down Expand Up @@ -456,10 +458,7 @@ def __post_init__(self):
"Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
)
if self.masked_mean_denominator is not None:
assert self.masked_mean_denominator > 0, (
f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!"
)
self.masked_mean_denominator = get_denominator(self.masked_mean_denominator)
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
Expand Down Expand Up @@ -498,6 +497,8 @@ def __post_init__(self):
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{beaker_users}/{checkpoint_dir_name}"
else:
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}"
if not checkpoint_dir_name.startswith("/filestore"):
self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing path separator in checkpoint directory

The code concatenates /filestore directly with self.checkpoint_state_dir without a separator, resulting in paths like /filestoremycheckpoint instead of /filestore/my_checkpoint. This creates invalid filesystem paths when the checkpoint directory doesn't already start with /filestore.

Fix in Cursor Fix in Web


if self.checkpoint_state_dir is not None:
if self.gs_checkpoint_state_dir is not None:
Expand Down Expand Up @@ -544,15 +545,6 @@ def __post_init__(self):
raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.")


def masked_mean(
values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None
) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
numerator = (values * mask).sum(axis=axis)
denom = mask.sum(axis=axis) if denominator is None else denominator
return (numerator / denom).mean()


def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor:
padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id)
if pin_memory:
Expand Down Expand Up @@ -1053,6 +1045,75 @@ def compute_logprobs(

return collated_logprobs, collated_entropies

def calculate_token_counts(
self,
accumulation_steps: int,
collated_response_masks: list[torch.Tensor],
collated_tool_masks: list[torch.Tensor],
mode: str = "token",
) -> dict[int, float | torch.Tensor]:
accumulation_counts = {}

max_group_id = 0
if mode == "group":
# First pass to determine max group index
for i in range(len(collated_response_masks)):
mb_response_masks = collated_response_masks[i]
# group_id = (sample_index) // n
# sample_index starts at 0. response_mask contains sample_index + 1.
# So (response_mask - 1) // n
max_group_id = max(
max_group_id, ((mb_response_masks.max().item() - 1) // self.args.num_samples_per_prompt_rollout)
)

# All reduce max_group_id to ensure all ranks have the same tensor size
if dist.is_available() and dist.is_initialized():
dist.barrier()
max_group_id_tensor = torch.tensor(max_group_id, dtype=torch.long, device=self.device)
dist.all_reduce(max_group_id_tensor, op=dist.ReduceOp.MAX, group=None)
max_group_id = max_group_id_tensor.item()

for group_start in range(0, len(collated_response_masks), accumulation_steps):
group_end = min(group_start + accumulation_steps, len(collated_response_masks))

if mode == "group":
counts = torch.zeros(max_group_id + 1, device=self.device, dtype=torch.float32)
else:
counts = torch.tensor(0.0, device=self.device, dtype=torch.float32)

for i in range(group_start, group_end):
mb_response_masks = collated_response_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if self.args.mask_tool_use and self.args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool()

if mode == "group":
# Filter valid tokens
valid_mask = mb_response_masks_bool
flat_mask = valid_mask.flatten()
if flat_mask.any():
# Get group IDs for valid tokens
flat_response_masks = mb_response_masks[:, 1:].flatten()
valid_response_masks = flat_response_masks[flat_mask]
group_ids = (valid_response_masks - 1) // self.args.num_samples_per_prompt_rollout
# Accumulate counts
counts.scatter_add_(0, group_ids, torch.ones_like(group_ids, dtype=torch.float32))
else:
counts += mb_response_masks_bool.sum().float()

# All reduce counts
if dist.is_available() and dist.is_initialized():
dist.barrier()
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=None)

if mode == "token":
accumulation_counts[group_start] = counts.item()
else:
accumulation_counts[group_start] = counts

return accumulation_counts

def train(
self,
collated_query_responses,
Expand Down Expand Up @@ -1158,6 +1219,17 @@ def train(
ratio_stats = torch.zeros(len(collated_query_responses))
entropy_stats = torch.zeros(len(collated_query_responses))
for epoch_idx in range(args.num_epochs):
# Pre-compute total tokens for each accumulation group if using "token" normalization
# This ensures all minibatches in an accumulation group are normalized by the same total
accumulation_token_counts = {}
if args.masked_mean_denominator in ["token", "group"]:
accumulation_token_counts = self.calculate_token_counts(
accumulation_steps,
collated_response_masks,
collated_tool_masks,
mode=args.masked_mean_denominator,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you makek this a function?

accumulation_group_tokens = maybe_calculate_group_tokens(...)

for i in range(len(collated_query_responses)):
mb_query_responses = collated_query_responses[i]
mb_tool_mask = collated_tool_masks[i]
Expand All @@ -1167,6 +1239,14 @@ def train(
# if masking snippets, do it here.
if args.mask_tool_use and args.tool_use:
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()

# Determine the denominator for masked_mean normalization
loss_denominator = args.masked_mean_denominator
loss_axis = None
if args.masked_mean_denominator == "token":
group_start = (i // accumulation_steps) * accumulation_steps
loss_denominator = accumulation_token_counts[group_start]

mb_attention_mask = collated_attention_masks[i]
mb_position_id = collated_position_ids[i]
mb_local_logprobs, mb_entropy = self.forward(
Expand Down Expand Up @@ -1272,6 +1352,50 @@ def train(

pg_loss_max = torch.max(pg_losses, pg_losses2)

# for stats computation, for now no denominator is used
# unless masked_mean_denominator is a numeric value.
stats_denominator = (
args.masked_mean_denominator
if args.masked_mean_denominator != "token" and args.masked_mean_denominator != "group"
else None
)

# Define reduction function based on configuration
if args.masked_mean_denominator == "group":
group_start = (i // accumulation_steps) * accumulation_steps
group_counts = accumulation_token_counts[group_start]
total_active_groups = (group_counts > 0).sum().item()
group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout

def reduce_fn(v, m, a=None, d=None):
flat_v = v.flatten()
flat_m = m.flatten().bool()
flat_g = group_ids.flatten()

# if no valid tokens in batch.
if not flat_m.any():
return torch.tensor(0.0, device=v.device)

valid_v = flat_v[flat_m]
valid_g = flat_g[flat_m]

valid_counts = group_counts[valid_g]
# Avoid division by zero if count is 0 (should not happen for valid tokens)
valid_counts = torch.max(
valid_counts, torch.tensor(1.0, device=valid_counts.device, dtype=valid_counts.dtype)
)

weights = 1.0 / (valid_counts * total_active_groups)

# Sum weighted values
loss = (valid_v * weights).sum()
scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
scale *= accumulation_steps
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect gradient scaling in group mode

In group mode, the loss is multiplied by accumulation_steps on line 1393, causing accumulated gradients to be scaled by accumulation_steps. Since gradients are summed across minibatches during gradient accumulation, this multiplication inflates the effective gradient by a factor of accumulation_steps, making the effective learning rate much larger than intended.

Fix in Cursor Fix in Web


return loss * scale
else:
reduce_fn = masked_mean

if args.load_ref_policy:
mb_ref_logprob = collated_ref_logprobs[i]
# Here we recalculate kl: we want the KL loss to backpropagate through the model
Expand All @@ -1291,17 +1415,23 @@ def train(
elif args.kl_estimator == "kl4":
kl = kl4
# grpo change: directly subtract KL in loss (add)
loss = masked_mean(
pg_loss_max + (args.beta * kl),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
loss = reduce_fn(
pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator
)
else:
loss = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
loss = loss / accumulation_steps
loss = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator)
if args.masked_mean_denominator == "token":
# In token mode, we divide by the GLOBAL total number of tokens.
# DDP averages gradients across ranks (dividing by world_size).
# To get the true global mean gradient (sum_all_gradients / global_tokens),
# we must multiply by world_size to cancel out DDP's division.
if dist.is_available() and dist.is_initialized():
loss *= dist.get_world_size()
elif args.masked_mean_denominator != "group":
# For "group" mode, the scaling is handled inside reduce_fn.
# For default (None) or numeric modes, we divide by accumulation_steps here.
loss = loss / accumulation_steps

# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
torch.cuda.empty_cache()
self.model.backward(loss)
Expand All @@ -1311,44 +1441,29 @@ def train(
with torch.no_grad():
if args.load_ref_policy:
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl1_stats[i] = masked_mean(
kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
kl2_stats[i] = masked_mean(
kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
kl3_stats[i] = masked_mean(
kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
kl4_stats[i] = masked_mean(
kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
if args.kl_estimator == "kl1":
kl_loss_stats[i] = kl1_stats[i] * args.beta
elif args.kl_estimator == "kl2":
kl_loss_stats[i] = kl2_stats[i] * args.beta
elif args.kl_estimator == "kl3":
kl_loss_stats[i] = kl3_stats[i] * args.beta
elif args.kl_estimator == "kl4":
kl_loss_stats[i] = kl4_stats[i] * args.beta
pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
)
pg_loss_stats[i] = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(
ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
kl1_stats[i] = reduce_fn(kl1, mb_response_masks_bool, loss_axis, stats_denominator).float()
kl2_stats[i] = reduce_fn(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float()
kl3_stats[i] = reduce_fn(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float()
kl4_stats[i] = reduce_fn(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float()
if args.kl_estimator == "kl1":
kl_loss_stats[i] = kl1_stats[i] * args.beta
elif args.kl_estimator == "kl2":
kl_loss_stats[i] = kl2_stats[i] * args.beta
elif args.kl_estimator == "kl3":
kl_loss_stats[i] = kl3_stats[i] * args.beta
elif args.kl_estimator == "kl4":
kl_loss_stats[i] = kl4_stats[i] * args.beta
pg_clipfrac_stats[i] = reduce_fn(
(pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator
)
pg_loss_stats[i] = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator)
loss_stats[i] = loss
ratio_stats[i] = reduce_fn(ratio, mb_response_masks_bool, loss_axis, stats_denominator)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = reduce_fn(
mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator
).float()

with torch.no_grad():
if args.load_ref_policy:
Expand Down
Loading
Loading