Skip to content

Commit cf3df35

Browse files
committed
update documentation & make fixup
1 parent 8e95dde commit cf3df35

File tree

4 files changed

+11
-37
lines changed

4 files changed

+11
-37
lines changed

docs/source/en/chat_templating_multimodal.md

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Some vision models also support video inputs. The message format is very similar
111111

112112
- The content `"type"` should be `"video"` to indicate the content is a video.
113113
- For videos, it can be a link to the video (`"url"`) or it could be a file path (`"path"`). Videos loaded from a URL can only be decoded with [PyAV](https://pyav.basswood-io.com/docs/stable/) or [Decord](https://github.com/dmlc/decord).
114+
- In addition to loading videos from a URL or file path, you can also pass decoded video data directly. This is useful if you’ve already preprocessed or decoded video frames elsewhere in memory (e.g., using OpenCV, decord, or torchvision). You don't need to save to files or store it in an URL.
114115

115116
> [!WARNING]
116117
> Loading a video from `"url"` is only supported by the PyAV or Decord backends.
@@ -137,27 +138,11 @@ messages = [
137138
]
138139
```
139140

140-
### Passing decoded video objects
141-
In addition to loading videos from a URL or file path, you can also pass decoded video data directly.
142-
143-
This is useful if you’ve already preprocessed or decoded video frames elsewhere in memory (e.g., using OpenCV, decord, or torchvision). You don't need to save to files or store it in an URL.
144-
145-
- Use the `"video"` type with a dictionary that includes:
146-
- `"frames"` (`np.ndarray` or `torch.Tensor`):
147-
A 4D array of shape (num_frames, channels, height, width) containing decoded video frames.
148-
- `"metadata"` (`"VideoMetadata"` or `"dict"`):
149-
Describes metadata for the video. If you provide a dictionary, it must include at least one of:
150-
- `"fps"` (frames per second)
151-
- `"duration"` (video duration in seconds)
152-
if both `"fps"` and `"duration"` is provided, `"fps"` gets priority and `"duration"` is calculated based on `"fps"`
153-
141+
### Example: Passing decoded video objects
154142
```python
155143
import numpy as np
156144

157-
video_object1 = {
158-
"frames": np.random.randint(0, 255, size=(16, 3, 224, 224), dtype=np.uint8),
159-
"metadata": {"fps": 16, "duration": 2.0}
160-
}
145+
video_object1 = np.random.randint(0, 255, size=(16, 224, 224, 3), dtype=np.uint8),
161146

162147
messages = [
163148
{
@@ -180,15 +165,10 @@ You can also use existing (`"load_video()"`) function to load a video, edit the
180165
from transformers.video_utils import load_video
181166

182167
# load a video file in memory for testing
183-
frames, metadata = load_video(
168+
video_object2, _ = load_video(
184169
"https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4"
185170
)
186171

187-
video_object2 = {
188-
"frames": frames,
189-
"metadata": metadata,
190-
}
191-
192172
messages = [
193173
{
194174
"role": "system",

src/transformers/processing_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,12 @@
3838
from .feature_extraction_utils import BatchFeature
3939
from .image_utils import ChannelDimension, is_vision_available, load_image
4040
from .utils.chat_template_utils import render_jinja_template
41-
from .video_utils import VideoMetadata, convert_pil_frames_to_video, load_video
41+
from .video_utils import VideoMetadata, load_video
4242

4343

4444
if is_vision_available():
4545
from .image_utils import PILImageResampling
4646

47-
if is_torch_available():
48-
import torch
4947

5048
from .tokenization_utils_base import (
5149
PaddingStrategy,
@@ -68,7 +66,6 @@
6866
download_url,
6967
is_offline_mode,
7068
is_remote_url,
71-
is_torch_available,
7269
list_repo_templates,
7370
logging,
7471
)
@@ -1578,11 +1575,6 @@ def apply_chat_template(
15781575
fname,
15791576
backend=mm_load_kwargs["video_load_backend"],
15801577
)
1581-
if metadata is None:
1582-
logger.warning(
1583-
"When loading the video from list of decoded frames, we cannot infer metadata such as `fps` or `duration`. "
1584-
"If your model requires metadata during processing, please load the whole video and let the processor sample frames instead."
1585-
)
15861578
videos.append(video)
15871579
video_metadata.append(metadata)
15881580

src/transformers/video_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,18 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
563563

564564
sample_indices_fn = sample_indices_fn_func
565565

566-
if isinstance(video, Union[np.ndarray, torch.Tensor]):
566+
if is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
567567
if not is_valid_video(video):
568568
raise ValueError(
569569
f"When passing video as decoded frames, video should be a 4D numpy array or torch tensor, but got {video.ndim} dimensions instead."
570570
)
571571
# Case 1: Video is provided as a 4D numpy array or torch tensor (frames, height, width, channels)
572572
if is_torch_tensor(video):
573573
video = video.numpy() # Convert torch tensor to numpy array
574+
logger.warning(
575+
"When loading the video from list of decoded frames, we cannot infer metadata such as `fps` or `duration`. "
576+
"If your model requires metadata during processing, please load the whole video and let the processor sample frames instead."
577+
)
574578
return video, None
575579

576580
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:

tests/test_processing_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@
6363
from transformers.video_utils import load_video
6464

6565
# load a video file in memory for testing
66-
video, _ = load_video(
67-
"https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4"
68-
)
66+
video, _ = load_video("https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4")
6967
MODALITY_INPUT_DATA["videos"].append(video)
7068

7169

0 commit comments

Comments
 (0)