Skip to content
14 changes: 14 additions & 0 deletions dspy/adapters/types/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
SF_AVAILABLE = False


def _normalize_audio_format(audio_format: str) -> str:
"""Removes 'x-' prefixes from audio format strings."""
return audio_format.removeprefix("x-")


class Audio(Type):
data: str
audio_format: str
Expand Down Expand Up @@ -61,6 +66,9 @@ def from_url(cls, url: str) -> "Audio":
if not mime_type.startswith("audio/"):
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")
audio_format = mime_type.split("/")[1]

audio_format = _normalize_audio_format(audio_format)

encoded_data = base64.b64encode(response.content).decode("utf-8")
return cls(data=encoded_data, audio_format=audio_format)

Expand All @@ -80,6 +88,9 @@ def from_file(cls, file_path: str) -> "Audio":
file_data = file.read()

audio_format = mime_type.split("/")[1]

audio_format = _normalize_audio_format(audio_format)

encoded_data = base64.b64encode(file_data).decode("utf-8")
return cls(data=encoded_data, audio_format=audio_format)

Expand Down Expand Up @@ -126,6 +137,9 @@ def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: in
header, b64data = audio.split(",", 1)
mime = header.split(";")[0].split(":")[1]
audio_format = mime.split("/")[1]

audio_format = _normalize_audio_format(audio_format)

return {"data": b64data, "audio_format": audio_format}
except Exception as e:
raise ValueError(f"Malformed audio data URI: {e}")
Expand Down
32 changes: 32 additions & 0 deletions tests/adapters/test_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from dspy.adapters.types.audio import _normalize_audio_format


@pytest.mark.parametrize(
"input_format, expected_format",
[
# Case 1: Standard format (no change)
("wav", "wav"),
("mp3", "mp3"),

# Case 2: The 'x-' prefix
("x-wav", "wav"),
("x-mp3", "mp3"),
("x-flac", "flac"),

# Case 3: The edge case
("my-x-format", "my-x-format"),
("x-my-format", "my-format"),

# Case 4: Empty string and edge cases
("", ""),
("x-", ""),
],
)
def test_normalize_audio_format(input_format, expected_format):
"""
Tests that the _normalize_audio_format helper correctly removes 'x-' prefixes.
This single test covers the logic for from_url, from_file, and encode_audio.
"""
assert _normalize_audio_format(input_format) == expected_format