@@ -235,12 +235,37 @@ def send_answer(obj: WebRTCWorkerResult):
235235 usage_collector .push_usage_payloads ()
236236 logger .info ("Function completed" )
237237
238+ @modal .exit ()
239+ def stop (self ):
240+ logger .info ("Stopping container" )
241+
242+ # Modal derives function name from class name
243+ # https://modal.com/docs/reference/modal.App#cls
244+ @app .cls (
245+ ** decorator_kwargs ,
246+ )
247+ class RTCPeerConnectionModalCPU (RTCPeerConnectionModal ):
248+ # https://modal.com/docs/reference/modal.enter
249+ @modal .enter (snap = True )
250+ def start (self ):
251+ # TODO: pre-load models on CPU
252+ logger .info ("Starting CPU container" )
253+
254+ @app .cls (
255+ ** {
256+ ** decorator_kwargs ,
257+ "gpu" : WEBRTC_MODAL_FUNCTION_GPU , # https://modal.com/docs/guide/gpu#specifying-gpu-type
258+ "experimental_options" : {
259+ "enable_gpu_snapshot" : WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
260+ },
261+ }
262+ )
263+ class RTCPeerConnectionModalGPU (RTCPeerConnectionModal ):
238264 # https://modal.com/docs/reference/modal.enter
239265 # https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
240266 @modal .enter (snap = True )
241267 def start (self ):
242- # TODO: pre-load models
243- logger .info ("Starting container" )
268+ logger .info ("Starting GPU container" )
244269 logger .info ("Preload hf ids: %s" , PRELOAD_HF_IDS )
245270 logger .info ("Preload models: %s" , PRELOAD_MODELS )
246271 if PRELOAD_HF_IDS :
@@ -270,30 +295,6 @@ def start(self):
270295 )
271296 self ._model_manager = model_manager
272297
273- @modal .exit ()
274- def stop (self ):
275- logger .info ("Stopping container" )
276-
277- # Modal derives function name from class name
278- # https://modal.com/docs/reference/modal.App#cls
279- @app .cls (
280- ** decorator_kwargs ,
281- )
282- class RTCPeerConnectionModalCPU (RTCPeerConnectionModal ):
283- pass
284-
285- @app .cls (
286- ** {
287- ** decorator_kwargs ,
288- "gpu" : WEBRTC_MODAL_FUNCTION_GPU , # https://modal.com/docs/guide/gpu#specifying-gpu-type
289- "experimental_options" : {
290- "enable_gpu_snapshot" : WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
291- },
292- }
293- )
294- class RTCPeerConnectionModalGPU (RTCPeerConnectionModal ):
295- pass
296-
297298 def spawn_rtc_peer_connection_modal (
298299 webrtc_request : WebRTCWorkerRequest ,
299300 ) -> WebRTCWorkerResult :
0 commit comments