Skip to content

Commit 59c90e7

Browse files
authored
Merge pull request #197 from pengzhiliang/vibevoice_asr_ft
add VibeVoice-ASR finetuning code
2 parents 875115c + 8516386 commit 59c90e7

File tree

8 files changed

+1100
-2
lines changed

8 files changed

+1100
-2
lines changed

finetuning/README.md

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# VibeVoice ASR LoRA Fine-tuning
2+
3+
This directory contains scripts for LoRA (Low-Rank Adaptation) fine-tuning of the VibeVoice ASR model.
4+
5+
## Requirements
6+
7+
```bash
8+
# you need to install vibevoice first
9+
# pip install -e .[asr]
10+
11+
pip install peft
12+
```
13+
14+
## Toy Dataset
15+
16+
> **Note**: The `toy_dataset/` included in this directory contains **synthetic audio generated by VibeVoice TTS** for demonstration purposes only. It is NOT a production-quality dataset.
17+
>
18+
> When using your own data, you should:
19+
> - Prepare real audio recordings with accurate transcriptions
20+
> - Adjust hyperparameters (learning rate, epochs, LoRA rank) based on your dataset size and domain
21+
> - Consider the audio quality and speaker diversity in your data
22+
23+
## Data Format
24+
25+
Training data should be organized as pairs of audio files and JSON labels in the same directory:
26+
27+
```
28+
toy_dataset/
29+
├── 0.mp3
30+
├── 0.json
31+
├── 1.mp3
32+
├── 1.json
33+
└── ...
34+
```
35+
36+
### JSON Label Format
37+
38+
Each JSON file should have the following structure:
39+
40+
```json
41+
{
42+
"audio_duration": 351.73,
43+
"audio_path": "0.mp3",
44+
"segments": [
45+
{
46+
"speaker": 0,
47+
"text": "Hey everyone, welcome back...",
48+
"start": 0.0,
49+
"end": 38.68
50+
},
51+
{
52+
"speaker": 1,
53+
"text": "Thanks for having me...",
54+
"start": 38.75,
55+
"end": 77.88
56+
}
57+
],
58+
"customized_context": ["Tea Brew", "Aiden Host", "The property is near Meter Street."] // optional, domain-specific terms or context sentences
59+
}
60+
```
61+
62+
## Training
63+
64+
### Basic
65+
66+
```bash
67+
# 1 GPU
68+
torchrun --nproc_per_node=1 lora_finetune.py \
69+
--model_path microsoft/VibeVoice-ASR \
70+
--data_dir ./toy_dataset \
71+
--output_dir ./output \
72+
--num_train_epochs 3 \
73+
--per_device_train_batch_size 1 \
74+
--learning_rate 1e-4 \
75+
--bf16 \
76+
--report_to none
77+
78+
# Specific GPUs (e.g., GPU 0,1,2,3)
79+
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 lora_finetune.py \
80+
--model_path microsoft/VibeVoice-ASR \
81+
--data_dir ./toy_dataset \
82+
--output_dir ./output \
83+
--num_train_epochs 3 \
84+
--per_device_train_batch_size 1 \
85+
--learning_rate 1e-4 \
86+
--bf16 \
87+
--report_to none
88+
```
89+
90+
### Full Options
91+
92+
The script uses HuggingFace's `TrainingArguments`, so all standard options are available:
93+
94+
```bash
95+
torchrun --nproc_per_node=4 lora_finetune.py \
96+
--model_path microsoft/VibeVoice-ASR \
97+
--data_dir ./toy_dataset \
98+
--output_dir ./output \
99+
--lora_r 16 \
100+
--lora_alpha 32 \
101+
--lora_dropout 0.05 \
102+
--num_train_epochs 3 \
103+
--per_device_train_batch_size 1 \
104+
--gradient_accumulation_steps 4 \
105+
--learning_rate 1e-4 \
106+
--warmup_ratio 0.1 \
107+
--weight_decay 0.01 \
108+
--max_grad_norm 1.0 \
109+
--logging_steps 10 \
110+
--save_steps 100 \
111+
--gradient_checkpointing \
112+
--bf16 \
113+
--report_to none
114+
```
115+
116+
### Key Parameters
117+
118+
| Parameter | Default | Description |
119+
|-----------|---------|-------------|
120+
| `--lora_r` | 16 | LoRA rank (lower = fewer params, higher = more expressive) |
121+
| `--lora_alpha` | 32 | LoRA scaling factor (typically 2x rank) |
122+
| `--lora_dropout` | 0.05 | Dropout for LoRA layers |
123+
| `--per_device_train_batch_size` | 8 | Batch size per device |
124+
| `--gradient_accumulation_steps` | 1 | Effective batch size = batch_size × grad_accum |
125+
| `--learning_rate` | 5e-5 | Learning rate (1e-4 to 2e-4 typical for LoRA) |
126+
| `--gradient_checkpointing` | False | Enable to reduce memory usage |
127+
| `--use_customized_context` | True | Include customized_context from JSON as additional context |
128+
| `--max_audio_length` | None | Skip audio longer than this (seconds) |
129+
130+
## Inference with Fine-tuned Model
131+
132+
```bash
133+
python inference_lora.py \
134+
--base_model microsoft/VibeVoice-ASR \
135+
--lora_path ./output \
136+
--audio_file ./toy_dataset/0.mp3 \
137+
--context_info "Tea Brew, Aiden Host"
138+
```
139+
140+
## Merging LoRA Weights (Optional)
141+
142+
To merge LoRA weights into the base model for faster inference:
143+
144+
```python
145+
from peft import PeftModel
146+
147+
# Load base model + LoRA
148+
model = VibeVoiceASRForConditionalGeneration.from_pretrained("microsoft/VibeVoice-ASR", ...)
149+
model = PeftModel.from_pretrained(model, "./output")
150+
151+
# Merge and save
152+
model = model.merge_and_unload()
153+
model.save_pretrained("./merged_model")
154+
```

finetuning/inference_lora.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#!/usr/bin/env python
2+
"""
3+
Inference with LoRA Fine-tuned VibeVoice ASR Model
4+
5+
This script loads a LoRA fine-tuned model and runs inference.
6+
7+
Usage:
8+
python inference_lora.py \
9+
--base_model microsoft/VibeVoice-ASR \
10+
--lora_path ./output \
11+
--audio_file ./toy_dataset/0.mp3
12+
"""
13+
14+
import argparse
15+
import torch
16+
17+
from peft import PeftModel
18+
19+
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
20+
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
21+
22+
23+
def load_lora_model(
24+
base_model_path: str,
25+
lora_path: str,
26+
device: str = "cuda",
27+
dtype: torch.dtype = torch.bfloat16,
28+
):
29+
"""
30+
Load base model and merge with LoRA weights.
31+
32+
Args:
33+
base_model_path: Path to base pretrained model
34+
lora_path: Path to LoRA adapter weights
35+
device: Device to load model on
36+
dtype: Data type for model
37+
38+
Returns:
39+
Tuple of (model, processor)
40+
"""
41+
print(f"Loading base model from {base_model_path}")
42+
43+
# Load processor
44+
processor = VibeVoiceASRProcessor.from_pretrained(
45+
base_model_path,
46+
language_model_pretrained_name="Qwen/Qwen2.5-7B"
47+
)
48+
49+
# Load base model
50+
model = VibeVoiceASRForConditionalGeneration.from_pretrained(
51+
base_model_path,
52+
dtype=dtype,
53+
device_map=device if device == "auto" else None,
54+
attn_implementation="flash_attention_2",
55+
trust_remote_code=True,
56+
)
57+
58+
if device != "auto":
59+
model = model.to(device)
60+
61+
# Load LoRA adapter
62+
print(f"Loading LoRA adapter from {lora_path}")
63+
model = PeftModel.from_pretrained(model, lora_path)
64+
65+
# Optionally merge LoRA weights into base model for faster inference
66+
# model = model.merge_and_unload()
67+
68+
model.eval()
69+
print("Model loaded successfully")
70+
71+
return model, processor
72+
73+
74+
def transcribe(
75+
model,
76+
processor,
77+
audio_path: str,
78+
max_new_tokens: int = 4096,
79+
temperature: float = 0.0,
80+
context_info: str = None,
81+
device: str = "cuda",
82+
):
83+
"""
84+
Transcribe an audio file using the LoRA fine-tuned model.
85+
86+
Args:
87+
model: The LoRA fine-tuned model
88+
processor: The processor
89+
audio_path: Path to audio file
90+
max_new_tokens: Maximum tokens to generate
91+
temperature: Sampling temperature (0 = greedy)
92+
context_info: Optional context info (e.g., hotwords)
93+
device: Device
94+
95+
Returns:
96+
Transcription result
97+
"""
98+
print(f"\nTranscribing: {audio_path}")
99+
100+
# Process audio
101+
inputs = processor(
102+
audio=audio_path,
103+
sampling_rate=None,
104+
return_tensors="pt",
105+
padding=True,
106+
add_generation_prompt=True,
107+
context_info=context_info,
108+
)
109+
110+
# Move to device
111+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
112+
for k, v in inputs.items()}
113+
114+
# Generation config
115+
gen_config = {
116+
"max_new_tokens": max_new_tokens,
117+
"pad_token_id": processor.pad_id,
118+
"eos_token_id": processor.tokenizer.eos_token_id,
119+
"do_sample": temperature > 0,
120+
}
121+
if temperature > 0:
122+
gen_config["temperature"] = temperature
123+
gen_config["top_p"] = 0.9
124+
125+
# Generate
126+
with torch.no_grad():
127+
output_ids = model.generate(**inputs, **gen_config)
128+
129+
# Decode
130+
input_length = inputs['input_ids'].shape[1]
131+
generated_ids = output_ids[0, input_length:]
132+
generated_text = processor.decode(generated_ids, skip_special_tokens=True)
133+
134+
# Parse structured output
135+
try:
136+
segments = processor.post_process_transcription(generated_text)
137+
except Exception as e:
138+
print(f"Warning: Failed to parse structured output: {e}")
139+
segments = []
140+
141+
return {
142+
"raw_text": generated_text,
143+
"segments": segments,
144+
}
145+
146+
147+
def main():
148+
parser = argparse.ArgumentParser(description="Inference with LoRA Fine-tuned VibeVoice ASR")
149+
150+
parser.add_argument(
151+
"--base_model",
152+
type=str,
153+
default="microsoft/VibeVoice-ASR",
154+
help="Path to base pretrained model"
155+
)
156+
parser.add_argument(
157+
"--lora_path",
158+
type=str,
159+
required=True,
160+
help="Path to LoRA adapter weights"
161+
)
162+
parser.add_argument(
163+
"--audio_file",
164+
type=str,
165+
required=True,
166+
help="Path to audio file to transcribe"
167+
)
168+
parser.add_argument(
169+
"--context_info",
170+
type=str,
171+
default=None,
172+
help="Optional context info (e.g., 'Hotwords: Tea Brew, Aiden Host')"
173+
)
174+
parser.add_argument(
175+
"--max_new_tokens",
176+
type=int,
177+
default=4096,
178+
help="Maximum tokens to generate"
179+
)
180+
parser.add_argument(
181+
"--temperature",
182+
type=float,
183+
default=0.0,
184+
help="Sampling temperature (0 = greedy)"
185+
)
186+
parser.add_argument(
187+
"--device",
188+
type=str,
189+
default="cuda" if torch.cuda.is_available() else "cpu",
190+
help="Device to use"
191+
)
192+
193+
args = parser.parse_args()
194+
195+
# Load model
196+
dtype = torch.bfloat16 if args.device != "cpu" else torch.float32
197+
model, processor = load_lora_model(
198+
base_model_path=args.base_model,
199+
lora_path=args.lora_path,
200+
device=args.device,
201+
dtype=dtype,
202+
)
203+
204+
# Transcribe
205+
result = transcribe(
206+
model=model,
207+
processor=processor,
208+
audio_path=args.audio_file,
209+
max_new_tokens=args.max_new_tokens,
210+
temperature=args.temperature,
211+
context_info=args.context_info,
212+
device=args.device,
213+
)
214+
215+
# Print results
216+
print("\n" + "="*60)
217+
print("Transcription Result")
218+
print("="*60)
219+
220+
print("\n--- Raw Output ---")
221+
raw_text = result['raw_text']
222+
print(raw_text[:2000] + "..." if len(raw_text) > 2000 else raw_text)
223+
224+
if result['segments']:
225+
print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
226+
for seg in result['segments'][:20]:
227+
print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
228+
f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')[:80]}...")
229+
if len(result['segments']) > 20:
230+
print(f" ... and {len(result['segments']) - 20} more segments")
231+
232+
233+
if __name__ == "__main__":
234+
main()

0 commit comments

Comments
 (0)