Skip to content

Commit 1e90873

Browse files
authored
[internvl] fix chat template (#37656)
* fix chat template * update * update conversion * rename `fake_image_token` in tests
1 parent 9ec8be5 commit 1e90873

File tree

5 files changed

+87
-119
lines changed

5 files changed

+87
-119
lines changed

docs/source/en/model_doc/internvl.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ InternVL models can also handle video inputs. Here is an example of how to perfo
257257
... add_generation_prompt=True,
258258
... tokenize=True,
259259
... return_dict=True,
260+
... num_frames=8,
260261
>>> ).to(model.device, dtype=torch.float16)
261262

262263
>>> output = model.generate(**inputs, max_new_tokens=25)

src/transformers/models/internvl/convert_internvl_weights_to_hf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None,
312312
"start_image_token": "<img>",
313313
"end_image_token": "</img>",
314314
"context_image_token": "<IMG_CONTEXT>",
315+
"video_token": "<video>",
315316
},
316317
)
317318
tokenizer.model_max_length = CONTEXT_LENGTH

src/transformers/models/internvl/processing_internvl.py

Lines changed: 54 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
# limitations under the License.
1515

1616

17-
from functools import partial
18-
from typing import Dict, List, Optional, Union
17+
from typing import List, Optional, Union
1918

2019
import numpy as np
2120

2221
from transformers.processing_utils import (
23-
AllKwargsForChatTemplate,
2422
ImagesKwargs,
2523
ProcessingKwargs,
2624
ProcessorMixin,
@@ -34,6 +32,7 @@
3432
VideoInput,
3533
VideoMetadata,
3634
concatenate_list,
35+
load_video,
3736
make_batched_videos,
3837
make_flat_list_of_images,
3938
)
@@ -75,20 +74,12 @@ class InternVLProcessor(ProcessorMixin):
7574
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
7675
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
7776
in a chat into a tokenizable string.
78-
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
79-
The token to use for the image placeholder in the text. This token will be replaced by the
80-
appropriate image tokens when processing the text with images.
81-
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
82-
The token to use for the video placeholder in the text. This token will be replaced by the
83-
appropriate image tokens when processing the text with videos.
8477
"""
8578

8679
attributes = ["image_processor", "tokenizer"]
8780
valid_kwargs = [
8881
"chat_template",
8982
"image_seq_length",
90-
"fake_image_token",
91-
"fake_video_token",
9283
]
9384
image_processor_class = "AutoImageProcessor"
9485
tokenizer_class = "AutoTokenizer"
@@ -99,16 +90,14 @@ def __init__(
9990
tokenizer=None,
10091
image_seq_length: int = 256,
10192
chat_template=None,
102-
fake_image_token="<image>",
103-
fake_video_token="<video>",
10493
**kwargs,
10594
):
10695
self.image_seq_length = image_seq_length
107-
self.fake_image_token = fake_image_token
108-
self.fake_video_token = fake_video_token
10996
self.start_image_token = tokenizer.start_image_token
11097
self.end_image_token = tokenizer.end_image_token
111-
self.context_image_token = tokenizer.context_image_token
98+
self.image_token = tokenizer.context_image_token
99+
self.video_token = tokenizer.video_token
100+
self.image_token_id = tokenizer.context_image_token_id
112101

113102
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
114103

@@ -131,24 +120,24 @@ def _insert_media_placeholders(
131120
video_index = 0
132121
processed_text = []
133122
image_video_patches = []
123+
replace_strings = []
134124
# Support interleaved image and video in prompts:
135125
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
136126
for prompt in text:
137127
new_prompt = prompt
138-
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
139-
if self.fake_image_token in new_prompt and (
140-
self.fake_video_token not in new_prompt
141-
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
128+
while self.image_token in new_prompt or self.video_token in new_prompt:
129+
if self.image_token in new_prompt and (
130+
self.video_token not in new_prompt
131+
or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
142132
):
143133
# Get the slice of patches corresponding to the current image
144134
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
145135
end_index = image_num_patches_indices[image_index]
146136
image_video_patches.append(image_pixel_values[start_index:end_index])
147137
# Replace the corresponding image placeholder with the correct number of image tokens
148-
new_prompt = new_prompt.replace(
149-
self.fake_image_token,
150-
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
151-
1,
138+
new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
139+
replace_strings.append(
140+
f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
152141
)
153142
image_index += 1
154143
else:
@@ -163,11 +152,15 @@ def _insert_media_placeholders(
163152
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
164153
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
165154
video_prompt = "\n".join(
166-
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
155+
f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
167156
for i in range(len(num_patches))
168157
)
169-
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
158+
replace_strings.append(video_prompt)
159+
new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
170160
video_index += 1
161+
while "<placeholder>" in new_prompt:
162+
replace_str = replace_strings.pop(0)
163+
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
171164
processed_text.append(new_prompt)
172165

173166
return processed_text, image_video_patches, image_index, video_index
@@ -269,9 +262,11 @@ def __call__(
269262
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
270263
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
271264

265+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
272266
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
267+
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
273268

274-
return BatchFeature(data={**text_inputs, **image_videos_inputs})
269+
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
275270

276271
def sample_indices_fn(
277272
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
@@ -290,15 +285,13 @@ def sample_indices_fn(
290285
Returns:
291286
`np.ndarray`: Array of frame indices to sample.
292287
"""
288+
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
289+
293290
if initial_shift is True:
294291
initial_shift = metadata.total_num_frames / num_frames / 2
295-
if num_frames is not None:
296-
indices = np.arange(
297-
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
298-
).astype(int)
299-
else:
300-
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
301-
292+
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
293+
int
294+
)
302295
return indices
303296

304297
def batch_decode(self, *args, **kwargs):
@@ -321,58 +314,39 @@ def model_input_names(self):
321314
image_processor_input_names = self.image_processor.model_input_names
322315
return list(tokenizer_input_names) + list(image_processor_input_names)
323316

324-
# Add model-specific video sampling method when applying the template
325-
def apply_chat_template(
317+
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
318+
def _load_video_for_model(
326319
self,
327-
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
328-
chat_template: Optional[str] = None,
329-
num_frames: int = 8,
330-
initial_shift: Union[bool, float, int] = True,
331-
video_load_backend="pyav",
332-
**kwargs: Unpack[AllKwargsForChatTemplate],
333-
):
320+
video: Union[str, "VideoInput"],
321+
num_frames: Optional[int],
322+
backend: str = "pyav",
323+
initial_shift: bool = True,
324+
**kwargs,
325+
) -> np.array:
334326
"""
335-
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
336-
conversations to turn them into a single tokenizable string.
337-
338-
The input is expected to be in the following format, where each message content is a list consisting of text and
339-
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
340-
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
341-
342-
conversation = [
343-
{
344-
"role": "user",
345-
"content": [
346-
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
347-
{"type": "text", "text": "Please describe this image in detail."},
348-
],
349-
},
350-
]
327+
Loads `video` to a numpy array.
351328
352329
Args:
353-
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
354-
The conversation to format.
355-
chat_template (`Optional[str]`, *optional*):
356-
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
357-
chat template is used.
358-
num_frames (`int`, *optional*, defaults to 8):
359-
Number of frames to sample from a video when using the default `sample_indices_fn`.
360-
initial_shift (`bool`, `float` or `int`, defaults to `0`):
361-
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
362-
If `True`, the shift is set so that frames are sampled from the middle of the video.
330+
video (`str` or `VideoInput`):
331+
The video to convert to the numpy array format. Can be a link to video or local path.
332+
num_frames (`int`, *optional*):
333+
Number of frames to sample uniformly. If not passed, the whole video is loaded.
334+
backend (`str`, *optional*, defaults to `"pyav"`):
335+
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
336+
initial_shift (`bool`, *optional*, defaults to `True`):
337+
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
338+
339+
Returns:
340+
Tuple[`np.array`, Dict]: A tuple containing:
341+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
342+
- Metadata dictionary.
363343
"""
364-
sample_indices_fn = kwargs.pop(
365-
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
366-
)
367344

368-
return super().apply_chat_template(
369-
conversation,
370-
chat_template,
371-
video_load_backend=video_load_backend,
372-
num_frames=num_frames,
373-
sample_indices_fn=sample_indices_fn,
374-
**kwargs,
375-
)
345+
def sample_indices_fn_func(metadata, **fn_kwargs):
346+
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
347+
348+
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
349+
return video, metadata
376350

377351

378352
__all__ = ["InternVLProcessor"]

tests/models/internvl/test_modeling_internvl.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def test_qwen2_small_model_integration_generate(self):
296296
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
297297
image = Image.open(requests.get(url, stream=True).raw)
298298

299-
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
299+
prompt = (
300+
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
301+
)
300302
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
301303
with torch.no_grad():
302304
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -314,7 +316,9 @@ def test_qwen2_small_model_integration_forward(self):
314316
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
315317
image = Image.open(requests.get(url, stream=True).raw)
316318

317-
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
319+
prompt = (
320+
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
321+
)
318322
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
319323

320324
# Forward
@@ -378,8 +382,8 @@ def test_qwen2_small_model_integration_batched_generate(self):
378382
)
379383
# Prepare inputs
380384
prompt = [
381-
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
382-
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
385+
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
386+
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
383387
]
384388
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
385389
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -414,8 +418,8 @@ def test_qwen2_small_model_integration_batched_generate_multi_image(self):
414418
)
415419
# Prepare inputs
416420
prompt = [
417-
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
418-
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
421+
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
422+
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
419423
]
420424
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
421425
image2 = Image.open(
@@ -485,6 +489,7 @@ def test_qwen2_medium_model_integration_video(self):
485489
tokenize=True,
486490
return_dict=True,
487491
return_tensors="pt",
492+
num_frames=8,
488493
).to(torch_device, dtype=torch.float16)
489494

490495
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -552,6 +557,7 @@ def test_qwen2_small_model_integration_interleaved_images_videos(self):
552557
return_dict=True,
553558
return_tensors="pt",
554559
padding=True,
560+
num_frames=8,
555561
).to(torch_device, dtype=torch.bfloat16)
556562

557563
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -601,7 +607,9 @@ def test_llama_small_model_integration_generate(self):
601607
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
602608
image = Image.open(requests.get(url, stream=True).raw)
603609

604-
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
610+
prompt = (
611+
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
612+
)
605613
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
606614
with torch.no_grad():
607615
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -619,7 +627,9 @@ def test_llama_small_model_integration_forward(self):
619627
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
620628
image = Image.open(requests.get(url, stream=True).raw)
621629

622-
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
630+
prompt = (
631+
"<|im_start|>user\n<IMG_CONTEXT>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
632+
)
623633
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
624634

625635
# Forward
@@ -687,8 +697,8 @@ def test_llama_small_model_integration_batched_generate(self):
687697
)
688698
# Prepare inputs
689699
prompt = [
690-
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
691-
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
700+
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
701+
"<|im_start|>user\n<IMG_CONTEXT>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
692702
]
693703
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
694704
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
@@ -724,8 +734,8 @@ def test_llama_small_model_integration_batched_generate_multi_image(self):
724734
)
725735
# Prepare inputs
726736
prompt = [
727-
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
728-
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
737+
"<|im_start|>user\n<IMG_CONTEXT>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
738+
"<|im_start|>user\n<IMG_CONTEXT><IMG_CONTEXT>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
729739
]
730740
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
731741
image2 = Image.open(
@@ -795,6 +805,7 @@ def test_llama_medium_model_integration_video(self):
795805
tokenize=True,
796806
return_dict=True,
797807
return_tensors="pt",
808+
num_frames=8,
798809
).to(torch_device, dtype=torch.float16)
799810

800811
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
@@ -862,6 +873,7 @@ def test_llama_small_model_integration_interleaved_images_videos(self):
862873
return_dict=True,
863874
return_tensors="pt",
864875
padding=True,
876+
num_frames=8,
865877
).to(torch_device, dtype=torch.bfloat16)
866878

867879
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)

0 commit comments

Comments
 (0)