Skip to content

VLM: Onboarding native resolution, native aspect ratio, interleaved VLM training #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

lkhphuc
Copy link
Contributor

@lkhphuc lkhphuc commented Aug 21, 2025

First PR to onboarding modern VLM training to torchtitan.

Features:

  • Native Aspect Ratio: not limited to square crops.
  • Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
  • Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.

Design

Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two additional hyperparameters, number of images per batch N and max image patches length L, then we pad the actual image patches to this fixed size.

Screenshot 2025-08-21 at 16 21 57
  • After tok_embedding, we obtain tokens of shape BxS.
  • After encoder, we obtain visual tokens of shape NxL.
  • We extract the valid visual tokens only
  • Then scatter those tokens to their actual positions in the LLM input tokens.

This requires the dataloader to handle the following aspect:

  • Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size
  • Convert images/videos to 1D sequence of patchs:
    • rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)
    • Pad all image patches sequence to a fixed length and return pixel_values.shape == [N, L, D]
  • Return a grid_thw.shape == [N, L, 3] to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values -1.

This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio:

  • Depending on data mixtures, we can set dataloader's hyperparameters N, L to have minimal empty image padding (in batch dimension).
  • Use modern pytorch features (Flex Attention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension).
  • Interface nicely with TP, PP, etc

In this PR

  • Minimal interleaved Obelics dataloader with native resolution and aspect ratio.
    • The dataloader is currently very slow, as it need to download images from internet everytime you run. (Same thing for the current imp in the multimodal experiment).
  • Siglip2 model code, mostly based on HF.
  • VLM model code called Llama3Siglip2 connecting the two vision encoder and language decoder.
  • Minimal infra code for debug model to run
Screenshot 2025-08-21 at 15 25 25

Todo:

  • Add support for captioning HF dataset that has images stored inside the dataset (CC12M like Flux exp?) so it's not super slow to load
  • Flex Attention for encoder.
  • Modify Llama3 tokenizer to add special tokens.
  • Script to combine Siglip2 + Llama3 weights and load.
  • Test Siglip2 encoder correctness.
  • Multimodal CE loss to correct for image token bias
  • All the parallelisms DP, CP, TP, PP.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 21, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the great PR! I learned a lot from this PR personally. However, I feel like the data preprocessing part in mm_collator_nld.py is a little bit hard to follow and read.

For image preprocessing, it mainly happens inmm_collator_nld.py , and the collator functions contains following steps for images:

  1. Patchify
  2. Generate Grids with coordinations
  3. Padding/ truncate
  4. Assemble as batched outputs

And text preprocessing is mainly handled in mm_dataset.py, which also contains several steps, eg padding with <image> tokens, Tokenization, mask out <image> tokens in label.

I was wondering can we future split the image and text preprocessing function into smaller code pieces, adding tensor shape hints, or even adding examples like experiments/multimodal? In this way, we could increase readability and easy to debug

The VLM modeling parts LGTM, it's clear structured!

@@ -2029,7 +2029,10 @@
"land": 1994,
"?\n": 1995,
" respect": 1996,
"ances": 1997
"ances": 1997,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we keep a test tokenizer for VLM only, under the experiments/vlm/assets folder. Currently I'm debugging a know issue #1136, adding these vocab will make the test tokenizer's vocab size not dividable by worldsize, which might causing failure of other LLM when resume training

Comment on lines +242 to +245
# Normalize with OpenAI CLIP mean/std
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
img_array = (img_array - mean) / std
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: why we use CLIP mean/std to normalize the dataset? Is it a common practice?

)
images_per_sample.append(num_images)

# Remove samples from end until total images <= max_images_per_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this means we will drop those images > max_images_per_batch

if img is not None:
try:
# Handle online case (image URLs)
if isinstance(img, str) and img.startswith("http"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this path, we can not easily run or test on our GPUs because our cluster don't have Internet access :( I might try finding some alternative dataset rather than HuggingFaceM4/OBELICS

self, pixel_values: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Direct NTHWC -> NLD conversion using einops."""
N, T, H, W, C = pixel_values.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By reading the code, T dimension here will always be 1 for image samples?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants