Skip to content

Commit d67dac6

Browse files
Merge pull request #1732 from roboflow/feat/modal-preload-owlv2
Preload HF base for owlv2 and models specified within PRELOAD_MODELS in modal
2 parents d7ab4dd + ccfbd7c commit d67dac6

File tree

4 files changed

+83
-27
lines changed

4 files changed

+83
-27
lines changed

inference/core/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
# and also ENABLE_STREAM_API environmental variable is set to False
123123
PRELOAD_HF_IDS = os.getenv("PRELOAD_HF_IDS")
124124
if PRELOAD_HF_IDS:
125-
PRELOAD_HF_IDS = [id.strip() for id in PRELOAD_HF_IDS.split(",")]
125+
PRELOAD_HF_IDS = [m.strip() for m in PRELOAD_HF_IDS.split(",")]
126126

127127
# Maximum batch size for GAZE, default is 8
128128
GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8))

inference/core/interfaces/stream/inference_pipeline.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
PipelineWatchDog,
5656
)
5757
from inference.core.managers.active_learning import BackgroundTaskActiveLearningManager
58+
from inference.core.managers.base import ModelManager
5859
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
5960
from inference.core.registries.roboflow import RoboflowModelRegistry
6061
from inference.core.utils.function import experimental
@@ -486,6 +487,7 @@ def init_with_workflow(
486487
serialize_results: bool = False,
487488
predictions_queue_size: int = PREDICTIONS_QUEUE_SIZE,
488489
decoding_buffer_size: int = DEFAULT_BUFFER_SIZE,
490+
model_manager: Optional[ModelManager] = None,
489491
) -> "InferencePipeline":
490492
"""
491493
This class creates the abstraction for making inferences from given workflow against video stream.
@@ -566,6 +568,8 @@ def init_with_workflow(
566568
default value is taken from INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE env variable
567569
decoding_buffer_size (int): size of video source decoding buffer
568570
default value is taken from VIDEO_SOURCE_BUFFER_SIZE env variable
571+
model_manager (Optional[ModelManager]): Model manager to be used by InferencePipeline, defaults to
572+
BackgroundTaskActiveLearningManager with WithFixedSizeCache
569573
570574
Other ENV variables involved in low-level configuration:
571575
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
@@ -623,13 +627,14 @@ def init_with_workflow(
623627
use_cache=use_workflow_definition_cache,
624628
)
625629
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
626-
model_manager = BackgroundTaskActiveLearningManager(
627-
model_registry=model_registry, cache=cache
628-
)
629-
model_manager = WithFixedSizeCache(
630-
model_manager,
631-
max_size=MAX_ACTIVE_MODELS,
632-
)
630+
if model_manager is None:
631+
model_manager = BackgroundTaskActiveLearningManager(
632+
model_registry=model_registry, cache=cache
633+
)
634+
model_manager = WithFixedSizeCache(
635+
model_manager,
636+
max_size=MAX_ACTIVE_MODELS,
637+
)
633638
if workflow_init_parameters is None:
634639
workflow_init_parameters = {}
635640
thread_pool_executor = ThreadPoolExecutor(

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MODELS_CACHE_AUTH_CACHE_TTL,
1717
MODELS_CACHE_AUTH_ENABLED,
1818
PRELOAD_HF_IDS,
19+
PRELOAD_MODELS,
1920
PROJECT,
2021
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2122
WEBRTC_MODAL_APP_NAME,
@@ -46,8 +47,12 @@
4647
from inference.core.interfaces.webrtc_worker.webrtc import (
4748
init_rtc_peer_connection_with_loop,
4849
)
50+
from inference.core.managers.base import ModelManager
51+
from inference.core.registries.roboflow import RoboflowModelRegistry
4952
from inference.core.roboflow_api import get_roboflow_workspace
5053
from inference.core.version import __version__
54+
from inference.models.aliases import resolve_roboflow_model_alias
55+
from inference.models.utils import ROBOFLOW_MODEL_TYPES
5156
from inference.usage_tracking.collector import usage_collector
5257
from inference.usage_tracking.plan_details import WebRTCPlan
5358

@@ -109,7 +114,8 @@
109114
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
110115
"LOG_LEVEL": LOG_LEVEL,
111116
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
112-
"PRELOAD_HF_IDS": PRELOAD_HF_IDS,
117+
"PRELOAD_HF_IDS": ",".join(PRELOAD_HF_IDS) if PRELOAD_HF_IDS else "",
118+
"PRELOAD_MODELS": ",".join(PRELOAD_MODELS) if PRELOAD_MODELS else "",
113119
"PROJECT": PROJECT,
114120
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
115121
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
@@ -135,7 +141,7 @@
135141
}
136142

137143
class RTCPeerConnectionModal:
138-
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
144+
_model_manager: Optional[ModelManager] = modal.parameter(default=None)
139145

140146
@modal.method()
141147
def rtc_peer_connection_modal(
@@ -145,6 +151,14 @@ def rtc_peer_connection_modal(
145151
):
146152
logger.info("*** Spawning %s:", self.__class__.__name__)
147153
logger.info("Inference tag: %s", docker_tag)
154+
logger.info(
155+
"Preloaded models: %s",
156+
(
157+
", ".join(self._model_manager.models().keys())
158+
if self._model_manager
159+
else ""
160+
),
161+
)
148162
_exec_session_started = datetime.datetime.now()
149163
webrtc_request.processing_session_started = _exec_session_started
150164
logger.info(
@@ -170,7 +184,6 @@ def rtc_peer_connection_modal(
170184
else []
171185
),
172186
)
173-
self._webrtc_request = webrtc_request
174187

175188
def send_answer(obj: WebRTCWorkerResult):
176189
logger.info("Sending webrtc answer")
@@ -180,35 +193,36 @@ def send_answer(obj: WebRTCWorkerResult):
180193
init_rtc_peer_connection_with_loop(
181194
webrtc_request=webrtc_request,
182195
send_answer=send_answer,
196+
model_manager=self._model_manager,
183197
)
184198
)
185199
_exec_session_stopped = datetime.datetime.now()
186200
logger.info(
187201
"WebRTC session stopped at %s",
188202
_exec_session_stopped.isoformat(),
189203
)
190-
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
204+
workflow_id = webrtc_request.workflow_configuration.workflow_id
191205
if not workflow_id:
192-
if self._webrtc_request.workflow_configuration.workflow_specification:
206+
if webrtc_request.workflow_configuration.workflow_specification:
193207
workflow_id = usage_collector._calculate_resource_hash(
194-
resource_details=self._webrtc_request.workflow_configuration.workflow_specification
208+
resource_details=webrtc_request.workflow_configuration.workflow_specification
195209
)
196210
else:
197211
workflow_id = "unknown"
198212

199213
# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
200-
webrtc_plan = self._webrtc_request.requested_plan
214+
webrtc_plan = webrtc_request.requested_plan
201215

202216
video_source = "realtime browser stream"
203-
if self._webrtc_request.rtsp_url:
217+
if webrtc_request.rtsp_url:
204218
video_source = "rtsp"
205-
elif not self._webrtc_request.webrtc_realtime_processing:
219+
elif not webrtc_request.webrtc_realtime_processing:
206220
video_source = "buffered browser stream"
207221

208222
usage_collector.record_usage(
209223
source=workflow_id,
210224
category="modal",
211-
api_key=self._webrtc_request.api_key,
225+
api_key=webrtc_request.api_key,
212226
resource_details={
213227
"plan": webrtc_plan,
214228
"billable": True,
@@ -221,13 +235,6 @@ def send_answer(obj: WebRTCWorkerResult):
221235
usage_collector.push_usage_payloads()
222236
logger.info("Function completed")
223237

224-
# https://modal.com/docs/reference/modal.enter
225-
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
226-
@modal.enter(snap=True)
227-
def start(self):
228-
# TODO: pre-load models
229-
logger.info("Starting container")
230-
231238
@modal.exit()
232239
def stop(self):
233240
logger.info("Stopping container")
@@ -238,7 +245,11 @@ def stop(self):
238245
**decorator_kwargs,
239246
)
240247
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
241-
pass
248+
# https://modal.com/docs/reference/modal.enter
249+
@modal.enter(snap=True)
250+
def start(self):
251+
# TODO: pre-load models on CPU
252+
logger.info("Starting CPU container")
242253

243254
@app.cls(
244255
**{
@@ -250,7 +261,39 @@ class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
250261
}
251262
)
252263
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
253-
pass
264+
# https://modal.com/docs/reference/modal.enter
265+
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
266+
@modal.enter(snap=True)
267+
def start(self):
268+
logger.info("Starting GPU container")
269+
logger.info("Preload hf ids: %s", PRELOAD_HF_IDS)
270+
logger.info("Preload models: %s", PRELOAD_MODELS)
271+
if PRELOAD_HF_IDS:
272+
# Kick off pre-loading of models (owlv2 preloading is based on module-level singleton)
273+
logger.info("Preloading owlv2 base model")
274+
import inference.models.owlv2.owlv2
275+
if PRELOAD_MODELS:
276+
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
277+
model_manager = ModelManager(model_registry=model_registry)
278+
for model_id in PRELOAD_MODELS:
279+
try:
280+
de_aliased_model_id = resolve_roboflow_model_alias(
281+
model_id=model_id
282+
)
283+
logger.info(f"Preloading model: {de_aliased_model_id}")
284+
model_manager.add_model(
285+
model_id=de_aliased_model_id,
286+
api_key=None,
287+
countinference=False,
288+
service_secret=ROBOFLOW_INTERNAL_SERVICE_SECRET,
289+
)
290+
except Exception as exc:
291+
logger.error(
292+
"Failed to preload model %s: %s",
293+
model_id,
294+
exc,
295+
)
296+
self._model_manager = model_manager
254297

255298
def spawn_rtc_peer_connection_modal(
256299
webrtc_request: WebRTCWorkerRequest,

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
detect_image_output,
5050
process_frame,
5151
)
52+
from inference.core.managers.base import ModelManager
5253
from inference.core.roboflow_api import get_workflow_specification
5354
from inference.core.workflows.core_steps.common.serializers import (
5455
serialize_wildcard_kind,
@@ -83,6 +84,7 @@ def __init__(
8384
asyncio_loop: asyncio.AbstractEventLoop,
8485
workflow_configuration: WorkflowConfiguration,
8586
api_key: str,
87+
model_manager: Optional[ModelManager] = None,
8688
data_output: Optional[List[str]] = None,
8789
stream_output: Optional[str] = None,
8890
has_video_track: bool = True,
@@ -134,6 +136,7 @@ def __init__(
134136
workflows_thread_pool_workers=workflow_configuration.workflows_thread_pool_workers,
135137
cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit,
136138
video_metadata_input_name=workflow_configuration.video_metadata_input_name,
139+
model_manager=model_manager,
137140
)
138141

139142
def set_track(self, track: RemoteStreamTrack):
@@ -362,6 +365,7 @@ def __init__(
362365
asyncio_loop: asyncio.AbstractEventLoop,
363366
workflow_configuration: WorkflowConfiguration,
364367
api_key: str,
368+
model_manager: Optional[ModelManager] = None,
365369
data_output: Optional[List[str]] = None,
366370
stream_output: Optional[str] = None,
367371
has_video_track: bool = True,
@@ -383,6 +387,7 @@ def __init__(
383387
declared_fps=declared_fps,
384388
termination_date=termination_date,
385389
terminate_event=terminate_event,
390+
model_manager=model_manager,
386391
)
387392

388393
async def _auto_detect_stream_output(
@@ -466,6 +471,7 @@ async def init_rtc_peer_connection_with_loop(
466471
webrtc_request: WebRTCWorkerRequest,
467472
send_answer: Callable[[WebRTCWorkerResult], None],
468473
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
474+
model_manager: Optional[ModelManager] = None,
469475
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
470476
) -> RTCPeerConnectionWithLoop:
471477
termination_date = None
@@ -517,6 +523,7 @@ async def init_rtc_peer_connection_with_loop(
517523
video_processor = VideoTransformTrackWithLoop(
518524
asyncio_loop=asyncio_loop,
519525
workflow_configuration=webrtc_request.workflow_configuration,
526+
model_manager=model_manager,
520527
api_key=webrtc_request.api_key,
521528
data_output=data_fields,
522529
stream_output=stream_field,
@@ -530,6 +537,7 @@ async def init_rtc_peer_connection_with_loop(
530537
video_processor = VideoFrameProcessor(
531538
asyncio_loop=asyncio_loop,
532539
workflow_configuration=webrtc_request.workflow_configuration,
540+
model_manager=model_manager,
533541
api_key=webrtc_request.api_key,
534542
data_output=data_fields,
535543
stream_output=None,

0 commit comments

Comments
 (0)