Skip to content

Commit ccfbd7c

Browse files
Do not preload models on CPU
1 parent 3f1fea3 commit ccfbd7c

File tree

1 file changed

+27
-26
lines changed
  • inference/core/interfaces/webrtc_worker

1 file changed

+27
-26
lines changed

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,37 @@ def send_answer(obj: WebRTCWorkerResult):
235235
usage_collector.push_usage_payloads()
236236
logger.info("Function completed")
237237

238+
@modal.exit()
239+
def stop(self):
240+
logger.info("Stopping container")
241+
242+
# Modal derives function name from class name
243+
# https://modal.com/docs/reference/modal.App#cls
244+
@app.cls(
245+
**decorator_kwargs,
246+
)
247+
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
248+
# https://modal.com/docs/reference/modal.enter
249+
@modal.enter(snap=True)
250+
def start(self):
251+
# TODO: pre-load models on CPU
252+
logger.info("Starting CPU container")
253+
254+
@app.cls(
255+
**{
256+
**decorator_kwargs,
257+
"gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type
258+
"experimental_options": {
259+
"enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
260+
},
261+
}
262+
)
263+
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
238264
# https://modal.com/docs/reference/modal.enter
239265
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
240266
@modal.enter(snap=True)
241267
def start(self):
242-
# TODO: pre-load models
243-
logger.info("Starting container")
268+
logger.info("Starting GPU container")
244269
logger.info("Preload hf ids: %s", PRELOAD_HF_IDS)
245270
logger.info("Preload models: %s", PRELOAD_MODELS)
246271
if PRELOAD_HF_IDS:
@@ -270,30 +295,6 @@ def start(self):
270295
)
271296
self._model_manager = model_manager
272297

273-
@modal.exit()
274-
def stop(self):
275-
logger.info("Stopping container")
276-
277-
# Modal derives function name from class name
278-
# https://modal.com/docs/reference/modal.App#cls
279-
@app.cls(
280-
**decorator_kwargs,
281-
)
282-
class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
283-
pass
284-
285-
@app.cls(
286-
**{
287-
**decorator_kwargs,
288-
"gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type
289-
"experimental_options": {
290-
"enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
291-
},
292-
}
293-
)
294-
class RTCPeerConnectionModalGPU(RTCPeerConnectionModal):
295-
pass
296-
297298
def spawn_rtc_peer_connection_modal(
298299
webrtc_request: WebRTCWorkerRequest,
299300
) -> WebRTCWorkerResult:

0 commit comments

Comments
 (0)