Skip to content

Commit bd6d013

Browse files
authored
Add jetstream_server_startup_latency metric (#118)
* first commit * no labels on metric * format * change measurement * fmt * rename metric * Time -> time * nits * fixed args * int -> float * int -> float * move endpoint to server_lib.py * nit * missing labels
1 parent 6ec67e4 commit bd6d013

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

jetstream/core/metrics/prometheus.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def __new__(cls):
5050
documentation="The percentage of decode slots currently being used",
5151
labelnames=["id", "idx"],
5252
)
53+
_server_startup_latency = Gauge(
54+
name="jetstream_server_startup_latency",
55+
documentation="Total time taken to start the Jetstream server",
56+
labelnames=["id"],
57+
)
5358

5459
def get_prefill_backlog_metric(self):
5560
return self._prefill_backlog.labels(id=self._id)
@@ -62,3 +67,6 @@ def get_generate_backlog_metric(self, idx: int):
6267

6368
def get_slots_used_percentage_metric(self, idx: int):
6469
return self._slots_used_percentage.labels(id=self._id, idx=idx)
70+
71+
def get_server_startup_latency_metric(self):
72+
return self._server_startup_latency.labels(id=self._id)

jetstream/core/server_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import signal
2525
import threading
26+
import time
2627
import traceback
2728
from typing import Any, Type
2829

@@ -122,6 +123,9 @@ def run(
122123
Returns:
123124
JetStreamServer that wraps the grpc server and orchestrator driver.
124125
"""
126+
127+
server_start_time = time.time()
128+
125129
logging.info("Kicking off gRPC server.")
126130
engines = config_lib.get_engines(config, devices=devices)
127131
prefill_params = [pe.load_params() for pe in engines.prefill_engines]
@@ -196,6 +200,11 @@ def run(
196200

197201
jetstream_server.start()
198202

203+
if metrics_collector:
204+
metrics_collector.get_server_startup_latency_metric().set(
205+
time.time() - server_start_time
206+
)
207+
199208
# Setup Jax Profiler
200209
if enable_jax_profiler:
201210
logging.info("Starting JAX profiler server on port %s", jax_profiler_port)

0 commit comments

Comments
 (0)