Skip to content

Commit f918798

Browse files
Preload models in modal
1 parent 55cb441 commit f918798

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

inference/core/interfaces/http/http_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@
160160
PRELOAD_MODELS,
161161
PROFILE,
162162
ROBOFLOW_SERVICE_SECRET,
163+
WEBRTC_MODAL_TOKEN_ID,
164+
WEBRTC_MODAL_TOKEN_SECRET,
163165
WEBRTC_WORKER_ENABLED,
164166
WORKFLOWS_MAX_CONCURRENT_STEPS,
165167
WORKFLOWS_PROFILER_BUFFER_SIZE,
@@ -1603,7 +1605,7 @@ async def consume(
16031605
)
16041606

16051607
# Enable preloading models at startup
1606-
if (
1608+
if (WEBRTC_MODAL_TOKEN_ID and WEBRTC_MODAL_TOKEN_SECRET) or (
16071609
(PRELOAD_MODELS or DEDICATED_DEPLOYMENT_WORKSPACE_URL)
16081610
and API_KEY
16091611
and not (LAMBDA or GCP_SERVERLESS)

inference/core/interfaces/stream/inference_pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
PipelineWatchDog,
5656
)
5757
from inference.core.managers.active_learning import BackgroundTaskActiveLearningManager
58+
from inference.core.managers.base import ModelManager
5859
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
5960
from inference.core.registries.roboflow import RoboflowModelRegistry
6061
from inference.core.utils.function import experimental
@@ -486,6 +487,7 @@ def init_with_workflow(
486487
serialize_results: bool = False,
487488
predictions_queue_size: int = PREDICTIONS_QUEUE_SIZE,
488489
decoding_buffer_size: int = DEFAULT_BUFFER_SIZE,
490+
model_manager: Optional[ModelManager] = None,
489491
) -> "InferencePipeline":
490492
"""
491493
This class creates the abstraction for making inferences from given workflow against video stream.
@@ -566,6 +568,8 @@ def init_with_workflow(
566568
default value is taken from INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE env variable
567569
decoding_buffer_size (int): size of video source decoding buffer
568570
default value is taken from VIDEO_SOURCE_BUFFER_SIZE env variable
571+
model_manager (Optional[ModelManager]): Model manager to be used by InferencePipeline, defaults to
572+
BackgroundTaskActiveLearningManager with WithFixedSizeCache
569573
570574
Other ENV variables involved in low-level configuration:
571575
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
@@ -623,13 +627,14 @@ def init_with_workflow(
623627
use_cache=use_workflow_definition_cache,
624628
)
625629
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
626-
model_manager = BackgroundTaskActiveLearningManager(
627-
model_registry=model_registry, cache=cache
628-
)
629-
model_manager = WithFixedSizeCache(
630-
model_manager,
631-
max_size=MAX_ACTIVE_MODELS,
632-
)
630+
if model_manager is None:
631+
model_manager = BackgroundTaskActiveLearningManager(
632+
model_registry=model_registry, cache=cache
633+
)
634+
model_manager = WithFixedSizeCache(
635+
model_manager,
636+
max_size=MAX_ACTIVE_MODELS,
637+
)
633638
if workflow_init_parameters is None:
634639
workflow_init_parameters = {}
635640
thread_pool_executor = ThreadPoolExecutor(

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MODELS_CACHE_AUTH_CACHE_TTL,
1717
MODELS_CACHE_AUTH_ENABLED,
1818
PRELOAD_HF_IDS,
19+
PRELOAD_MODELS,
1920
PROJECT,
2021
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2122
WEBRTC_MODAL_APP_NAME,
@@ -46,8 +47,12 @@
4647
from inference.core.interfaces.webrtc_worker.webrtc import (
4748
init_rtc_peer_connection_with_loop,
4849
)
50+
from inference.core.managers.base import ModelManager
51+
from inference.core.registries.roboflow import RoboflowModelRegistry
4952
from inference.core.roboflow_api import get_roboflow_workspace
5053
from inference.core.version import __version__
54+
from inference.models.aliases import resolve_roboflow_model_alias
55+
from inference.models.utils import ROBOFLOW_MODEL_TYPES
5156
from inference.usage_tracking.collector import usage_collector
5257
from inference.usage_tracking.plan_details import WebRTCPlan
5358

@@ -109,7 +114,8 @@
109114
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
110115
"LOG_LEVEL": LOG_LEVEL,
111116
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
112-
"PRELOAD_HF_IDS": PRELOAD_HF_IDS,
117+
"PRELOAD_HF_IDS": str(PRELOAD_HF_IDS),
118+
"PRELOAD_MODELS": str(PRELOAD_MODELS),
113119
"PROJECT": PROJECT,
114120
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
115121
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
@@ -135,7 +141,7 @@
135141
}
136142

137143
class RTCPeerConnectionModal:
138-
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
144+
_model_manager: Optional[ModelManager] = modal.parameter(default=None)
139145

140146
@modal.method()
141147
def rtc_peer_connection_modal(
@@ -145,6 +151,9 @@ def rtc_peer_connection_modal(
145151
):
146152
logger.info("*** Spawning %s:", self.__class__.__name__)
147153
logger.info("Inference tag: %s", docker_tag)
154+
logger.info(
155+
"Preloaded models: %s", ", ".join(self._model_manager.models().keys())
156+
)
148157
_exec_session_started = datetime.datetime.now()
149158
webrtc_request.processing_session_started = _exec_session_started
150159
logger.info(
@@ -170,7 +179,6 @@ def rtc_peer_connection_modal(
170179
else []
171180
),
172181
)
173-
self._webrtc_request = webrtc_request
174182

175183
def send_answer(obj: WebRTCWorkerResult):
176184
logger.info("Sending webrtc answer")
@@ -180,35 +188,36 @@ def send_answer(obj: WebRTCWorkerResult):
180188
init_rtc_peer_connection_with_loop(
181189
webrtc_request=webrtc_request,
182190
send_answer=send_answer,
191+
model_manager=self._model_manager,
183192
)
184193
)
185194
_exec_session_stopped = datetime.datetime.now()
186195
logger.info(
187196
"WebRTC session stopped at %s",
188197
_exec_session_stopped.isoformat(),
189198
)
190-
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
199+
workflow_id = webrtc_request.workflow_configuration.workflow_id
191200
if not workflow_id:
192-
if self._webrtc_request.workflow_configuration.workflow_specification:
201+
if webrtc_request.workflow_configuration.workflow_specification:
193202
workflow_id = usage_collector._calculate_resource_hash(
194-
resource_details=self._webrtc_request.workflow_configuration.workflow_specification
203+
resource_details=webrtc_request.workflow_configuration.workflow_specification
195204
)
196205
else:
197206
workflow_id = "unknown"
198207

199208
# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
200-
webrtc_plan = self._webrtc_request.requested_plan
209+
webrtc_plan = webrtc_request.requested_plan
201210

202211
video_source = "realtime browser stream"
203-
if self._webrtc_request.rtsp_url:
212+
if webrtc_request.rtsp_url:
204213
video_source = "rtsp"
205-
elif not self._webrtc_request.webrtc_realtime_processing:
214+
elif not webrtc_request.webrtc_realtime_processing:
206215
video_source = "buffered browser stream"
207216

208217
usage_collector.record_usage(
209218
source=workflow_id,
210219
category="modal",
211-
api_key=self._webrtc_request.api_key,
220+
api_key=webrtc_request.api_key,
212221
resource_details={
213222
"plan": webrtc_plan,
214223
"billable": True,
@@ -229,7 +238,22 @@ def start(self):
229238
logger.info("Starting container")
230239
if PRELOAD_HF_IDS:
231240
# Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
241+
logger.info("Preloading owlv2 base model")
232242
import inference.models.owlv2.owlv2
243+
if PRELOAD_MODELS:
244+
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
245+
self._model_manager = ModelManager(model_registry=model_registry)
246+
for model_id in PRELOAD_MODELS:
247+
de_aliased_model_id = resolve_roboflow_model_alias(
248+
model_id=model_id
249+
)
250+
logger.info(f"Preloading model: {de_aliased_model_id}")
251+
self._model_manager.add_model(
252+
model_id=de_aliased_model_id,
253+
api_key=None,
254+
countinference=False,
255+
service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET,
256+
)
233257

234258
@modal.exit()
235259
def stop(self):

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
detect_image_output,
5050
process_frame,
5151
)
52+
from inference.core.managers.base import ModelManager
5253
from inference.core.roboflow_api import get_workflow_specification
5354
from inference.core.workflows.core_steps.common.serializers import (
5455
serialize_wildcard_kind,
@@ -83,6 +84,7 @@ def __init__(
8384
asyncio_loop: asyncio.AbstractEventLoop,
8485
workflow_configuration: WorkflowConfiguration,
8586
api_key: str,
87+
model_manager: Optional[ModelManager] = None,
8688
data_output: Optional[List[str]] = None,
8789
stream_output: Optional[str] = None,
8890
has_video_track: bool = True,
@@ -134,6 +136,7 @@ def __init__(
134136
workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers,
135137
cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit,
136138
video_metadata_input_name=workflow_configuration.video_metadata_input_name,
139+
model_manager=model_manager,
137140
)
138141

139142
def set_track(self, track: RemoteStreamTrack):
@@ -362,6 +365,7 @@ def __init__(
362365
asyncio_loop: asyncio.AbstractEventLoop,
363366
workflow_configuration: WorkflowConfiguration,
364367
api_key: str,
368+
model_manager: Optional[ModelManager] = None,
365369
data_output: Optional[List[str]] = None,
366370
stream_output: Optional[str] = None,
367371
has_video_track: bool = True,
@@ -383,6 +387,7 @@ def __init__(
383387
declared_fps=declared_fps,
384388
termination_date=termination_date,
385389
terminate_event=terminate_event,
390+
model_manager=model_manager,
386391
)
387392

388393
async def _auto_detect_stream_output(
@@ -466,6 +471,7 @@ async def init_rtc_peer_connection_with_loop(
466471
webrtc_request: WebRTCWorkerRequest,
467472
send_answer: Callable[[WebRTCWorkerResult], None],
468473
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
474+
model_manager: Optional[ModelManager] = None,
469475
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
470476
) -> RTCPeerConnectionWithLoop:
471477
termination_date = None
@@ -517,6 +523,7 @@ async def init_rtc_peer_connection_with_loop(
517523
video_processor = VideoTransformTrackWithLoop(
518524
asyncio_loop=asyncio_loop,
519525
workflow_configuration=webrtc_request.workflow_configuration,
526+
model_manager=model_manager,
520527
api_key=webrtc_request.api_key,
521528
data_output=data_fields,
522529
stream_output=stream_field,
@@ -530,6 +537,7 @@ async def init_rtc_peer_connection_with_loop(
530537
video_processor = VideoFrameProcessor(
531538
asyncio_loop=asyncio_loop,
532539
workflow_configuration=webrtc_request.workflow_configuration,
540+
model_manager=model_manager,
533541
api_key=webrtc_request.api_key,
534542
data_output=data_fields,
535543
stream_output=None,

0 commit comments

Comments
 (0)