diff --git a/inference/core/interfaces/webrtc_worker/entities.py b/inference/core/interfaces/webrtc_worker/entities.py index 3d2a0e0136..cfee2840ce 100644 --- a/inference/core/interfaces/webrtc_worker/entities.py +++ b/inference/core/interfaces/webrtc_worker/entities.py @@ -39,6 +39,9 @@ class WebRTCWorkerRequest(BaseModel): data_output: Optional[List[str]] = Field(default=None) declared_fps: Optional[float] = None rtsp_url: Optional[str] = None + use_data_channel_frames: bool = ( + False # When True, expect frames via data channel instead of media track + ) processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT processing_session_started: Optional[datetime.datetime] = None requested_plan: Optional[str] = "webrtc-gpu-small" diff --git a/inference/core/interfaces/webrtc_worker/webrtc.py b/inference/core/interfaces/webrtc_worker/webrtc.py index b899d510be..a990d528c4 100644 --- a/inference/core/interfaces/webrtc_worker/webrtc.py +++ b/inference/core/interfaces/webrtc_worker/webrtc.py @@ -2,8 +2,11 @@ import datetime import json import logging +import struct from typing import Any, Callable, Dict, List, Optional, Tuple +import cv2 +import numpy as np from aiortc import ( RTCConfiguration, RTCDataChannel, @@ -60,6 +63,113 @@ logging.getLogger("aiortc").setLevel(logging.WARNING) +# WebRTC data channel chunking configuration +CHUNK_SIZE = 48 * 1024 # 48KB - safe for all WebRTC implementations + + +def create_chunked_binary_message( + frame_id: int, chunk_index: int, total_chunks: int, payload: bytes +) -> bytes: + """Create a binary message with standard 12-byte header. + + Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N] + All integers are uint32 little-endian. + """ + header = struct.pack(" Tuple[int, int, int, bytes]: + """Parse a binary message with standard 12-byte header. + + Returns: (frame_id, chunk_index, total_chunks, payload) + """ + if len(message) < 12: + raise ValueError(f"Message too short: {len(message)} bytes (expected >= 12)") + + frame_id, chunk_index, total_chunks = struct.unpack(" Optional[bytes]: + """Add a chunk and return complete payload if all chunks received. + + Returns: + Complete reassembled payload bytes if all chunks received, None otherwise. + """ + # Initialize buffers for new frame + if frame_id not in self._chunks: + self._chunks[frame_id] = {} + self._total[frame_id] = total_chunks + + # Store chunk + self._chunks[frame_id][chunk_index] = chunk_data + + # Check if all chunks received + if len(self._chunks[frame_id]) >= total_chunks: + # Reassemble in order + complete_payload = b"".join( + self._chunks[frame_id][i] for i in range(total_chunks) + ) + + # Clean up + del self._chunks[frame_id] + del self._total[frame_id] + + return complete_payload + + return None + + +def send_chunked_data( + data_channel: RTCDataChannel, + frame_id: int, + payload_bytes: bytes, + chunk_size: int = CHUNK_SIZE, +) -> None: + """Send payload via data channel, automatically chunking if needed. + + Args: + data_channel: RTCDataChannel to send on + frame_id: Frame identifier + payload_bytes: Data to send (JPEG, JSON UTF-8, etc.) + chunk_size: Maximum chunk size (default 48KB) + """ + if data_channel.readyState != "open": + logger.warning(f"Cannot send response for frame {frame_id}, channel not open") + return + + total_chunks = ( + len(payload_bytes) + chunk_size - 1 + ) // chunk_size # Ceiling division + + if frame_id % 100 == 1: + logger.info( + f"Sending response for frame {frame_id}: {total_chunks} chunk(s), {len(payload_bytes)} bytes" + ) + + for chunk_index in range(total_chunks): + start = chunk_index * chunk_size + end = min(start + chunk_size, len(payload_bytes)) + chunk_data = payload_bytes[start:end] + + message = create_chunked_binary_message( + frame_id, chunk_index, total_chunks, chunk_data + ) + data_channel.send(message) + class RTCPeerConnectionWithLoop(RTCPeerConnection): def __init__( @@ -91,6 +201,7 @@ def __init__( declared_fps: float = 30, termination_date: Optional[datetime.datetime] = None, terminate_event: Optional[asyncio.Event] = None, + use_data_channel_frames: bool = False, ): self._loop = asyncio_loop self._termination_date = termination_date @@ -101,6 +212,11 @@ def __init__( self._received_frames = 0 self._declared_fps = declared_fps self._stop_processing = False + self.use_data_channel_frames = use_data_channel_frames + self._data_frame_queue: "asyncio.Queue[Optional[VideoFrame]]" = asyncio.Queue() + self._chunk_reassembler = ( + ChunkReassembler() + ) # For reassembling inbound frame chunks self.has_video_track = has_video_track self.stream_output = stream_output @@ -185,7 +301,9 @@ async def _send_data_output( ) if self._data_mode == DataOutputMode.NONE: - self.data_channel.send(json.dumps(webrtc_output.model_dump())) + # Even empty responses use binary protocol + json_bytes = json.dumps(webrtc_output.model_dump()).encode("utf-8") + send_chunked_data(self.data_channel, self._received_frames, json_bytes) return if self._data_mode == DataOutputMode.ALL: @@ -216,11 +334,55 @@ async def _send_data_output( webrtc_output.errors.append(f"{field_name}: {e}") serialized_outputs[field_name] = {"__serialization_error__": str(e)} - # Only set serialized_output_data if we have data to send + # Set serialized outputs if serialized_outputs: webrtc_output.serialized_output_data = serialized_outputs - self.data_channel.send(json.dumps(webrtc_output.model_dump(mode="json"))) + # Send using binary chunked protocol + json_bytes = json.dumps(webrtc_output.model_dump(mode="json")).encode("utf-8") + send_chunked_data(self.data_channel, self._received_frames, json_bytes) + + async def _handle_data_channel_frame(self, message: bytes) -> None: + """Handle incoming binary frame chunk from upstream_frames data channel. + + Uses standard binary protocol with 12-byte header + JPEG chunk payload. + """ + try: + # Parse message + frame_id, chunk_index, total_chunks, jpeg_chunk = ( + parse_chunked_binary_message(message) + ) + + # Add chunk and check if complete + jpeg_bytes = self._chunk_reassembler.add_chunk( + frame_id, chunk_index, total_chunks, jpeg_chunk + ) + + if jpeg_bytes is None: + # Still waiting for more chunks + return + + # All chunks received - decode and queue frame + if frame_id % 100 == 1: + logger.info( + f"Received frame {frame_id}: {total_chunks} chunk(s), {len(jpeg_bytes)} bytes JPEG" + ) + + nparr = np.frombuffer(jpeg_bytes, np.uint8) + np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if np_image is None: + logger.error(f"Failed to decode JPEG for frame {frame_id}") + return + + video_frame = VideoFrame.from_ndarray(np_image, format="bgr24") + await self._data_frame_queue.put((frame_id, video_frame)) + + if frame_id % 100 == 1: + logger.info(f"Queued frame {frame_id}") + + except Exception as e: + logger.error(f"Error handling frame chunk: {e}", exc_info=True) async def process_frames_data_only(self): """Process frames for data extraction only, without video track output. @@ -232,24 +394,37 @@ async def process_frames_data_only(self): av_logging.set_libav_level(av_logging.ERROR) self._av_logging_set = True - logger.info("Starting data-only frame processing") + logger.info( + f"Starting data-only frame processing (use_data_channel_frames={self.use_data_channel_frames})" + ) try: - while ( - self.track - and self.track.readyState != "ended" - and not self._stop_processing - ): + while not self._stop_processing: if self._check_termination(): break - # Drain queue if using PlayerStreamTrack (RTSP) - if isinstance(self.track, PlayerStreamTrack): - while self.track._queue.qsize() > 30: - self.track._queue.get_nowait() + # Get frame from appropriate source + if self.use_data_channel_frames: + # Wait for frame from data channel queue + item = await self._data_frame_queue.get() + if item is None: + logger.info("Received stop signal from data channel") + break + frame_id, frame = item + self._received_frames = frame_id + else: + # Get frame from media track (existing behavior) + if not self.track or self.track.readyState == "ended": + break + + # Drain queue if using PlayerStreamTrack (RTSP) + if isinstance(self.track, PlayerStreamTrack): + while self.track._queue.qsize() > 30: + self.track._queue.get_nowait() + + frame = await self.track.recv() + self._received_frames += 1 - frame: VideoFrame = await self.track.recv() - self._received_frames += 1 frame_timestamp = datetime.datetime.now() workflow_output, _, errors = await self._process_frame_async( @@ -372,6 +547,7 @@ def __init__( declared_fps: float = 30, termination_date: Optional[datetime.datetime] = None, terminate_event: Optional[asyncio.Event] = None, + use_data_channel_frames: bool = False, *args, **kwargs, ): @@ -387,6 +563,7 @@ def __init__( declared_fps=declared_fps, termination_date=termination_date, terminate_event=terminate_event, + use_data_channel_frames=use_data_channel_frames, model_manager=model_manager, ) @@ -531,6 +708,7 @@ async def init_rtc_peer_connection_with_loop( declared_fps=webrtc_request.declared_fps, termination_date=termination_date, terminate_event=terminate_event, + use_data_channel_frames=webrtc_request.use_data_channel_frames, ) else: # No video track - use base VideoFrameProcessor @@ -545,6 +723,7 @@ async def init_rtc_peer_connection_with_loop( declared_fps=webrtc_request.declared_fps, termination_date=termination_date, terminate_event=terminate_event, + use_data_channel_frames=webrtc_request.use_data_channel_frames, ) except ( ValidationError, @@ -679,6 +858,23 @@ async def on_connectionstatechange(): def on_datachannel(channel: RTCDataChannel): logger.info("Data channel '%s' received", channel.label) + # Handle upstream frames channel (client sending frames to server) + if channel.label == "upstream_frames": + logger.info( + "Upstream frames channel established, starting data-only processing" + ) + + @channel.on("message") + def on_frame_message(message): + asyncio.create_task(video_processor._handle_data_channel_frame(message)) + + # Start processing immediately since we won't get a media track + if webrtc_request.use_data_channel_frames and not should_send_video: + asyncio.create_task(video_processor.process_frames_data_only()) + + return + + # Handle inference control channel (bidirectional communication) @channel.on("message") def on_message(message): logger.info("Data channel message received: %s", message)