-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat:add support for 'image_grid_thw'(QwenVL) in DPOTrainer #4091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
It seems not to work yet, not sure why: from typing import Any, Callable
import numpy as np
import torch
from PIL import Image
from datasets import Dataset, features
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForImageTextToText, AutoProcessor
dataset_dict = {
"prompt": [
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in great detail."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Is this bus in the USA?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Give a thorough description of the image."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Who are the people in the image?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is written?"}]}],
],
"chosen": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, multi-colored train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, it can be assumed that this bus is in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "There are two individuals, possibly girls or women."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"rejected": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, colorful train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "No, it's not in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path surrounded by trees."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "In the image, there are two individuals."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"images": [
[Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))],
],
}
# fmt: on
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("images", features.Sequence(features.Image()))
# Instantiate the model and processor
model_id = "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
model = AutoModelForImageTextToText.from_pretrained(model_id)
ref_model = AutoModelForImageTextToText.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
training_args = DPOConfig(
output_dir="t",
per_device_train_batch_size=2,
remove_unused_columns=False,
learning_rate=0.01, # increase learning rate to speed up test
max_prompt_length=None, # don't truncate to avoid issues with patch tokens
max_length=None,
report_to="none",
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=processor,
train_dataset=dataset,
eval_dataset=dataset,
)
# Save the initial weights, so we can check if they have changed after training
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train() fails |
Got it, I’ll check that out |
In Qwen2.5-VL–style models with dynamic visual tokens, pixel_values can have different lengths across samples. Padding along the batch dimension causes two issues: (1) it breaks the spatial correspondence of the image features (i.e., the grid/layout), and (2) it leads to a mismatch between the number of visual features and the number of image placeholder tokens. Therefore, when image_grid_thw is present, pixel_values should not be padded; instead, keep them as the per-sample concatenation guided by image_grid_thw. Also, in this case pixel_values must not be taken only from processed_features["pixel_values"][0]. |
@qgallouedec when you have a moment, could you take a look at this? |
What does this PR do?
Added support for QwenVL's image_grid_thw in the DPO trainer.
Fixes #4071
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.