Skip to content

VLM: Onboarding native resolution, native aspect ratio, interleaved VLM training #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/assets/tokenizer/tokenizer.json
Original file line number Diff line number Diff line change
Expand Up @@ -2029,7 +2029,10 @@
"land": 1994,
"?\n": 1995,
" respect": 1996,
"ances": 1997
"ances": 1997,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we keep a test tokenizer for VLM only, under the experiments/vlm/assets folder. Currently I'm debugging a know issue #1136, adding these vocab will make the test tokenizer's vocab size not dividable by worldsize, which might causing failure of other LLM when resume training

"<|image|>": 1998,
"<|begin_of_image|>": 1999,
"<|end_of_image|>": 2000
},
"merges": [
]
Expand Down
27 changes: 27 additions & 0 deletions tests/assets/tokenizer/tokenizer_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,38 @@
"rstrip": false,
"single_word": false,
"special": true
},
"1998": {
"content": "<|image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1999": {
"content": "<|begin_of_image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2000": {
"content": "<|end_of_image|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|begin_of_text|>",
"clean_up_tokenization_spaces": true,
"eos_token": "<|end_of_text|>",
"img_token": "<|image|>",
"boi_token": "<|begin_of_image|>",
"eoi_token": "<|end_of_image|>",
"model_input_names": [
"input_ids",
"attention_mask"
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
import torchtitan.experiments.vlm # noqa: F401
19 changes: 19 additions & 0 deletions torchtitan/experiments/vlm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Vision Language Model training in `torchtitan`

**under active development**

This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.


## Features:
- Native Aspect Ratio: not limited to square crops.
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.


## Design
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
Then we scatter the patch embeddings to their actual positions in the LLM input tokens.
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio.
By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens.
We leverage Flex Attention to efficiently handle varying number of patches per image.
101 changes: 101 additions & 0 deletions torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .datasets.mm_datasets import build_mm_dataloader
from .infra.parallelize import parallelize_vlm
# from .infra.pipeline import pipeline_llama
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs
from .model.model import Llama3Siglip2Transformer

__all__ = [
"parallelize_vlm",
# "pipeline_llama",
"Llama3Siglip2ModelArgs",
"Llama3Siglip2Transformer",
"llama3_siglip2_configs",
]


siglip2_configs = {
"debugmodel": Siglip2ModelArgs(
dim=128,
ffn_dim=256,
n_layers=4,
n_heads=2,
)
}

llama3_siglip2_configs = {
"debugmodel": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
),
"debugmodel_flex_attn": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256,
n_layers=6,
n_heads=16,
vocab_size=2000,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"8B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
),
"70B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
rope_theta=500000,
),
"405B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3-siglip2",
model_cls=Llama3Siglip2Transformer,
model_args=llama3_siglip2_configs,
parallelize_fn=parallelize_vlm,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_mm_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
# state_dict_adapter=Llama3StateDictAdapter,
)
)
188 changes: 188 additions & 0 deletions torchtitan/experiments/vlm/datasets/mm_collator_nld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Dict, List

import einops as E
import torch
from torch.nn.utils.rnn import pad_sequence

from torchtitan.tools.logging import logger


IGNORE_INDEX = -100


@dataclass
class MultiModalCollatorNLD:
"""Collator that works with patches in NLD format (N=batch, L=patches, D=patch_features)"""

padding_idx: int = 0
ignore_idx: int = IGNORE_INDEX
max_images_per_batch: int = 5
max_patch_per_image: int = 256 # Maximum patches per image
patch_size: int = 16 # Patch size for converting images to patches
merge_size: int = 1 # Merge size for converting spatial patches to channel dim
seq_len: int = 2048

def convert_to_patches(
self, pixel_values: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Direct NTHWC -> NLD conversion using einops."""
N, T, H, W, C = pixel_values.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By reading the code, T dimension here will always be 1 for image samples?

ps = self.patch_size
device = pixel_values.device
patches = E.rearrange(
pixel_values, "n t (h p1) (w p2) c -> n (t h w) (p1 p2 c)", p1=ps, p2=ps
)

coords = torch.meshgrid(
torch.arange(T, device=device),
torch.arange(H // ps, device=device),
torch.arange(W // ps, device=device),
indexing="ij",
)
grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords")
grid = grid.unsqueeze(0).expand(N, -1, -1) # (N, t*h*w, 3)

# All patches are valid since we resize images to be divisible by patch_size
return patches, grid

def _pad_to_max(self, patches, grids):
"""Pad or truncate to max_patch_per_image."""
N, L, D = patches.shape
if L == self.max_patch_per_image:
return patches, grids
elif L < self.max_patch_per_image:
# Pad
pad_len = self.max_patch_per_image - L
zero_patches = torch.zeros(N, pad_len, D, device=patches.device)
invalid_grids = torch.full(
(grids.shape[0], pad_len, 3), -1, device=grids.device
)
return torch.cat([patches, zero_patches], 1), torch.cat(
[grids, invalid_grids], 1
)
else:
# Truncate
return (
patches[:, : self.max_patch_per_image],
grids[:, : self.max_patch_per_image],
)

def __call__(
self, batch: List[Dict[str, Any]]
) -> tuple[Dict[str, torch.Tensor | None], torch.Tensor]:
"""Encode batch with patch-based approach."""
if not batch:
return None

# Count images per sample and total images
images_per_sample = []
for sample in batch:
num_images = (
len(sample.get("pixel_values", [])) if "pixel_values" in sample else 0
)
images_per_sample.append(num_images)

# Remove samples from end until total images <= max_images_per_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this means we will drop those images > max_images_per_batch

total_images = sum(images_per_sample)
while total_images > self.max_images_per_batch and batch:
removed_images = images_per_sample.pop()
total_images -= removed_images
batch.pop()
logger.warning(f"Removed sample with {removed_images} images to keep total images <= {self.max_images_per_batch}")

all_images = [
img
for sample in batch
if "pixel_values" in sample
for img in sample["pixel_values"]
]

if all_images:
patch_list, grid_list = [], []
for img in all_images:
p, g = self.convert_to_patches(img.unsqueeze(0))
p, g = self._pad_to_max(p, g)
patch_list.append(p[0])
grid_list.append(g[0])
patches = torch.stack(patch_list)
grids = torch.stack(grid_list)

if len(all_images) < self.max_images_per_batch:
blank_count = self.max_images_per_batch - len(all_images)
blank_patches = torch.zeros(
blank_count,
self.max_patch_per_image,
patches.shape[2],
device=patches.device,
)
blank_grids = torch.full(
(blank_count, self.max_patch_per_image, 3), -1, device=grids.device
)
patches = torch.cat([patches, blank_patches], dim=0)
grids = torch.cat([grids, blank_grids], dim=0)
else:
patches = grids = None

# Text processing
input_ids = pad_sequence(
[s["input_ids"] for s in batch],
batch_first=True,
padding_value=self.padding_idx,
)
labels = pad_sequence(
[s["labels"] for s in batch],
batch_first=True,
padding_value=self.padding_idx,
)

# Pad along batch dimension if needed
batch_size = len(batch)
if input_ids.size(0) < batch_size:
padding_needed = batch_size - input_ids.size(0)
padding_input = (
torch.ones(padding_needed, input_ids.size(1), dtype=torch.long)
* self.padding_idx
)
padding_labels = (
torch.ones(padding_needed, labels.size(1), dtype=torch.long)
* self.padding_idx
)
input_ids = torch.cat([input_ids, padding_input], dim=0)
labels = torch.cat([labels, padding_labels], dim=0)

# Handle sequence length
current_length = input_ids.size(1)
desired_length = self.seq_len + 1 # Extra token for label shift and cut
if current_length < desired_length:
padding_length = desired_length - current_length
padding_input = (
torch.ones(batch_size, padding_length, dtype=torch.long)
* self.padding_idx
)
padding_labels = (
torch.ones(batch_size, padding_length, dtype=torch.long)
* self.padding_idx
)
input_ids = torch.cat([input_ids, padding_input], dim=1)
labels = torch.cat([labels, padding_labels], dim=1)
elif current_length > self.seq_len:
input_ids = input_ids[:, :desired_length]
labels = labels[:, :desired_length]

labels[labels == self.padding_idx] = self.ignore_idx
# Cut and shift
input_ids = input_ids[:, :-1]
labels = labels[:, 1:]

return {
"input": input_ids,
"pixel_values": patches,
"grid_thw": grids,
}, labels
Loading
Loading