1616 MODELS_CACHE_AUTH_CACHE_TTL ,
1717 MODELS_CACHE_AUTH_ENABLED ,
1818 PRELOAD_HF_IDS ,
19+ PRELOAD_MODELS ,
1920 PROJECT ,
2021 ROBOFLOW_INTERNAL_SERVICE_SECRET ,
2122 WEBRTC_MODAL_APP_NAME ,
4647from inference .core .interfaces .webrtc_worker .webrtc import (
4748 init_rtc_peer_connection_with_loop ,
4849)
50+ from inference .core .managers .base import ModelManager
51+ from inference .core .registries .roboflow import RoboflowModelRegistry
4952from inference .core .roboflow_api import get_roboflow_workspace
5053from inference .core .version import __version__
54+ from inference .models .aliases import resolve_roboflow_model_alias
55+ from inference .models .utils import ROBOFLOW_MODEL_TYPES
5156from inference .usage_tracking .collector import usage_collector
5257from inference .usage_tracking .plan_details import WebRTCPlan
5358
109114 "MODELS_CACHE_AUTH_ENABLED" : str (MODELS_CACHE_AUTH_ENABLED ),
110115 "LOG_LEVEL" : LOG_LEVEL ,
111116 "ONNXRUNTIME_EXECUTION_PROVIDERS" : "[CUDAExecutionProvider,CPUExecutionProvider]" ,
112- "PRELOAD_HF_IDS" : PRELOAD_HF_IDS ,
117+ "PRELOAD_HF_IDS" : str (PRELOAD_HF_IDS ),
118+ "PRELOAD_MODELS" : str (PRELOAD_MODELS ),
113119 "PROJECT" : PROJECT ,
114120 "ROBOFLOW_INTERNAL_SERVICE_NAME" : WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME ,
115121 "ROBOFLOW_INTERNAL_SERVICE_SECRET" : ROBOFLOW_INTERNAL_SERVICE_SECRET ,
135141 }
136142
137143 class RTCPeerConnectionModal :
138- _webrtc_request : Optional [WebRTCWorkerRequest ] = modal .parameter (default = None )
144+ _model_manager : Optional [ModelManager ] = modal .parameter (default = None )
139145
140146 @modal .method ()
141147 def rtc_peer_connection_modal (
@@ -145,6 +151,9 @@ def rtc_peer_connection_modal(
145151 ):
146152 logger .info ("*** Spawning %s:" , self .__class__ .__name__ )
147153 logger .info ("Inference tag: %s" , docker_tag )
154+ logger .info (
155+ "Preloaded models: %s" , ", " .join (self ._model_manager .models ().keys ())
156+ )
148157 _exec_session_started = datetime .datetime .now ()
149158 webrtc_request .processing_session_started = _exec_session_started
150159 logger .info (
@@ -170,7 +179,6 @@ def rtc_peer_connection_modal(
170179 else []
171180 ),
172181 )
173- self ._webrtc_request = webrtc_request
174182
175183 def send_answer (obj : WebRTCWorkerResult ):
176184 logger .info ("Sending webrtc answer" )
@@ -180,35 +188,36 @@ def send_answer(obj: WebRTCWorkerResult):
180188 init_rtc_peer_connection_with_loop (
181189 webrtc_request = webrtc_request ,
182190 send_answer = send_answer ,
191+ model_manager = self ._model_manager ,
183192 )
184193 )
185194 _exec_session_stopped = datetime .datetime .now ()
186195 logger .info (
187196 "WebRTC session stopped at %s" ,
188197 _exec_session_stopped .isoformat (),
189198 )
190- workflow_id = self . _webrtc_request .workflow_configuration .workflow_id
199+ workflow_id = webrtc_request .workflow_configuration .workflow_id
191200 if not workflow_id :
192- if self . _webrtc_request .workflow_configuration .workflow_specification :
201+ if webrtc_request .workflow_configuration .workflow_specification :
193202 workflow_id = usage_collector ._calculate_resource_hash (
194- resource_details = self . _webrtc_request .workflow_configuration .workflow_specification
203+ resource_details = webrtc_request .workflow_configuration .workflow_specification
195204 )
196205 else :
197206 workflow_id = "unknown"
198207
199208 # requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
200- webrtc_plan = self . _webrtc_request .requested_plan
209+ webrtc_plan = webrtc_request .requested_plan
201210
202211 video_source = "realtime browser stream"
203- if self . _webrtc_request .rtsp_url :
212+ if webrtc_request .rtsp_url :
204213 video_source = "rtsp"
205- elif not self . _webrtc_request .webrtc_realtime_processing :
214+ elif not webrtc_request .webrtc_realtime_processing :
206215 video_source = "buffered browser stream"
207216
208217 usage_collector .record_usage (
209218 source = workflow_id ,
210219 category = "modal" ,
211- api_key = self . _webrtc_request .api_key ,
220+ api_key = webrtc_request .api_key ,
212221 resource_details = {
213222 "plan" : webrtc_plan ,
214223 "billable" : True ,
@@ -229,7 +238,22 @@ def start(self):
229238 logger .info ("Starting container" )
230239 if PRELOAD_HF_IDS :
231240 # Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
241+ logger .info ("Preloading owlv2 base model" )
232242 import inference .models .owlv2 .owlv2
243+ if PRELOAD_MODELS :
244+ model_registry = RoboflowModelRegistry (ROBOFLOW_MODEL_TYPES )
245+ self ._model_manager = ModelManager (model_registry = model_registry )
246+ for model_id in PRELOAD_MODELS :
247+ de_aliased_model_id = resolve_roboflow_model_alias (
248+ model_id = model_id
249+ )
250+ logger .info (f"Preloading model: { de_aliased_model_id } " )
251+ self ._model_manager .add_model (
252+ model_id = de_aliased_model_id ,
253+ api_key = None ,
254+ countinference = False ,
255+ service_secret = ROBOFLOW_INTERNAL_SERVICE_SECRET ,
256+ )
233257
234258 @modal .exit ()
235259 def stop (self ):
0 commit comments