Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions verl/models/mcore/mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
roll_tensor,
)

try:
from megatron.core.transformer.multi_token_prediction import process_mtp_loss as _process_mtp_loss
_HAS_PROCESS_MTP_LOSS = True
except ImportError:
_HAS_PROCESS_MTP_LOSS = False

try:
from megatron.core.utils import unwrap_model
except ImportError:
Expand Down Expand Up @@ -78,6 +84,7 @@ def _megatron_gptmodel_postprocess(
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
is_spec_decode=None,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.

Expand Down Expand Up @@ -111,58 +118,78 @@ def _megatron_gptmodel_postprocess(

# Skip when mtp_num_layers is None or 0
if self.config.mtp_num_layers and labels is not None:
mtp_labels = labels.clone()

hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
if _HAS_PROCESS_MTP_LOSS:
# New Megatron API (>= verl megatron fork with process_mtp_loss):
# process_mtp_loss handles chunking, rolling, loss scaling all internally.
cp_group = getattr(self, 'cp_group', None) or (
self.pg_collection.cp if hasattr(self, 'pg_collection') else None
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
scale_logits_fn = (
self._scale_logits if (hasattr(self, '_scale_logits') and self.config.use_mup) else None
)

# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
hidden_states = _process_mtp_loss(
hidden_states=hidden_states,
labels=labels,
loss_mask=loss_mask,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
is_training=self.training,
compute_language_model_loss=self.compute_language_model_loss,
config=self.config,
cp_group=cp_group,
packed_seq_params=packed_seq_params,
scale_logits_fn=scale_logits_fn,
)

mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
else:
# Legacy Megatron API: manual rolling + compute_output_layer_and_language_model_loss.
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
loss_mask = torch.ones_like(mtp_labels)
cp_group = getattr(self, 'cp_group', None)
for mtp_layer_number in range(self.config.mtp_num_layers):
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=cp_group,
packed_seq_params=packed_seq_params,
)
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
mtp_loss = loss_mask * mtp_loss
if self.training:
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
)

logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
# [s b h] => [b s h]
Expand Down
14 changes: 10 additions & 4 deletions verl/models/mcore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,11 @@ def preprocess_bshd_engine(
seqlens_in_batch = input_ids.offsets().diff()
max_seqlen = seqlens_in_batch.max().item()
tp_size = mpu.get_tensor_model_parallel_world_size()
# For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size.
align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size
# For CP (zigzag), sequence length must be divisible by (2 * cp_size).
# After zigzag-CP split each rank holds s/cp_size tokens, which must also be
# divisible by tp_size for sequence-parallel scatter. Therefore the total
# sequence length must be divisible by tp_size * cp_size * 2.
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - max_seqlen % align_size) % align_size
max_seqlen += pad_size
Expand Down Expand Up @@ -748,10 +751,13 @@ def build_vlm_attn_mask_bshd(input_ids: torch.Tensor, batch_size: int, pad_token
seqlens_in_batch = input_ids.offsets().diff()
max_seqlen = seqlens_in_batch.max().item()

# For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size.
# For CP (zigzag), sequence length must be divisible by (2 * cp_size).
# After zigzag-CP split each rank holds s/cp_size tokens, which must also be
# divisible by tp_size for sequence-parallel scatter. Therefore the total
# sequence length must be divisible by tp_size * cp_size * 2.
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - max_seqlen % align_size) % align_size
max_seqlen += pad_size
Expand Down
Loading