diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index dc8a00e9..06807e0d 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -255,3 +255,12 @@ def get_request_output_length(self): def get_request_success_count_metric(self): return self._request_success_count.labels(id=self._id) + + _total_tokens_in_current_batch = Gauge( + name="jetstream_total_tokens_in_current_batch", + documentation="Total number of tokens in the decode batch", + labelnames=["id", "idx"], + ) + + def get_total_tokens_in_current_batch_metric(self, idx: int): + return self._total_tokens_in_current_batch.labels(id=self._id, idx=idx) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index cefabd05..f54b3c79 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -815,6 +815,7 @@ def _detokenize_thread(self, idx: int): # is a result tokens, and we can't annotate the tuple. result_tokens = result_tokens.convert_to_numpy() + total_tokens_in_batch = 0 for slot, request in my_live_requests.items(): if request is not None: results, complete = token_utils.process_result_tokens( @@ -826,6 +827,9 @@ def _detokenize_thread(self, idx: int): complete=request.complete, ) request.complete = complete + total_tokens_in_batch += result_tokens.get_result_at_slot( + slot + ).lengths # Return some output samples. request.enqueue_samples(results) if request.complete.all(): @@ -873,6 +877,10 @@ def _detokenize_thread(self, idx: int): generate_timestep_added, (time.time() - start_detokenize_time) * 10**3, ) + if self._metrics_collector: + self._metrics_collector.get_total_tokens_in_current_batch_metric( + idx=idx + ).set(total_tokens_in_batch) else: # We want to update a slot with the new channel. slot, active_request = data