Skip to content
2 changes: 1 addition & 1 deletion inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
# and also ENABLE_STREAM_API environmental variable is set to False
PRELOAD_HF_IDS = os.getenv("PRELOAD_HF_IDS")
if PRELOAD_HF_IDS:
PRELOAD_HF_IDS = [id.strip() for id in PRELOAD_HF_IDS.split(",")]
PRELOAD_HF_IDS = [m.strip() for m in PRELOAD_HF_IDS.split(",")]

# Maximum batch size for GAZE, default is 8
GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8))
Expand Down
19 changes: 12 additions & 7 deletions inference/core/interfaces/stream/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
PipelineWatchDog,
)
from inference.core.managers.active_learning import BackgroundTaskActiveLearningManager
from inference.core.managers.base import ModelManager
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.core.utils.function import experimental
Expand Down Expand Up @@ -486,6 +487,7 @@ def init_with_workflow(
serialize_results: bool = False,
predictions_queue_size: int = PREDICTIONS_QUEUE_SIZE,
decoding_buffer_size: int = DEFAULT_BUFFER_SIZE,
model_manager: Optional[ModelManager] = None,
) -> "InferencePipeline":
"""
This class creates the abstraction for making inferences from given workflow against video stream.
Expand Down Expand Up @@ -566,6 +568,8 @@ def init_with_workflow(
default value is taken from INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE env variable
decoding_buffer_size (int): size of video source decoding buffer
default value is taken from VIDEO_SOURCE_BUFFER_SIZE env variable
model_manager (Optional[ModelManager]): Model manager to be used by InferencePipeline, defaults to
BackgroundTaskActiveLearningManager with WithFixedSizeCache

Other ENV variables involved in low-level configuration:
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
Expand Down Expand Up @@ -623,13 +627,14 @@ def init_with_workflow(
use_cache=use_workflow_definition_cache,
)
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
model_manager = BackgroundTaskActiveLearningManager(
model_registry=model_registry, cache=cache
)
model_manager = WithFixedSizeCache(
model_manager,
max_size=MAX_ACTIVE_MODELS,
)
if model_manager is None:
model_manager = BackgroundTaskActiveLearningManager(
model_registry=model_registry, cache=cache
)
model_manager = WithFixedSizeCache(
model_manager,
max_size=MAX_ACTIVE_MODELS,
)
if workflow_init_parameters is None:
workflow_init_parameters = {}
thread_pool_executor = ThreadPoolExecutor(
Expand Down
81 changes: 62 additions & 19 deletions inference/core/interfaces/webrtc_worker/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MODELS_CACHE_AUTH_CACHE_TTL,
MODELS_CACHE_AUTH_ENABLED,
PRELOAD_HF_IDS,
PRELOAD_MODELS,
PROJECT,
ROBOFLOW_INTERNAL_SERVICE_SECRET,
WEBRTC_MODAL_APP_NAME,
Expand Down Expand Up @@ -46,8 +47,12 @@
from inference.core.interfaces.webrtc_worker.webrtc import (
init_rtc_peer_connection_with_loop,
)
from inference.core.managers.base import ModelManager
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.core.roboflow_api import get_roboflow_workspace
from inference.core.version import __version__
from inference.models.aliases import resolve_roboflow_model_alias
from inference.models.utils import ROBOFLOW_MODEL_TYPES
from inference.usage_tracking.collector import usage_collector
from inference.usage_tracking.plan_details import WebRTCPlan

Expand Down Expand Up @@ -109,7 +114,8 @@
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
"LOG_LEVEL": LOG_LEVEL,
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
"PRELOAD_HF_IDS": PRELOAD_HF_IDS,
"PRELOAD_HF_IDS": ",".join(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "",
"PRELOAD_MODELS": ",".join(PRELOAD_MODELS) if PRELOAD_MODELS else "",
"PROJECT": PROJECT,
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
Expand All @@ -135,7 +141,7 @@
}

class RTCPeerConnectionModal:
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
_model_manager: Optional[ModelManager] = modal.parameter(default=None)

@modal.method()
def rtc_peer_connection_modal(
Expand All @@ -145,6 +151,14 @@ def rtc_peer_connection_modal(
):
logger.info("*** Spawning %s:", self.__class__.__name__)
logger.info("Inference tag: %s", docker_tag)
logger.info(
"Preloaded models: %s",
(
", ".join(self._model_manager.models().keys())
if self._model_manager
else ""
),
)
_exec_session_started = datetime.datetime.now()
webrtc_request.processing_session_started = _exec_session_started
logger.info(
Expand All @@ -170,7 +184,6 @@ def rtc_peer_connection_modal(
else []
),
)
self._webrtc_request = webrtc_request

def send_answer(obj: WebRTCWorkerResult):
logger.info("Sending webrtc answer")
Expand All @@ -180,35 +193,36 @@ def send_answer(obj: WebRTCWorkerResult):
init_rtc_peer_connection_with_loop(
webrtc_request=webrtc_request,
send_answer=send_answer,
model_manager=self._model_manager,
)
)
_exec_session_stopped = datetime.datetime.now()
logger.info(
"WebRTC session stopped at %s",
_exec_session_stopped.isoformat(),
)
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
workflow_id = webrtc_request.workflow_configuration.workflow_id
if not workflow_id:
if self._webrtc_request.workflow_configuration.workflow_specification:
if webrtc_request.workflow_configuration.workflow_specification:
workflow_id = usage_collector._calculate_resource_hash(
resource_details=self._webrtc_request.workflow_configuration.workflow_specification
resource_details=webrtc_request.workflow_configuration.workflow_specification
)
else:
workflow_id = "unknown"

# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
webrtc_plan = self._webrtc_request.requested_plan
webrtc_plan = webrtc_request.requested_plan

video_source = "realtime browser stream"
if self._webrtc_request.rtsp_url:
if webrtc_request.rtsp_url:
video_source = "rtsp"
elif not self._webrtc_request.webrtc_realtime_processing:
elif not webrtc_request.webrtc_realtime_processing:
video_source = "buffered browser stream"

usage_collector.record_usage(
source=workflow_id,
category="modal",
api_key=self._webrtc_request.api_key,
api_key=webrtc_request.api_key,
resource_details={
"plan": webrtc_plan,
"billable": True,
Expand All @@ -221,13 +235,6 @@ def send_answer(obj: WebRTCWorkerResult):
usage_collector.push_usage_payloads()
logger.info("Function completed")

# https://modal.com/docs/reference/modal.enter
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
@modal.enter(snap=True)
def start(self):
# TODO: pre-load models
logger.info("Starting container")

@modal.exit()
def stop(self):
logger.info("Stopping container")
Expand All @@ -238,7 +245,11 @@ def stop(self):
**decorator_kwargs,
)
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
pass
# https://modal.com/docs/reference/modal.enter
@modal.enter(snap=True)
def start(self):
# TODO: pre-load models on CPU
logger.info("Starting CPU container")

@app.cls(
**{
Expand All @@ -250,7 +261,39 @@ class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
}
)
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
pass
# https://modal.com/docs/reference/modal.enter
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
@modal.enter(snap=True)
def start(self):
logger.info("Starting GPU container")
logger.info("Preload hf ids: %s", PRELOAD_HF_IDS)
logger.info("Preload models: %s", PRELOAD_MODELS)
if PRELOAD_HF_IDS:
# Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
logger.info("Preloading owlv2 base model")
import inference.models.owlv2.owlv2
if PRELOAD_MODELS:
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
model_manager = ModelManager(model_registry=model_registry)
for model_id in PRELOAD_MODELS:
try:
de_aliased_model_id = resolve_roboflow_model_alias(
model_id=model_id
)
logger.info(f"Preloading model: {de_aliased_model_id}")
model_manager.add_model(
model_id=de_aliased_model_id,
api_key=None,
countinference=False,
service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET,
)
except Exception as exc:
logger.error(
"Failed to preload model %s: %s",
model_id,
exc,
)
self._model_manager = model_manager

def spawn_rtc_peer_connection_modal(
webrtc_request: WebRTCWorkerRequest,
Expand Down
8 changes: 8 additions & 0 deletions inference/core/interfaces/webrtc_worker/webrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
detect_image_output,
process_frame,
)
from inference.core.managers.base import ModelManager
from inference.core.roboflow_api import get_workflow_specification
from inference.core.workflows.core_steps.common.serializers import (
serialize_wildcard_kind,
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
asyncio_loop: asyncio.AbstractEventLoop,
workflow_configuration: WorkflowConfiguration,
api_key: str,
model_manager: Optional[ModelManager] = None,
data_output: Optional[List[str]] = None,
stream_output: Optional[str] = None,
has_video_track: bool = True,
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers,
cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit,
video_metadata_input_name=workflow_configuration.video_metadata_input_name,
model_manager=model_manager,
)

def set_track(self, track: RemoteStreamTrack):
Expand Down Expand Up @@ -362,6 +365,7 @@ def __init__(
asyncio_loop: asyncio.AbstractEventLoop,
workflow_configuration: WorkflowConfiguration,
api_key: str,
model_manager: Optional[ModelManager] = None,
data_output: Optional[List[str]] = None,
stream_output: Optional[str] = None,
has_video_track: bool = True,
Expand All @@ -383,6 +387,7 @@ def __init__(
declared_fps=declared_fps,
termination_date=termination_date,
terminate_event=terminate_event,
model_manager=model_manager,
)

async def _auto_detect_stream_output(
Expand Down Expand Up @@ -466,6 +471,7 @@ async def init_rtc_peer_connection_with_loop(
webrtc_request: WebRTCWorkerRequest,
send_answer: Callable[[WebRTCWorkerResult], None],
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
model_manager: Optional[ModelManager] = None,
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
) -> RTCPeerConnectionWithLoop:
termination_date = None
Expand Down Expand Up @@ -517,6 +523,7 @@ async def init_rtc_peer_connection_with_loop(
video_processor = VideoTransformTrackWithLoop(
asyncio_loop=asyncio_loop,
workflow_configuration=webrtc_request.workflow_configuration,
model_manager=model_manager,
api_key=webrtc_request.api_key,
data_output=data_fields,
stream_output=stream_field,
Expand All @@ -530,6 +537,7 @@ async def init_rtc_peer_connection_with_loop(
video_processor = VideoFrameProcessor(
asyncio_loop=asyncio_loop,
workflow_configuration=webrtc_request.workflow_configuration,
model_manager=model_manager,
api_key=webrtc_request.api_key,
data_output=data_fields,
stream_output=None,
Expand Down
Loading