Skip to content

Commit c38f9c6

Browse files
authored
Merge branch 'main' into jetson-620-cuda-base-pr
2 parents 12dbdaf + d7ab4dd commit c38f9c6

File tree

5 files changed

+102
-55
lines changed

5 files changed

+102
-55
lines changed

inference/core/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,8 @@
709709
WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT = int(
710710
os.getenv("WEBRTC_MODAL_FUNCTION_MAX_TIME_LIMIT", "604800") # 7 days
711711
)
712+
# seconds
713+
WEBRTC_MODAL_SHUTDOWN_RESERVE = int(os.getenv("WEBRTC_MODAL_SHUTDOWN_RESERVE", "1"))
712714
WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT = str2bool(
713715
os.getenv("WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT", "True")
714716
)
@@ -739,6 +741,7 @@
739741
)
740742
WEBRTC_MODAL_RTSP_PLACEHOLDER = os.getenv("WEBRTC_MODAL_RTSP_PLACEHOLDER")
741743
WEBRTC_MODAL_RTSP_PLACEHOLDER_URL = os.getenv("WEBRTC_MODAL_RTSP_PLACEHOLDER_URL")
744+
WEBRTC_MODAL_GCP_SECRET_NAME = os.getenv("WEBRTC_MODAL_GCP_SECRET_NAME")
742745
HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED = str2bool(
743746
os.getenv("HTTP_API_SHARED_WORKFLOWS_THREAD_POOL_ENABLED", "True")
744747
)

inference/core/interfaces/webrtc_worker/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from enum import Enum
23
from typing import Any, Dict, List, Literal, Optional, Union
34

@@ -39,6 +40,7 @@ class WebRTCWorkerRequest(BaseModel):
3940
declared_fps: Optional[float] = None
4041
rtsp_url: Optional[str] = None
4142
processing_timeout: Optional[int] = WEBRTC_MODAL_FUNCTION_TIME_LIMIT
43+
processing_session_started: Optional[datetime.datetime] = None
4244
requested_plan: Optional[str] = "webrtc-gpu-small"
4345
# TODO: replaced with requested_plan
4446
requested_gpu: Optional[str] = None

inference/core/interfaces/webrtc_worker/modal.py

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MODELS_CACHE_AUTH_CACHE_MAX_SIZE,
1616
MODELS_CACHE_AUTH_CACHE_TTL,
1717
MODELS_CACHE_AUTH_ENABLED,
18+
PRELOAD_HF_IDS,
1819
PROJECT,
1920
ROBOFLOW_INTERNAL_SERVICE_SECRET,
2021
WEBRTC_MODAL_APP_NAME,
@@ -26,6 +27,7 @@
2627
WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS,
2728
WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW,
2829
WEBRTC_MODAL_FUNCTION_TIME_LIMIT,
30+
WEBRTC_MODAL_GCP_SECRET_NAME,
2931
WEBRTC_MODAL_IMAGE_NAME,
3032
WEBRTC_MODAL_IMAGE_TAG,
3133
WEBRTC_MODAL_RESPONSE_TIMEOUT,
@@ -44,6 +46,7 @@
4446
from inference.core.interfaces.webrtc_worker.webrtc import (
4547
init_rtc_peer_connection_with_loop,
4648
)
49+
from inference.core.roboflow_api import get_roboflow_workspace
4750
from inference.core.version import __version__
4851
from inference.usage_tracking.collector import usage_collector
4952
from inference.usage_tracking.plan_details import WebRTCPlan
@@ -55,14 +58,20 @@
5558

5659

5760
if modal is not None:
58-
# https://modal.com/docs/reference/modal.Image
59-
video_processing_image = (
60-
modal.Image.from_registry(
61-
f"{WEBRTC_MODAL_IMAGE_NAME}:{WEBRTC_MODAL_IMAGE_TAG if WEBRTC_MODAL_IMAGE_TAG else __version__}"
61+
docker_tag: str = WEBRTC_MODAL_IMAGE_TAG if WEBRTC_MODAL_IMAGE_TAG else __version__
62+
if WEBRTC_MODAL_GCP_SECRET_NAME:
63+
# https://modal.com/docs/reference/modal.Secret#from_name
64+
secret = modal.Secret.from_name(WEBRTC_MODAL_GCP_SECRET_NAME)
65+
# https://modal.com/docs/reference/modal.Image#from_gcp_artifact_registry
66+
video_processing_image = modal.Image.from_gcp_artifact_registry(
67+
f"{WEBRTC_MODAL_IMAGE_NAME}:{docker_tag}",
68+
secret=secret,
6269
)
63-
.pip_install("modal")
64-
.entrypoint([])
65-
)
70+
else:
71+
video_processing_image = modal.Image.from_registry(
72+
f"{WEBRTC_MODAL_IMAGE_NAME}:{docker_tag}"
73+
)
74+
video_processing_image = video_processing_image.pip_install("modal").entrypoint([])
6675

6776
# https://modal.com/docs/reference/modal.Volume
6877
rfcache_volume = modal.Volume.from_name("rfcache", create_if_missing=True)
@@ -71,6 +80,7 @@
7180
app = modal.App(
7281
name=WEBRTC_MODAL_APP_NAME,
7382
image=video_processing_image,
83+
tags={"tag": docker_tag},
7484
)
7585

7686
decorator_kwargs = {
@@ -81,56 +91,51 @@
8191
"enable_memory_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT,
8292
"max_inputs": WEBRTC_MODAL_FUNCTION_MAX_INPUTS,
8393
"env": {
84-
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
85-
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
86-
"PROJECT": PROJECT,
87-
"LOG_LEVEL": LOG_LEVEL,
88-
"INTERNAL_WEIGHTS_URL_SUFFIX": INTERNAL_WEIGHTS_URL_SUFFIX,
89-
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
90-
"MODELS_CACHE_AUTH_CACHE_TTL": str(MODELS_CACHE_AUTH_CACHE_TTL),
91-
"MODELS_CACHE_AUTH_CACHE_MAX_SIZE": str(MODELS_CACHE_AUTH_CACHE_MAX_SIZE),
92-
"METRICS_ENABLED": "False",
9394
"ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS": str(
9495
ALLOW_CUSTOM_PYTHON_EXECUTION_IN_WORKFLOWS
9596
),
96-
"WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE": WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
97+
"ALLOW_WORKFLOW_BLOCKS_ACCESSING_ENVIRONMENTAL_VARIABLES": "False",
98+
"DISABLE_INFERENCE_CACHE": "True",
99+
"DISABLE_VERSION_CHECK": "True",
100+
"HF_HOME": Path(MODEL_CACHE_DIR).joinpath("hf_home").as_posix(),
101+
"INTERNAL_WEIGHTS_URL_SUFFIX": INTERNAL_WEIGHTS_URL_SUFFIX,
102+
"METRICS_ENABLED": "False",
97103
"MODAL_TOKEN_ID": MODAL_TOKEN_ID,
98104
"MODAL_TOKEN_SECRET": MODAL_TOKEN_SECRET,
99105
"MODAL_WORKSPACE_NAME": MODAL_WORKSPACE_NAME,
100-
"ALLOW_WORKFLOW_BLOCKS_ACCESSING_ENVIRONMENTAL_VARIABLES": "False",
101-
"DISABLE_VERSION_CHECK": "True",
102106
"MODEL_CACHE_DIR": MODEL_CACHE_DIR,
103-
"HF_HOME": Path(MODEL_CACHE_DIR).joinpath("hf_home").as_posix(),
107+
"MODELS_CACHE_AUTH_CACHE_MAX_SIZE": str(MODELS_CACHE_AUTH_CACHE_MAX_SIZE),
108+
"MODELS_CACHE_AUTH_CACHE_TTL": str(MODELS_CACHE_AUTH_CACHE_TTL),
109+
"MODELS_CACHE_AUTH_ENABLED": str(MODELS_CACHE_AUTH_ENABLED),
110+
"LOG_LEVEL": LOG_LEVEL,
111+
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
112+
"PRELOAD_HF_IDS": PRELOAD_HF_IDS,
113+
"PROJECT": PROJECT,
114+
"ROBOFLOW_INTERNAL_SERVICE_NAME": WEBRTC_MODAL_ROBOFLOW_INTERNAL_SERVICE_NAME,
115+
"ROBOFLOW_INTERNAL_SERVICE_SECRET": ROBOFLOW_INTERNAL_SERVICE_SECRET,
116+
"WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE": WORKFLOWS_CUSTOM_PYTHON_EXECUTION_MODE,
104117
"TELEMETRY_USE_PERSISTENT_QUEUE": "False",
105-
"DISABLE_INFERENCE_CACHE": "True",
106-
"WEBRTC_MODAL_FUNCTION_GPU": WEBRTC_MODAL_FUNCTION_GPU,
107-
"WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW": str(
108-
WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW
109-
),
110118
"WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS": str(
111119
WEBRTC_MODAL_FUNCTION_BUFFER_CONTAINERS
112120
),
121+
"WEBRTC_MODAL_FUNCTION_GPU": WEBRTC_MODAL_FUNCTION_GPU,
113122
"WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS": str(
114123
WEBRTC_MODAL_FUNCTION_MIN_CONTAINERS
115124
),
125+
"WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW": str(
126+
WEBRTC_MODAL_FUNCTION_SCALEDOWN_WINDOW
127+
),
116128
"WEBRTC_MODAL_FUNCTION_TIME_LIMIT": str(WEBRTC_MODAL_FUNCTION_TIME_LIMIT),
117129
"WEBRTC_MODAL_IMAGE_NAME": WEBRTC_MODAL_IMAGE_NAME,
118130
"WEBRTC_MODAL_IMAGE_TAG": WEBRTC_MODAL_IMAGE_TAG,
119131
"WEBRTC_MODAL_RTSP_PLACEHOLDER": WEBRTC_MODAL_RTSP_PLACEHOLDER,
120132
"WEBRTC_MODAL_RTSP_PLACEHOLDER_URL": WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
121-
"ONNXRUNTIME_EXECUTION_PROVIDERS": "[CUDAExecutionProvider,CPUExecutionProvider]",
122133
},
123134
"volumes": {MODEL_CACHE_DIR: rfcache_volume},
124135
}
125136

126137
class RTCPeerConnectionModal:
127138
_webrtc_request: Optional[WebRTCWorkerRequest] = modal.parameter(default=None)
128-
_exec_session_started: Optional[datetime.datetime] = modal.parameter(
129-
default=None
130-
)
131-
_exec_session_stopped: Optional[datetime.datetime] = modal.parameter(
132-
default=None
133-
)
134139

135140
@modal.method()
136141
def rtc_peer_connection_modal(
@@ -139,6 +144,12 @@ def rtc_peer_connection_modal(
139144
q: modal.Queue,
140145
):
141146
logger.info("*** Spawning %s:", self.__class__.__name__)
147+
logger.info("Inference tag: %s", docker_tag)
148+
_exec_session_started = datetime.datetime.now()
149+
webrtc_request.processing_session_started = _exec_session_started
150+
logger.info(
151+
"WebRTC session started at %s", _exec_session_started.isoformat()
152+
)
142153
logger.info(
143154
"webrtc_realtime_processing: %s",
144155
webrtc_request.webrtc_realtime_processing,
@@ -171,18 +182,11 @@ def send_answer(obj: WebRTCWorkerResult):
171182
send_answer=send_answer,
172183
)
173184
)
174-
175-
# https://modal.com/docs/reference/modal.enter
176-
# Modal usage calculation is relying on no concurrency and no hot instances
177-
@modal.enter()
178-
def start(self):
179-
self._exec_session_started = datetime.datetime.now()
180-
181-
@modal.exit()
182-
def stop(self):
183-
if not self._webrtc_request:
184-
return
185-
self._exec_session_stopped = datetime.datetime.now()
185+
_exec_session_stopped = datetime.datetime.now()
186+
logger.info(
187+
"WebRTC session stopped at %s",
188+
_exec_session_stopped.isoformat(),
189+
)
186190
workflow_id = self._webrtc_request.workflow_configuration.workflow_id
187191
if not workflow_id:
188192
if self._webrtc_request.workflow_configuration.workflow_specification:
@@ -195,16 +199,38 @@ def stop(self):
195199
# requested plan is guaranteed to be set due to validation in spawn_rtc_peer_connection_modal
196200
webrtc_plan = self._webrtc_request.requested_plan
197201

202+
video_source = "realtime browser stream"
203+
if self._webrtc_request.rtsp_url:
204+
video_source = "rtsp"
205+
elif not self._webrtc_request.webrtc_realtime_processing:
206+
video_source = "buffered browser stream"
207+
198208
usage_collector.record_usage(
199209
source=workflow_id,
200210
category="modal",
201211
api_key=self._webrtc_request.api_key,
202-
resource_details={"plan": webrtc_plan},
212+
resource_details={
213+
"plan": webrtc_plan,
214+
"billable": True,
215+
"video_source": video_source,
216+
},
203217
execution_duration=(
204-
self._exec_session_stopped - self._exec_session_started
218+
_exec_session_stopped - _exec_session_started
205219
).total_seconds(),
206220
)
207221
usage_collector.push_usage_payloads()
222+
logger.info("Function completed")
223+
224+
# https://modal.com/docs/reference/modal.enter
225+
# https://modal.com/docs/guide/memory-snapshot#gpu-memory-snapshot
226+
@modal.enter(snap=True)
227+
def start(self):
228+
# TODO: pre-load models
229+
logger.info("Starting container")
230+
231+
@modal.exit()
232+
def stop(self):
233+
logger.info("Stopping container")
208234

209235
# Modal derives function name from class name
210236
# https://modal.com/docs/reference/modal.App#cls
@@ -217,7 +243,6 @@ class RTCPeerConnectionModalCPU(RTCPeerConnectionModal):
217243
@app.cls(
218244
**{
219245
**decorator_kwargs,
220-
"enable_memory_snapshot": False,
221246
"gpu": WEBRTC_MODAL_FUNCTION_GPU, # https://modal.com/docs/guide/gpu#specifying-gpu-type
222247
"experimental_options": {
223248
"enable_gpu_snapshot": WEBRTC_MODAL_FUNCTION_ENABLE_MEMORY_SNAPSHOT
@@ -266,7 +291,21 @@ def spawn_rtc_peer_connection_modal(
266291
)
267292
except modal.exception.NotFoundError:
268293
logger.info("Deploying webrtc modal app %s", WEBRTC_MODAL_APP_NAME)
269-
app.deploy(name=WEBRTC_MODAL_APP_NAME, client=client)
294+
app.deploy(name=WEBRTC_MODAL_APP_NAME, client=client, tag=docker_tag)
295+
296+
workspace_id = webrtc_request.workflow_configuration.workspace_name
297+
if not workspace_id:
298+
try:
299+
workspace_id = get_roboflow_workspace(api_key=webrtc_request.api_key)
300+
webrtc_request.workflow_configuration.workspace_name = workspace_id
301+
except Exception:
302+
pass
303+
304+
tags = {"tag": docker_tag}
305+
if workspace_id:
306+
tags["workspace_id"] = workspace_id
307+
308+
# TODO: tag function run
270309

271310
if webrtc_request.requested_gpu:
272311
RTCPeerConnectionModal = RTCPeerConnectionModalGPU

inference/core/interfaces/webrtc_worker/webrtc.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from inference.core.env import (
2424
WEBRTC_MODAL_RTSP_PLACEHOLDER,
2525
WEBRTC_MODAL_RTSP_PLACEHOLDER_URL,
26+
WEBRTC_MODAL_SHUTDOWN_RESERVE,
2627
)
2728
from inference.core.exceptions import (
2829
MissingApiKeyError,
@@ -465,22 +466,25 @@ async def init_rtc_peer_connection_with_loop(
465466
webrtc_request: WebRTCWorkerRequest,
466467
send_answer: Callable[[WebRTCWorkerResult], None],
467468
asyncio_loop: Optional[asyncio.AbstractEventLoop] = None,
469+
shutdown_reserve: int = WEBRTC_MODAL_SHUTDOWN_RESERVE,
468470
) -> RTCPeerConnectionWithLoop:
469471
termination_date = None
470472
terminate_event = asyncio.Event()
471473

472474
if webrtc_request.processing_timeout is not None:
473475
try:
474476
time_limit_seconds = int(webrtc_request.processing_timeout)
475-
datetime_now = datetime.datetime.now()
477+
datetime_now = webrtc_request.processing_session_started
478+
if datetime_now is None:
479+
datetime_now = datetime.datetime.now()
476480
termination_date = datetime_now + datetime.timedelta(
477-
seconds=time_limit_seconds - 1
481+
seconds=time_limit_seconds - shutdown_reserve
478482
)
479483
logger.info(
480484
"Setting termination date to %s (%s seconds from %s)",
481-
termination_date,
485+
termination_date.isoformat(),
482486
time_limit_seconds,
483-
datetime_now,
487+
datetime_now.isoformat(),
484488
)
485489
except (TypeError, ValueError):
486490
pass
@@ -653,7 +657,7 @@ def on_track(track: RemoteStreamTrack):
653657

654658
@peer_connection.on("connectionstatechange")
655659
async def on_connectionstatechange():
656-
logger.info("Connection state is %s", peer_connection.connectionState)
660+
logger.info("on_connectionstatechange: %s", peer_connection.connectionState)
657661
if peer_connection.connectionState in {"failed", "closed"}:
658662
if video_processor.track:
659663
logger.info("Stopping video processor track")
@@ -662,7 +666,6 @@ async def on_connectionstatechange():
662666
logger.info("Stopping WebRTC peer")
663667
await peer_connection.close()
664668
terminate_event.set()
665-
logger.info("'connectionstatechange' event handler finished")
666669

667670
@peer_connection.on("datachannel")
668671
def on_datachannel(channel: RTCDataChannel):

inference/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.61.0"
1+
__version__ = "0.61.1"
22

33

44
if __name__ == "__main__":

0 commit comments

Comments
 (0)