-
Notifications
You must be signed in to change notification settings - Fork 436
Description
For a 2B wav2vec2 with hidden layer size 1920 passing an input of 1920 samples (~0.12s at 16khz) with "return_hidden": true in config.json the hueristic check of dimensions returns true for unencoded data.
import ctranslate2
import numpy as np
converted_model_path = "ctranslate2_model"
converter = ctranslate2.converters.TransformersConverter("model")
converter.convert(converted_model_path)
model = ctranslate2.models.Wav2Vec2(
converted_model_path,
device="cpu",
device_index=[0],
compute_type="float32",
intra_threads=0,
inter_threads=1,
)
# 1920 samples = 0.12 seconds at 16khz
array = np.random.uniform(low=-1, high=1, size=(1, 1, 1920)).astype(np.float32)
inputs = ctranslate2.StorageView.from_array(array)
output = model.encode(inputs)
print("Input: ", inputs)
print("Output: ", output)
print("The same: ", np.array_equal(inputs, output))produces the output
Input: -0.549616 0.54811 0.688297 ... -0.316497 -0.0261584 -0.179784
[cpu:0 float32 storage viewed as 1x1x1920]
Output: -0.549616 0.54811 0.688297 ... -0.316497 -0.0261584 -0.179784
[cpu:0 float32 storage viewed as 1x1x1920]
The same: True
The only place this function is called from is here and seems copy pasted between Whisper, Wav2Vec2 and Wav2Vec2Bart.
To be honest I don't understand the decision to include this since i couldn't find any recursive calls for which this might be the base case which is what this initially reminded me of, and it seems unlikely that a user will pass in already encoded data (in which case failing seems like a sane thing to do). And including it breaks the above mentioned edge case.
I'll be happy to open a PR with the deletion of it from all 3 models if someone confirms.
Cheers!