From 55cb44144639603655c49e6fce5ff6768fa999dc Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Tue, 18 Nov 2025 22:35:05 +0100 Subject: [PATCH 01/10] Preload HF base for owlv2 in modal --- inference/core/interfaces/webrtc_worker/modal.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 05d72ab6d9..0e594858c0 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -227,6 +227,9 @@ def send_answer(obj: WebRTCWorkerResult): def start(self): # TODO: pre-load models logger.info("Starting container") + if PRELOAD_HF_IDS: + # Kick off pre-loading of models (owlv2 preloading is based on module-level singleton) + import inference.models.owlv2.owlv2 @modal.exit() def stop(self): From f918798239e390632197568983d8619b0531f9db Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:13:55 +0100 Subject: [PATCH 02/10] Preload models in modal --- inference/core/interfaces/http/http_api.py | 4 +- .../interfaces/stream/inference_pipeline.py | 19 +++++--- .../core/interfaces/webrtc_worker/modal.py | 44 ++++++++++++++----- .../core/interfaces/webrtc_worker/webrtc.py | 8 ++++ 4 files changed, 57 insertions(+), 18 deletions(-) diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 99e5a2123b..dc99469326 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -160,6 +160,8 @@ PRELOAD_MODELS, PROFILE, ROBOFLOW_SERVICE_SECRET, + WEBRTC_MODAL_TOKEN_ID, + WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_WORKER_ENABLED, WORKFLOWS_MAX_CONCURRENT_STEPS, WORKFLOWS_PROFILER_BUFFER_SIZE, @@ -1603,7 +1605,7 @@ async def consume( ) # Enable preloading models at startup - if ( + if (WEBRTC_MODAL_TOKEN_ID and WEBRTC_MODAL_TOKEN_SECRET) or ( (PRELOAD_MODELS or DEDICATED_DEPLOYMENT_WORKSPACE_URL) and API_KEY and not (LAMBDA or GCP_SERVERLESS) diff --git a/inference/core/interfaces/stream/inference_pipeline.py b/inference/core/interfaces/stream/inference_pipeline.py index c507c1f39f..382a9f75ad 100644 --- a/inference/core/interfaces/stream/inference_pipeline.py +++ b/inference/core/interfaces/stream/inference_pipeline.py @@ -55,6 +55,7 @@ PipelineWatchDog, ) from inference.core.managers.active_learning import BackgroundTaskActiveLearningManager +from inference.core.managers.base import ModelManager from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache from inference.core.registries.roboflow import RoboflowModelRegistry from inference.core.utils.function import experimental @@ -486,6 +487,7 @@ def init_with_workflow( serialize_results: bool = False, predictions_queue_size: int = PREDICTIONS_QUEUE_SIZE, decoding_buffer_size: int = DEFAULT_BUFFER_SIZE, + model_manager: Optional[ModelManager] = None, ) -> "InferencePipeline": """ This class creates the abstraction for making inferences from given workflow against video stream. @@ -566,6 +568,8 @@ def init_with_workflow( default value is taken from INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE env variable decoding_buffer_size (int): size of video source decoding buffer default value is taken from VIDEO_SOURCE_BUFFER_SIZE env variable + model_manager (Optional[ModelManager]): Model manager to be used by InferencePipeline, defaults to + BackgroundTaskActiveLearningManager with WithFixedSizeCache Other ENV variables involved in low-level configuration: * INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching @@ -623,13 +627,14 @@ def init_with_workflow( use_cache=use_workflow_definition_cache, ) model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) - model_manager = BackgroundTaskActiveLearningManager( - model_registry=model_registry, cache=cache - ) - model_manager = WithFixedSizeCache( - model_manager, - max_size=MAX_ACTIVE_MODELS, - ) + if model_manager is None: + model_manager = BackgroundTaskActiveLearningManager( + model_registry=model_registry, cache=cache + ) + model_manager = WithFixedSizeCache( + model_manager, + max_size=MAX_ACTIVE_MODELS, + ) if workflow_init_parameters is None: workflow_init_parameters = {} thread_pool_executor = ThreadPoolExecutor( diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 0e594858c0..ee07104577 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -16,6 +16,7 @@ MODELS_CACHE_AUTH_CACHE_TTL, MODELS_CACHE_AUTH_ENABLED, PRELOAD_HF_IDS, + PRELOAD_MODELS, PROJECT, ROBOFLOW_INTERNAL_SERVICE_SECRET, WEBRTC_MODAL_APP_NAME, @@ -46,8 +47,12 @@ from inference.core.interfaces.webrtc_worker.webrtc import ( init_rtc_peer_connection_with_loop, ) +from inference.core.managers.base import ModelManager +from inference.core.registries.roboflow import RoboflowModelRegistry from inference.core.roboflow_api import get_roboflow_workspace from inference.core.version import __version__ +from inference.models.aliases import resolve_roboflow_model_alias +from inference.models.utils import ROBOFLOW_MODEL_TYPES from inference.usage_tracking.collector import usage_collector from inference.usage_tracking.plan_details import WebRTCPlan @@ -109,7 +114,8 @@ "MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED), "LOG_LEVEL": LOG_LEVEL, "ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]", - "PRELOAD_HF_IDS": PRELOAD_HF_IDS, + "PRELOAD_HF_IDS": str(PRELOAD_HF_IDS), + "PRELOAD_MODELS": str(PRELOAD_MODELS), "PROJECT": PROJECT, "ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME, "ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET, @@ -135,7 +141,7 @@ } class RTCPeerConnectionModal: - _webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None) + _model_manager: Optional[ModelManager] = modal.parameter(default=None) @modal.method() def rtc_peer_connection_modal( @@ -145,6 +151,9 @@ def rtc_peer_connection_modal( ): logger.info("*** Spawning %s:", self.__class__.__name__) logger.info("Inference tag: %s", docker_tag) + logger.info( + "Preloaded models: %s", ", ".join(self._model_manager.models().keys()) + ) _exec_session_started = datetime.datetime.now() webrtc_request.processing_session_started = _exec_session_started logger.info( @@ -170,7 +179,6 @@ def rtc_peer_connection_modal( else [] ), ) - self._webrtc_request = webrtc_request def send_answer(obj: WebRTCWorkerResult): logger.info("Sending webrtc answer") @@ -180,6 +188,7 @@ def send_answer(obj: WebRTCWorkerResult): init_rtc_peer_connection_with_loop( webrtc_request=webrtc_request, send_answer=send_answer, + model_manager=self._model_manager, ) ) _exec_session_stopped = datetime.datetime.now() @@ -187,28 +196,28 @@ def send_answer(obj: WebRTCWorkerResult): "WebRTC session stopped at %s", _exec_session_stopped.isoformat(), ) - workflow_id = self._webrtc_request.workflow_configuration.workflow_id + workflow_id = webrtc_request.workflow_configuration.workflow_id if not workflow_id: - if self._webrtc_request.workflow_configuration.workflow_specification: + if webrtc_request.workflow_configuration.workflow_specification: workflow_id = usage_collector._calculate_resource_hash( - resource_details=self._webrtc_request.workflow_configuration.workflow_specification + resource_details=webrtc_request.workflow_configuration.workflow_specification ) else: workflow_id = "unknown" # requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal - webrtc_plan = self._webrtc_request.requested_plan + webrtc_plan = webrtc_request.requested_plan video_source = "realtime browser stream" - if self._webrtc_request.rtsp_url: + if webrtc_request.rtsp_url: video_source = "rtsp" - elif not self._webrtc_request.webrtc_realtime_processing: + elif not webrtc_request.webrtc_realtime_processing: video_source = "buffered browser stream" usage_collector.record_usage( source=workflow_id, category="modal", - api_key=self._webrtc_request.api_key, + api_key=webrtc_request.api_key, resource_details={ "plan": webrtc_plan, "billable": True, @@ -229,7 +238,22 @@ def start(self): logger.info("Starting container") if PRELOAD_HF_IDS: # Kick off pre-loading of models (owlv2 preloading is based on module-level singleton) + logger.info("Preloading owlv2 base model") import inference.models.owlv2.owlv2 + if PRELOAD_MODELS: + model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) + self._model_manager = ModelManager(model_registry=model_registry) + for model_id in PRELOAD_MODELS: + de_aliased_model_id = resolve_roboflow_model_alias( + model_id=model_id + ) + logger.info(f"Preloading model: {de_aliased_model_id}") + self._model_manager.add_model( + model_id=de_aliased_model_id, + api_key=None, + countinference=False, + service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET, + ) @modal.exit() def stop(self): diff --git a/inference/core/interfaces/webrtc_worker/webrtc.py b/inference/core/interfaces/webrtc_worker/webrtc.py index 8a576d9197..b899d510be 100644 --- a/inference/core/interfaces/webrtc_worker/webrtc.py +++ b/inference/core/interfaces/webrtc_worker/webrtc.py @@ -49,6 +49,7 @@ detect_image_output, process_frame, ) +from inference.core.managers.base import ModelManager from inference.core.roboflow_api import get_workflow_specification from inference.core.workflows.core_steps.common.serializers import ( serialize_wildcard_kind, @@ -83,6 +84,7 @@ def __init__( asyncio_loop: asyncio.AbstractEventLoop, workflow_configuration: WorkflowConfiguration, api_key: str, + model_manager: Optional[ModelManager] = None, data_output: Optional[List[str]] = None, stream_output: Optional[str] = None, has_video_track: bool = True, @@ -134,6 +136,7 @@ def __init__( workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers, cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit, video_metadata_input_name=workflow_configuration.video_metadata_input_name, + model_manager=model_manager, ) def set_track(self, track: RemoteStreamTrack): @@ -362,6 +365,7 @@ def __init__( asyncio_loop: asyncio.AbstractEventLoop, workflow_configuration: WorkflowConfiguration, api_key: str, + model_manager: Optional[ModelManager] = None, data_output: Optional[List[str]] = None, stream_output: Optional[str] = None, has_video_track: bool = True, @@ -383,6 +387,7 @@ def __init__( declared_fps=declared_fps, termination_date=termination_date, terminate_event=terminate_event, + model_manager=model_manager, ) async def _auto_detect_stream_output( @@ -466,6 +471,7 @@ async def init_rtc_peer_connection_with_loop( webrtc_request: WebRTCWorkerRequest, send_answer: Callable[[WebRTCWorkerResult], None], asyncio_loop: Optional[asyncio.AbstractEventLoop] = None, + model_manager: Optional[ModelManager] = None, shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE, ) -> RTCPeerConnectionWithLoop: termination_date = None @@ -517,6 +523,7 @@ async def init_rtc_peer_connection_with_loop( video_processor = VideoTransformTrackWithLoop( asyncio_loop=asyncio_loop, workflow_configuration=webrtc_request.workflow_configuration, + model_manager=model_manager, api_key=webrtc_request.api_key, data_output=data_fields, stream_output=stream_field, @@ -530,6 +537,7 @@ async def init_rtc_peer_connection_with_loop( video_processor = VideoFrameProcessor( asyncio_loop=asyncio_loop, workflow_configuration=webrtc_request.workflow_configuration, + model_manager=model_manager, api_key=webrtc_request.api_key, data_output=data_fields, stream_output=None, From ed304d1d292fe4be6e1cb52785a8928b4dd8749d Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:17:55 +0100 Subject: [PATCH 03/10] Not reusing initialize_models from http_api --- inference/core/interfaces/http/http_api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index dc99469326..99e5a2123b 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -160,8 +160,6 @@ PRELOAD_MODELS, PROFILE, ROBOFLOW_SERVICE_SECRET, - WEBRTC_MODAL_TOKEN_ID, - WEBRTC_MODAL_TOKEN_SECRET, WEBRTC_WORKER_ENABLED, WORKFLOWS_MAX_CONCURRENT_STEPS, WORKFLOWS_PROFILER_BUFFER_SIZE, @@ -1605,7 +1603,7 @@ async def consume( ) # Enable preloading models at startup - if (WEBRTC_MODAL_TOKEN_ID and WEBRTC_MODAL_TOKEN_SECRET) or ( + if ( (PRELOAD_MODELS or DEDICATED_DEPLOYMENT_WORKSPACE_URL) and API_KEY and not (LAMBDA or GCP_SERVERLESS) From 6ef853d512c097980c77e02e9b1c72c9a8174f44 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:35:58 +0100 Subject: [PATCH 04/10] Handle no preloading --- inference/core/interfaces/webrtc_worker/modal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index ee07104577..93c6e3dc00 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -114,8 +114,8 @@ "MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED), "LOG_LEVEL": LOG_LEVEL, "ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]", - "PRELOAD_HF_IDS": str(PRELOAD_HF_IDS), - "PRELOAD_MODELS": str(PRELOAD_MODELS), + "PRELOAD_HF_IDS": str(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "", + "PRELOAD_MODELS": str(PRELOAD_MODELS) if PRELOAD_MODELS else "", "PROJECT": PROJECT, "ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME, "ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET, From 224b87aa381a4356cca127b03e3546348e7e9a2c Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:08:05 +0100 Subject: [PATCH 05/10] logging --- inference/core/interfaces/webrtc_worker/modal.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 93c6e3dc00..d4d6baeba5 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -152,7 +152,12 @@ def rtc_peer_connection_modal( logger.info("*** Spawning %s:", self.__class__.__name__) logger.info("Inference tag: %s", docker_tag) logger.info( - "Preloaded models: %s", ", ".join(self._model_manager.models().keys()) + "Preloaded models: %s", + ( + ", ".join(self._model_manager.models().keys()) + if self._model_manager + else "" + ), ) _exec_session_started = datetime.datetime.now() webrtc_request.processing_session_started = _exec_session_started From dcf4bf4850ffec99a5b2198e1adc9cb9c1b60778 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:24:21 +0100 Subject: [PATCH 06/10] logging --- inference/core/env.py | 2 +- inference/core/interfaces/webrtc_worker/modal.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/inference/core/env.py b/inference/core/env.py index 40827f5aa4..7c01868b25 100644 --- a/inference/core/env.py +++ b/inference/core/env.py @@ -122,7 +122,7 @@ # and also ENABLE_STREAM_API environmental variable is set to False PRELOAD_HF_IDS = os.getenv("PRELOAD_HF_IDS") if PRELOAD_HF_IDS: - PRELOAD_HF_IDS = [id.strip() for id in PRELOAD_HF_IDS.split(",")] + PRELOAD_HF_IDS = [m.strip() for m in PRELOAD_HF_IDS.split(",")] # Maximum batch size for GAZE, default is 8 GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8)) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index d4d6baeba5..e77fb22215 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -241,6 +241,8 @@ def send_answer(obj: WebRTCWorkerResult): def start(self): # TODO: pre-load models logger.info("Starting container") + logger.info("Preload hf ids: %s", PRELOAD_HF_IDS) + logger.info("Preload models: %s", PRELOAD_MODELS) if PRELOAD_HF_IDS: # Kick off pre-loading of models (owlv2 preloading is based on module-level singleton) logger.info("Preloading owlv2 base model") From 5d2c8fe7b9bcacff2665a4d869c9f06afed70725 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:26:12 +0100 Subject: [PATCH 07/10] skip models that cannot be preloaded --- .../core/interfaces/webrtc_worker/modal.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index e77fb22215..a95907be1b 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -251,16 +251,23 @@ def start(self): model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) self._model_manager = ModelManager(model_registry=model_registry) for model_id in PRELOAD_MODELS: - de_aliased_model_id = resolve_roboflow_model_alias( - model_id=model_id - ) - logger.info(f"Preloading model: {de_aliased_model_id}") - self._model_manager.add_model( - model_id=de_aliased_model_id, - api_key=None, - countinference=False, - service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET, - ) + try: + de_aliased_model_id = resolve_roboflow_model_alias( + model_id=model_id + ) + logger.info(f"Preloading model: {de_aliased_model_id}") + self._model_manager.add_model( + model_id=de_aliased_model_id, + api_key=None, + countinference=False, + service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET, + ) + except Exception as exc: + logger.error( + "Failed to preload model %s: %s", + model_id, + exc, + ) @modal.exit() def stop(self): From fd79c9f9d0a2940e8b2995f1ca0b983cd7a10162 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:30:32 +0100 Subject: [PATCH 08/10] Construct env back from list --- inference/core/interfaces/webrtc_worker/modal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index a95907be1b..60651a4061 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -114,8 +114,8 @@ "MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED), "LOG_LEVEL": LOG_LEVEL, "ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]", - "PRELOAD_HF_IDS": str(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "", - "PRELOAD_MODELS": str(PRELOAD_MODELS) if PRELOAD_MODELS else "", + "PRELOAD_HF_IDS": ",".join(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "", + "PRELOAD_MODELS": ",".join(PRELOAD_MODELS) if PRELOAD_MODELS else "", "PROJECT": PROJECT, "ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME, "ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET, From 3f1fea368de6a112b30acded11442e408baa88f5 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:29:59 +0100 Subject: [PATCH 09/10] attempt to avoid modal segfault in enter --- inference/core/interfaces/webrtc_worker/modal.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 60651a4061..73fef84727 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -249,14 +249,14 @@ def start(self): import inference.models.owlv2.owlv2 if PRELOAD_MODELS: model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) - self._model_manager = ModelManager(model_registry=model_registry) + model_manager = ModelManager(model_registry=model_registry) for model_id in PRELOAD_MODELS: try: de_aliased_model_id = resolve_roboflow_model_alias( model_id=model_id ) logger.info(f"Preloading model: {de_aliased_model_id}") - self._model_manager.add_model( + model_manager.add_model( model_id=de_aliased_model_id, api_key=None, countinference=False, @@ -268,6 +268,7 @@ def start(self): model_id, exc, ) + self._model_manager = model_manager @modal.exit() def stop(self): From ccfbd7c6ca5f02f6115449a1ee03a4bcc788deb2 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:36:09 +0100 Subject: [PATCH 10/10] Do not preload models on CPU --- .../core/interfaces/webrtc_worker/modal.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/inference/core/interfaces/webrtc_worker/modal.py b/inference/core/interfaces/webrtc_worker/modal.py index 73fef84727..7c4a77b6b1 100644 --- a/inference/core/interfaces/webrtc_worker/modal.py +++ b/inference/core/interfaces/webrtc_worker/modal.py @@ -235,12 +235,37 @@ def send_answer(obj: WebRTCWorkerResult): usage_collector.push_usage_payloads() logger.info("Function completed") + @modal.exit() + def stop(self): + logger.info("Stopping container") + + # Modal derives function name from class name + # https://modal.com/docs/reference/modal.App#cls + @app.cls( + **decorator_kwargs, + ) + class RTCPeerConnectionModalCPU(RTCPeerConnectionModal): + # https://modal.com/docs/reference/modal.enter + @modal.enter(snap=True) + def start(self): + # TODO: pre-load models on CPU + logger.info("Starting CPU container") + + @app.cls( + **{ + **decorator_kwargs, + "gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type + "experimental_options": { + "enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT + }, + } + ) + class RTCPeerConnectionModalGPU(RTCPeerConnectionModal): # https://modal.com/docs/reference/modal.enter # https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot @modal.enter(snap=True) def start(self): - # TODO: pre-load models - logger.info("Starting container") + logger.info("Starting GPU container") logger.info("Preload hf ids: %s", PRELOAD_HF_IDS) logger.info("Preload models: %s", PRELOAD_MODELS) if PRELOAD_HF_IDS: @@ -270,30 +295,6 @@ def start(self): ) self._model_manager = model_manager - @modal.exit() - def stop(self): - logger.info("Stopping container") - - # Modal derives function name from class name - # https://modal.com/docs/reference/modal.App#cls - @app.cls( - **decorator_kwargs, - ) - class RTCPeerConnectionModalCPU(RTCPeerConnectionModal): - pass - - @app.cls( - **{ - **decorator_kwargs, - "gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type - "experimental_options": { - "enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT - }, - } - ) - class RTCPeerConnectionModalGPU(RTCPeerConnectionModal): - pass - def spawn_rtc_peer_connection_modal( webrtc_request: WebRTCWorkerRequest, ) -> WebRTCWorkerResult: