diff --git a/dspy/adapters/types/audio.py b/dspy/adapters/types/audio.py index f39d60bda7..0ceb734b73 100644 --- a/dspy/adapters/types/audio.py +++ b/dspy/adapters/types/audio.py @@ -17,6 +17,11 @@ SF_AVAILABLE = False +def _normalize_audio_format(audio_format: str) -> str: + """Removes 'x-' prefixes from audio format strings.""" + return audio_format.removeprefix("x-") + + class Audio(Type): data: str audio_format: str @@ -61,6 +66,9 @@ def from_url(cls, url: str) -> "Audio": if not mime_type.startswith("audio/"): raise ValueError(f"Unsupported MIME type for audio: {mime_type}") audio_format = mime_type.split("/")[1] + + audio_format = _normalize_audio_format(audio_format) + encoded_data = base64.b64encode(response.content).decode("utf-8") return cls(data=encoded_data, audio_format=audio_format) @@ -80,6 +88,9 @@ def from_file(cls, file_path: str) -> "Audio": file_data = file.read() audio_format = mime_type.split("/")[1] + + audio_format = _normalize_audio_format(audio_format) + encoded_data = base64.b64encode(file_data).decode("utf-8") return cls(data=encoded_data, audio_format=audio_format) @@ -126,6 +137,9 @@ def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: in header, b64data = audio.split(",", 1) mime = header.split(";")[0].split(":")[1] audio_format = mime.split("/")[1] + + audio_format = _normalize_audio_format(audio_format) + return {"data": b64data, "audio_format": audio_format} except Exception as e: raise ValueError(f"Malformed audio data URI: {e}") diff --git a/tests/adapters/test_audio.py b/tests/adapters/test_audio.py new file mode 100644 index 0000000000..b7597fbd29 --- /dev/null +++ b/tests/adapters/test_audio.py @@ -0,0 +1,32 @@ +import pytest + +from dspy.adapters.types.audio import _normalize_audio_format + + +@pytest.mark.parametrize( + "input_format, expected_format", + [ + # Case 1: Standard format (no change) + ("wav", "wav"), + ("mp3", "mp3"), + + # Case 2: The 'x-' prefix + ("x-wav", "wav"), + ("x-mp3", "mp3"), + ("x-flac", "flac"), + + # Case 3: The edge case + ("my-x-format", "my-x-format"), + ("x-my-format", "my-format"), + + # Case 4: Empty string and edge cases + ("", ""), + ("x-", ""), + ], +) +def test_normalize_audio_format(input_format, expected_format): + """ + Tests that the _normalize_audio_format helper correctly removes 'x-' prefixes. + This single test covers the logic for from_url, from_file, and encode_audio. + """ + assert _normalize_audio_format(input_format) == expected_format