Skip to content

Commit 4d71290

Browse files
authored
[fix] Realtime CLI demo energy threshold (#2166)
1 parent 9f96338 commit 4d71290

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

examples/realtime/cli/demo.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking
2626
PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks)
2727
FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting
28+
PLAYBACK_ECHO_MARGIN = 0.002 # extra energy above playback echo required to count as speech
2829

2930
# Set up logging for OpenAI agents SDK
3031
# logging.basicConfig(
@@ -78,6 +79,7 @@ def __init__(self) -> None:
7879
self.fade_total_samples = 0
7980
self.fade_done_samples = 0
8081
self.fade_samples = int(SAMPLE_RATE * (FADE_OUT_MS / 1000.0))
82+
self.playback_rms = 0.0 # smoothed playback energy to filter out echo
8183

8284
def _output_callback(self, outdata, frames: int, time, status) -> None:
8385
"""Callback for audio output - handles continuous audio stream from server."""
@@ -123,6 +125,7 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:
123125
gain = 1.0 - (idx / float(self.fade_total_samples))
124126
ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16)
125127
outdata[samples_filled : samples_filled + n, 0] = ramped
128+
self._update_playback_rms(ramped)
126129

127130
# Optionally report played bytes (ramped) to playback tracker
128131
try:
@@ -183,6 +186,7 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:
183186
chunk_data = samples[self.chunk_position : self.chunk_position + samples_to_copy]
184187
# More efficient: direct assignment for mono audio instead of reshape
185188
outdata[samples_filled : samples_filled + samples_to_copy, 0] = chunk_data
189+
self._update_playback_rms(chunk_data)
186190
samples_filled += samples_to_copy
187191
self.chunk_position += samples_to_copy
188192

@@ -273,14 +277,6 @@ async def capture_audio(self) -> None:
273277
read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S)
274278

275279
try:
276-
# Simple energy-based barge-in: if user speaks while audio is playing, interrupt.
277-
def rms_energy(samples: np.ndarray[Any, np.dtype[Any]]) -> float:
278-
if samples.size == 0:
279-
return 0.0
280-
# Normalize int16 to [-1, 1]
281-
x = samples.astype(np.float32) / 32768.0
282-
return float(np.sqrt(np.mean(x * x)))
283-
284280
while self.recording:
285281
# Check if there's enough data to read
286282
if self.audio_stream.read_available < read_size:
@@ -300,7 +296,13 @@ def rms_energy(samples: np.ndarray[Any, np.dtype[Any]]) -> float:
300296
if assistant_playing:
301297
# Compute RMS energy to detect speech while assistant is talking
302298
samples = data.reshape(-1)
303-
if rms_energy(samples) >= ENERGY_THRESHOLD:
299+
mic_rms = self._compute_rms(samples)
300+
# Require the mic to be louder than the echo of the assistant playback.
301+
playback_gate = max(
302+
ENERGY_THRESHOLD,
303+
self.playback_rms * 0.6 + PLAYBACK_ECHO_MARGIN,
304+
)
305+
if mic_rms >= playback_gate:
304306
# Locally flush queued assistant audio for snappier interruption.
305307
self.interrupt_event.set()
306308
await self.session.send_audio(audio_bytes)
@@ -356,6 +358,18 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None:
356358
except Exception as e:
357359
print(f"Error processing event: {_truncate_str(str(e), 200)}")
358360

361+
def _compute_rms(self, samples: np.ndarray[Any, np.dtype[Any]]) -> float:
362+
"""Compute RMS energy for int16 samples normalized to [-1, 1]."""
363+
if samples.size == 0:
364+
return 0.0
365+
x = samples.astype(np.float32) / 32768.0
366+
return float(np.sqrt(np.mean(x * x)))
367+
368+
def _update_playback_rms(self, samples: np.ndarray[Any, np.dtype[Any]]) -> None:
369+
"""Keep a smoothed estimate of playback energy to filter out echo feedback."""
370+
sample_rms = self._compute_rms(samples)
371+
self.playback_rms = 0.9 * self.playback_rms + 0.1 * sample_rms
372+
359373

360374
if __name__ == "__main__":
361375
demo = NoUIDemo()

0 commit comments

Comments
 (0)