Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference/core/interfaces/webrtc_worker/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class WebRTCWorkerRequest(BaseModel):
data_output: Optional[List[str]] = Field(default=None)
declared_fps: Optional[float] = None
rtsp_url: Optional[str] = None
use_data_channel_frames: bool = (
False # When True, expect frames via data channel instead of media track
)
processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT
processing_session_started: Optional[datetime.datetime] = None
requested_plan: Optional[str] = "webrtc-gpu-small"
Expand Down
226 changes: 211 additions & 15 deletions inference/core/interfaces/webrtc_worker/webrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import datetime
import json
import logging
import struct
from typing import Any, Callable, Dict, List, Optional, Tuple

import cv2
import numpy as np
from aiortc import (
RTCConfiguration,
RTCDataChannel,
Expand Down Expand Up @@ -60,6 +63,113 @@

logging.getLogger("aiortc").setLevel(logging.WARNING)

# WebRTC data channel chunking configuration
CHUNK_SIZE = 48 * 1024 # 48KB - safe for all WebRTC implementations


def create_chunked_binary_message(
frame_id: int, chunk_index: int, total_chunks: int, payload: bytes
) -> bytes:
"""Create a binary message with standard 12-byte header.

Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N]
All integers are uint32 little-endian.
"""
header = struct.pack("<III", frame_id, chunk_index, total_chunks)
return header + payload


def parse_chunked_binary_message(message: bytes) -> Tuple[int, int, int, bytes]:
"""Parse a binary message with standard 12-byte header.

Returns: (frame_id, chunk_index, total_chunks, payload)
"""
if len(message) < 12:
raise ValueError(f"Message too short: {len(message)} bytes (expected >= 12)")

frame_id, chunk_index, total_chunks = struct.unpack("<III", message[0:12])
payload = message[12:]
return frame_id, chunk_index, total_chunks, payload


class ChunkReassembler:
"""Helper to reassemble chunked binary messages."""

def __init__(self):
self._chunks: Dict[int, Dict[int, bytes]] = (
{}
) # {frame_id: {chunk_index: data}}
self._total: Dict[int, int] = {} # {frame_id: total_chunks}

def add_chunk(
self, frame_id: int, chunk_index: int, total_chunks: int, chunk_data: bytes
) -> Optional[bytes]:
"""Add a chunk and return complete payload if all chunks received.

Returns:
Complete reassembled payload bytes if all chunks received, None otherwise.
"""
# Initialize buffers for new frame
if frame_id not in self._chunks:
self._chunks[frame_id] = {}
self._total[frame_id] = total_chunks

# Store chunk
self._chunks[frame_id][chunk_index] = chunk_data

# Check if all chunks received
if len(self._chunks[frame_id]) >= total_chunks:
# Reassemble in order
complete_payload = b"".join(
self._chunks[frame_id][i] for i in range(total_chunks)
)

# Clean up
del self._chunks[frame_id]
del self._total[frame_id]

return complete_payload

return None


def send_chunked_data(
data_channel: RTCDataChannel,
frame_id: int,
payload_bytes: bytes,
chunk_size: int = CHUNK_SIZE,
) -> None:
"""Send payload via data channel, automatically chunking if needed.

Args:
data_channel: RTCDataChannel to send on
frame_id: Frame identifier
payload_bytes: Data to send (JPEG, JSON UTF-8, etc.)
chunk_size: Maximum chunk size (default 48KB)
"""
if data_channel.readyState != "open":
logger.warning(f"Cannot send response for frame {frame_id}, channel not open")
return

total_chunks = (
len(payload_bytes) + chunk_size - 1
) // chunk_size # Ceiling division

if frame_id % 100 == 1:
logger.info(
f"Sending response for frame {frame_id}: {total_chunks} chunk(s), {len(payload_bytes)} bytes"
)

for chunk_index in range(total_chunks):
start = chunk_index * chunk_size
end = min(start + chunk_size, len(payload_bytes))
chunk_data = payload_bytes[start:end]

message = create_chunked_binary_message(
frame_id, chunk_index, total_chunks, chunk_data
)
data_channel.send(message)


class RTCPeerConnectionWithLoop(RTCPeerConnection):
def __init__(
Expand Down Expand Up @@ -91,6 +201,7 @@ def __init__(
declared_fps: float = 30,
termination_date: Optional[datetime.datetime] = None,
terminate_event: Optional[asyncio.Event] = None,
use_data_channel_frames: bool = False,
):
self._loop = asyncio_loop
self._termination_date = termination_date
Expand All @@ -101,6 +212,11 @@ def __init__(
self._received_frames = 0
self._declared_fps = declared_fps
self._stop_processing = False
self.use_data_channel_frames = use_data_channel_frames
self._data_frame_queue: "asyncio.Queue[Optional[VideoFrame]]" = asyncio.Queue()
self._chunk_reassembler = (
ChunkReassembler()
) # For reassembling inbound frame chunks

self.has_video_track = has_video_track
self.stream_output = stream_output
Expand Down Expand Up @@ -185,7 +301,9 @@ async def _send_data_output(
)

if self._data_mode == DataOutputMode.NONE:
self.data_channel.send(json.dumps(webrtc_output.model_dump()))
# Even empty responses use binary protocol
json_bytes = json.dumps(webrtc_output.model_dump()).encode("utf-8")
send_chunked_data(self.data_channel, self._received_frames, json_bytes)
return

if self._data_mode == DataOutputMode.ALL:
Expand Down Expand Up @@ -216,11 +334,55 @@ async def _send_data_output(
webrtc_output.errors.append(f"{field_name}: {e}")
serialized_outputs[field_name] = {"__serialization_error__": str(e)}

# Only set serialized_output_data if we have data to send
# Set serialized outputs
if serialized_outputs:
webrtc_output.serialized_output_data = serialized_outputs

self.data_channel.send(json.dumps(webrtc_output.model_dump(mode="json")))
# Send using binary chunked protocol
json_bytes = json.dumps(webrtc_output.model_dump(mode="json")).encode("utf-8")
send_chunked_data(self.data_channel, self._received_frames, json_bytes)

async def _handle_data_channel_frame(self, message: bytes) -> None:
"""Handle incoming binary frame chunk from upstream_frames data channel.

Uses standard binary protocol with 12-byte header + JPEG chunk payload.
"""
try:
# Parse message
frame_id, chunk_index, total_chunks, jpeg_chunk = (
parse_chunked_binary_message(message)
)

# Add chunk and check if complete
jpeg_bytes = self._chunk_reassembler.add_chunk(
frame_id, chunk_index, total_chunks, jpeg_chunk
)

if jpeg_bytes is None:
# Still waiting for more chunks
return

# All chunks received - decode and queue frame
if frame_id % 100 == 1:
logger.info(
f"Received frame {frame_id}: {total_chunks} chunk(s), {len(jpeg_bytes)} bytes JPEG"
)

nparr = np.frombuffer(jpeg_bytes, np.uint8)
np_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

if np_image is None:
logger.error(f"Failed to decode JPEG for frame {frame_id}")
return

video_frame = VideoFrame.from_ndarray(np_image, format="bgr24")
await self._data_frame_queue.put((frame_id, video_frame))

if frame_id % 100 == 1:
logger.info(f"Queued frame {frame_id}")

except Exception as e:
logger.error(f"Error handling frame chunk: {e}", exc_info=True)

async def process_frames_data_only(self):
"""Process frames for data extraction only, without video track output.
Expand All @@ -232,24 +394,37 @@ async def process_frames_data_only(self):
av_logging.set_libav_level(av_logging.ERROR)
self._av_logging_set = True

logger.info("Starting data-only frame processing")
logger.info(
f"Starting data-only frame processing (use_data_channel_frames={self.use_data_channel_frames})"
)

try:
while (
self.track
and self.track.readyState != "ended"
and not self._stop_processing
):
while not self._stop_processing:
if self._check_termination():
break

# Drain queue if using PlayerStreamTrack (RTSP)
if isinstance(self.track, PlayerStreamTrack):
while self.track._queue.qsize() > 30:
self.track._queue.get_nowait()
# Get frame from appropriate source
if self.use_data_channel_frames:
# Wait for frame from data channel queue
item = await self._data_frame_queue.get()
if item is None:
logger.info("Received stop signal from data channel")
break
frame_id, frame = item
self._received_frames = frame_id
else:
# Get frame from media track (existing behavior)
if not self.track or self.track.readyState == "ended":
break

# Drain queue if using PlayerStreamTrack (RTSP)
if isinstance(self.track, PlayerStreamTrack):
while self.track._queue.qsize() > 30:
self.track._queue.get_nowait()

frame = await self.track.recv()
self._received_frames += 1

frame: VideoFrame = await self.track.recv()
self._received_frames += 1
frame_timestamp = datetime.datetime.now()

workflow_output, _, errors = await self._process_frame_async(
Expand Down Expand Up @@ -372,6 +547,7 @@ def __init__(
declared_fps: float = 30,
termination_date: Optional[datetime.datetime] = None,
terminate_event: Optional[asyncio.Event] = None,
use_data_channel_frames: bool = False,
*args,
**kwargs,
):
Expand All @@ -387,6 +563,7 @@ def __init__(
declared_fps=declared_fps,
termination_date=termination_date,
terminate_event=terminate_event,
use_data_channel_frames=use_data_channel_frames,
model_manager=model_manager,
)

Expand Down Expand Up @@ -531,6 +708,7 @@ async def init_rtc_peer_connection_with_loop(
declared_fps=webrtc_request.declared_fps,
termination_date=termination_date,
terminate_event=terminate_event,
use_data_channel_frames=webrtc_request.use_data_channel_frames,
)
else:
# No video track - use base VideoFrameProcessor
Expand All @@ -545,6 +723,7 @@ async def init_rtc_peer_connection_with_loop(
declared_fps=webrtc_request.declared_fps,
termination_date=termination_date,
terminate_event=terminate_event,
use_data_channel_frames=webrtc_request.use_data_channel_frames,
)
except (
ValidationError,
Expand Down Expand Up @@ -679,6 +858,23 @@ async def on_connectionstatechange():
def on_datachannel(channel: RTCDataChannel):
logger.info("Data channel '%s' received", channel.label)

# Handle upstream frames channel (client sending frames to server)
if channel.label == "upstream_frames":
logger.info(
"Upstream frames channel established, starting data-only processing"
)

@channel.on("message")
def on_frame_message(message):
asyncio.create_task(video_processor._handle_data_channel_frame(message))

# Start processing immediately since we won't get a media track
if webrtc_request.use_data_channel_frames and not should_send_video:
asyncio.create_task(video_processor.process_frames_data_only())

return

# Handle inference control channel (bidirectional communication)
@channel.on("message")
def on_message(message):
logger.info("Data channel message received: %s", message)
Expand Down
Loading