Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 184 additions & 7 deletions nemo_automodel/components/datasets/vlm/collate_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@
process_vision_info = MagicMock()


def _maybe_add_gemma3_token_type_ids(batch: dict, processor) -> None:
"""If running with a Gemma-3 style processor and token_type_ids are absent,
mark image tokens (== image_token_id) as 1 and others as 0.

This mirrors Gemma3 token type semantics in LLaMA-Factory's plugins, where
token_type_ids highlight image tokens. We only add when safe (not present).
"""

if "token_type_ids" in batch:
return

processor_type = type(processor).__name__ if processor is not None else ""
if processor_type not in ("Gemma3_VLProcessor", "Gemma3nProcessor"):
return

image_token_id = getattr(processor, "image_token_id", None)
if image_token_id is None:
return

input_ids = batch.get("input_ids", None)
if input_ids is None:
return

# token_type_ids: 1 where image token appears, else 0
batch["token_type_ids"] = (input_ids == image_token_id).long()


def create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token=None):
r"""
Create loss mask by finding start of turn token positions, similar to squad.py approach.
Expand Down Expand Up @@ -85,15 +112,32 @@ def phi4_mm_collate_fn(examples, processor):
batch = processor(
text=texts, audios=audio_inputs, return_tensors="pt", padding=True, truncation=True, max_length=1024
)
skipped_tokens = extract_skipped_token_ids(processor)

labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100

loss_masks = []
for i, conversation in enumerate(conversations):
input_ids = batch["input_ids"][i].tolist()

assistant_content = conversation[1]["content"]
assistant_tokens = processor.tokenizer(assistant_content, add_special_tokens=False)["input_ids"]
# Extract assistant text robustly (supports list-of-chunks or plain string)
if isinstance(assistant_content, list):
assistant_text = "".join(
[
chunk.get("text", "")
for chunk in assistant_content
if isinstance(chunk, dict) and chunk.get("type") == "text"
]
)
elif isinstance(assistant_content, str):
assistant_text = assistant_content
else:
assistant_text = str(assistant_content)

assistant_tokens = processor.tokenizer(assistant_text, add_special_tokens=False)["input_ids"]

loss_mask = [0] * len(input_ids)
for start_idx in range(len(input_ids) - len(assistant_tokens) + 1):
Expand All @@ -105,7 +149,7 @@ def phi4_mm_collate_fn(examples, processor):

max_len = max(len(mask) for mask in loss_masks)
padded_loss_masks = [mask + [0] * (max_len - len(mask)) for mask in loss_masks]
batch["loss_mask"] = torch.tensor(padded_loss_masks, dtype=torch.float)
batch["loss_mask"] = torch.tensor(padded_loss_masks, dtype=torch.float, device=batch["input_ids"].device)

labels[batch["loss_mask"] == 0] = -100
batch["labels"] = labels
Expand Down Expand Up @@ -139,22 +183,75 @@ def qwen2_5_collate_fn(
labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100
batch["labels"] = labels
loss_masks = [
create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token)
for input_ids in batch["input_ids"]
]
batch["loss_mask"] = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device)
labels[batch["loss_mask"] == 0] = -100
batch["labels"] = labels
return batch


def default_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
"""Default collate function for VLM models."""
# Helper: Generic collate fn for Qwen-VL style processors (Qwen2, Qwen3, etc.)
def _qwen_vl_generic_collate_fn(
examples: list, processor, start_of_response_token: str = "<|im_start|>assistant\n"
) -> dict[str, torch.Tensor]:
"""Shared logic for Qwen-2/3 VL style collate functions.

This is factorised so we can easily register additional Qwen-VL processor
types without duplicating code. The behaviour is identical to
``qwen2_5_collate_fn`` but is parameterised on *start_of_response_token* so
that future processor versions can override it if necessary.
"""

if not HAVE_QWEN_VL_UTILS:
raise ImportError(MISSING_QWEN_VL_UTILS_MSG)

skipped_tokens = extract_skipped_token_ids(processor)

texts = [processor.apply_chat_template(example["conversation"], tokenize=False) for example in examples]
image_inputs = [process_vision_info(example["conversation"])[0] for example in examples]

batch = processor(
text=texts,
images=image_inputs,
padding=True,
return_tensors="pt",
)

labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100

loss_masks = [
create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token)
for input_ids in batch["input_ids"]
]
batch["loss_mask"] = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device)
labels[batch["loss_mask"] == 0] = -100
batch["labels"] = labels

return batch


# Collate functions for other Qwen-VL processor variants
def qwen2_vl_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]:
"""Collate function for Qwen-2 VL models (same logic as Qwen-2.5)."""

return _qwen_vl_generic_collate_fn(examples, processor, "<|im_start|>assistant\n")


def qwen3_vl_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]:
"""Collate function for Qwen-3 VL models (identical logic to Qwen-2)."""

return _qwen_vl_generic_collate_fn(examples, processor, "<|im_start|>assistant\n")


def default_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
"""Default collate function for VLM models."""
skipped_tokens = extract_skipped_token_ids(processor)

batch = processor.apply_chat_template(
[example["conversation"] for example in examples],
tokenize=True,
Expand All @@ -170,21 +267,101 @@ def default_collate_fn(examples: list, processor, start_of_response_token=None)
torch.arange(seq_len, device=batch["input_ids"].device).unsqueeze(0).expand(batch_size, -1)
)

batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
if "pixel_values" in batch and isinstance(batch["pixel_values"], torch.Tensor):
batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100
batch["labels"] = labels
loss_masks = [
create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token)
for input_ids in batch["input_ids"]
]
batch["loss_mask"] = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device)
labels[batch["loss_mask"] == 0] = -100
batch["labels"] = labels
_maybe_add_gemma3_token_type_ids(batch, processor)
return batch


# Thin wrappers per model family to allow future specialization without changing
# call sites. For now, they just delegate to `default_collate_fn`.


def llava_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def llava_next_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def llava_next_video_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def video_llava_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def paligemma_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def pixtral_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def intern_vl_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def kimi_vl_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def llama4_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def gemma3_vl_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def gemma3n_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def glm4v_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def minicpm_v_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


def mllama_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]:
return default_collate_fn(examples, processor, start_of_response_token)


# Mapping of processor types to their collate functions
COLLATE_FNS = {
"Qwen2_5_VLProcessor": qwen2_5_collate_fn,
"Qwen2_VLProcessor": qwen2_vl_collate_fn,
"Qwen3_VLProcessor": qwen3_vl_collate_fn,
# Per-model wrappers (currently delegate to default_collate_fn)
"LlavaProcessor": llava_collate_fn,
"LlavaNextProcessor": llava_next_collate_fn,
"LlavaNextVideoProcessor": llava_next_video_collate_fn,
"VideoLlavaProcessor": video_llava_collate_fn,
"PaliGemmaProcessor": paligemma_collate_fn,
"PixtralProcessor": pixtral_collate_fn,
"InternVLProcessor": intern_vl_collate_fn,
"KimiVLProcessor": kimi_vl_collate_fn,
"Llama4Processor": llama4_collate_fn,
"Gemma3_VLProcessor": gemma3_vl_collate_fn,
"Gemma3nProcessor": gemma3n_collate_fn,
"GLM4VProcessor": glm4v_collate_fn,
"MiniCPMVProcessor": minicpm_v_collate_fn,
"MllamaProcessor": mllama_collate_fn,
"default": default_collate_fn,
}