Skip to content

Commit 3b3f9c0

Browse files
rev2607eustlb
andauthored
fix(voxtral): correct typo in apply_transcription_request (#39572)
* fix(voxtral): correct typo in apply_transcription_request * temporary wrapper: apply_transcrition_request * Update processing_voxtral.py * style: sort imports in processing_voxtral.py * docs(voxtral): fix typo in voxtral.md * make style * doc update --------- Co-authored-by: eustlb <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]>
1 parent 2a82cf0 commit 3b3f9c0

File tree

3 files changed

+76
-13
lines changed

3 files changed

+76
-13
lines changed

docs/source/en/model_doc/voxtral.md

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ Voxtral builds on Ministral-3B by adding audio processing capabilities:
3737

3838
## Usage
3939

40-
Let's first load the model!
40+
### Audio Instruct Mode
41+
42+
The model supports audio-text instructions, including multi-turn and multi-audio interactions, all processed in batches.
43+
44+
➡️ audio + text instruction
4145
```python
4246
from transformers import VoxtralForConditionalGeneration, AutoProcessor
4347
import torch
@@ -47,14 +51,7 @@ repo_id = "mistralai/Voxtral-Mini-3B-2507"
4751

4852
processor = AutoProcessor.from_pretrained(repo_id)
4953
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
50-
```
51-
52-
### Audio Instruct Mode
5354

54-
The model supports audio-text instructions, including multi-turn and multi-audio interactions, all processed in batches.
55-
56-
➡️ audio + text instruction
57-
```python
5855
conversation = [
5956
{
6057
"role": "user",
@@ -82,6 +79,15 @@ print("=" * 80)
8279

8380
➡️ multi-audio + text instruction
8481
```python
82+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
83+
import torch
84+
85+
device = "cuda" if torch.cuda.is_available() else "cpu"
86+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
87+
88+
processor = AutoProcessor.from_pretrained(repo_id)
89+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
90+
8591
conversation = [
8692
{
8793
"role": "user",
@@ -113,6 +119,15 @@ print("=" * 80)
113119

114120
➡️ multi-turn:
115121
```python
122+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
123+
import torch
124+
125+
device = "cuda" if torch.cuda.is_available() else "cpu"
126+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
127+
128+
processor = AutoProcessor.from_pretrained(repo_id)
129+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
130+
116131
conversation = [
117132
{
118133
"role": "user",
@@ -158,6 +173,15 @@ print("=" * 80)
158173

159174
➡️ text only:
160175
```python
176+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
177+
import torch
178+
179+
device = "cuda" if torch.cuda.is_available() else "cpu"
180+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
181+
182+
processor = AutoProcessor.from_pretrained(repo_id)
183+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
184+
161185
conversation = [
162186
{
163187
"role": "user",
@@ -184,6 +208,15 @@ print("=" * 80)
184208

185209
➡️ audio only:
186210
```python
211+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
212+
import torch
213+
214+
device = "cuda" if torch.cuda.is_available() else "cpu"
215+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
216+
217+
processor = AutoProcessor.from_pretrained(repo_id)
218+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
219+
187220
conversation = [
188221
{
189222
"role": "user",
@@ -210,6 +243,15 @@ print("=" * 80)
210243

211244
➡️ batched inference!
212245
```python
246+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
247+
import torch
248+
249+
device = "cuda" if torch.cuda.is_available() else "cpu"
250+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
251+
252+
processor = AutoProcessor.from_pretrained(repo_id)
253+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
254+
213255
conversations = [
214256
[
215257
{
@@ -262,7 +304,16 @@ for decoded_output in decoded_outputs:
262304
Use the model to transcribe audio (supports English, Spanish, French, Portuguese, Hindi, German, Dutch, Italian)!
263305

264306
```python
265-
inputs = processor.apply_transcrition_request(language="en", audio="https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3")
307+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
308+
import torch
309+
310+
device = "cuda" if torch.cuda.is_available() else "cpu"
311+
repo_id = "mistralai/Voxtral-Mini-3B-2507"
312+
313+
processor = AutoProcessor.from_pretrained(repo_id)
314+
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map=device)
315+
316+
inputs = processor.apply_transcription_request(language="en", audio="https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3", model_id=repo_id)
266317
inputs = inputs.to(device, dtype=torch.bfloat16)
267318

268319
outputs = model.generate(**inputs, max_new_tokens=500)

src/transformers/models/voxtral/processing_voxtral.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import io
17+
import warnings
1718
from typing import Optional, Union
1819

1920
from ...utils import is_mistral_common_available, is_soundfile_available, is_torch_available, logging
@@ -242,7 +243,7 @@ def __call__(
242243
the text. Please refer to the docstring of the above methods for more information.
243244
This methods does not support audio. To prepare the audio, please use:
244245
1. `apply_chat_template` [`~VoxtralProcessor.apply_chat_template`] method.
245-
2. `apply_transcrition_request` [`~VoxtralProcessor.apply_transcrition_request`] method.
246+
2. `apply_transcription_request` [`~VoxtralProcessor.apply_transcription_request`] method.
246247
247248
Args:
248249
text (`str`, `list[str]`, `list[list[str]]`):
@@ -284,7 +285,7 @@ def __call__(
284285
return BatchFeature(data=out, tensor_type=common_kwargs.pop("return_tensors", None))
285286

286287
# TODO: @eustlb, this should be moved to mistral_common + testing
287-
def apply_transcrition_request(
288+
def apply_transcription_request(
288289
self,
289290
language: Union[str, list[str]],
290291
audio: Union[str, list[str], AudioInput],
@@ -306,7 +307,7 @@ def apply_transcrition_request(
306307
language = "en"
307308
audio = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3"
308309
309-
inputs = processor.apply_transcrition_request(language=language, audio=audio, model_id=model_id)
310+
inputs = processor.apply_transcription_request(language=language, audio=audio, model_id=model_id)
310311
```
311312
312313
Args:
@@ -431,6 +432,17 @@ def apply_transcrition_request(
431432

432433
return texts
433434

435+
# Deprecated typo'd method for backward compatibility
436+
def apply_transcrition_request(self, *args, **kwargs):
437+
"""
438+
Deprecated typo'd method. Use `apply_transcription_request` instead.
439+
"""
440+
warnings.warn(
441+
"`apply_transcrition_request` is deprecated due to a typo and will be removed in a future release. Please use `apply_transcription_request` instead.",
442+
FutureWarning,
443+
)
444+
return self.apply_transcription_request(*args, **kwargs)
445+
434446
def batch_decode(self, *args, **kwargs):
435447
"""
436448
This method forwards all its arguments to MistralCommonTokenizer's [`~MistralCommonTokenizer.batch_decode`]. Please

tests/models/voxtral/test_modeling_voxtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def test_transcribe_mode_audio_input(self):
493493
model = VoxtralForConditionalGeneration.from_pretrained(
494494
self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device
495495
)
496-
inputs = self.processor.apply_transcrition_request(
496+
inputs = self.processor.apply_transcription_request(
497497
language="en",
498498
audio="https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3",
499499
model_id=self.checkpoint_name,

0 commit comments

Comments
 (0)