Skip to content

Commit d3f578b

Browse files
vidhyavfacebook-github-bot
authored andcommitted
Added call_one endpoint latency
Summary: As stated here. We should also get throughput metrics because of that. Reviewed By: dulinriley Differential Revision: D85998427
1 parent d003266 commit d3f578b

File tree

1 file changed

+93
-30
lines changed

1 file changed

+93
-30
lines changed

python/monarch/_src/actor/endpoint.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
# pyre-strict
88

99
import functools
10+
import time
1011
from abc import ABC, abstractmethod
11-
from datetime import datetime
1212
from typing import (
1313
Any,
1414
Awaitable,
1515
Callable,
1616
cast,
1717
Concatenate,
18+
Coroutine,
1819
Dict,
1920
Generator,
2021
Generic,
@@ -44,6 +45,51 @@
4445
description="Latency of endpoint call operations in microseconds",
4546
)
4647

48+
# Histogram for measuring endpoint call_one latency
49+
endpoint_call_one_latency_histogram: Histogram = METER.create_histogram(
50+
name="endpoint_call_one_latency.us",
51+
description="Latency of endpoint call_one operations in microseconds",
52+
)
53+
54+
T = TypeVar("T")
55+
56+
57+
def _measure_latency(
58+
coro: Coroutine[Any, Any, T],
59+
histogram: Histogram,
60+
method_name: str,
61+
actor_count: int,
62+
) -> Coroutine[Any, Any, T]:
63+
"""
64+
Decorator to measure and record latency of an async operation.
65+
66+
Args:
67+
coro: The coroutine to measure
68+
histogram: The histogram to record metrics to
69+
method_name: Name of the method being called
70+
actor_count: Number of actors involved in the call
71+
72+
Returns:
73+
A wrapped coroutine that records latency metrics
74+
"""
75+
start_time: float = time.monotonic()
76+
77+
async def _wrapper() -> T:
78+
try:
79+
return await coro
80+
finally:
81+
duration_us = int((time.monotonic() - start_time) * 1_000_000)
82+
histogram.record(
83+
duration_us,
84+
attributes={
85+
"method": method_name,
86+
"actor_count": actor_count,
87+
},
88+
)
89+
90+
return _wrapper()
91+
92+
4793
if TYPE_CHECKING:
4894
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
4995
OncePortReceiver as HyOncePortReceiver,
@@ -65,6 +111,19 @@ def __init__(self, propagator: Propagator) -> None:
65111
self._propagator_arg = propagator
66112
self._cache: Optional[Dict[Any, Any]] = None
67113

114+
def _get_method_name(self) -> str:
115+
"""
116+
Extract method name from this endpoint's method specifier.
117+
118+
Returns:
119+
The method name, or "unknown" if not available
120+
"""
121+
method_specifier = self._call_name()
122+
if hasattr(method_specifier, "name"):
123+
# pyre-ignore[16]: MethodSpecifier subclasses ReturnsResponse and ExplicitPort have .name
124+
return method_specifier.name
125+
return "unknown"
126+
68127
@abstractmethod
69128
def _send(
70129
self,
@@ -119,56 +178,60 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
119178
return r.recv()
120179

121180
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
122-
p, r = self._port(once=True)
181+
p, r_port = self._port(once=True)
182+
r: PortReceiver[R] = r_port
123183
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
124184
extent = self._send(args, kwargs, port=p, selection="choose")
125185
if extent.nelements != 1:
126186
raise ValueError(
127187
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
128188
)
129-
return r.recv()
189+
190+
method_name = self._get_method_name()
191+
192+
async def process() -> R:
193+
result = await r.recv()
194+
return result
195+
196+
measured_coro = _measure_latency(
197+
process(),
198+
endpoint_call_one_latency_histogram,
199+
method_name,
200+
1,
201+
)
202+
return Future(coro=measured_coro)
130203

131204
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
132205
from monarch._src.actor.actor_mesh import RankedPortReceiver, ValueMesh
133206

134-
start_time: datetime = datetime.now()
135207
p, unranked = self._port()
136208
r: RankedPortReceiver[R] = unranked.ranked()
137209
# pyre-ignore[6]: ParamSpec kwargs is compatible with Dict[str, Any]
138210
extent: Extent = self._send(args, kwargs, port=p)
139211

140-
method_specifier = self._call_name()
141-
if hasattr(method_specifier, "name"):
142-
# pyre-ignore[16]: MethodSpecifier subclasses ReturnsResponse and ExplicitPort have .name
143-
method_name: str = method_specifier.name
144-
else:
145-
method_name: str = "unknown"
212+
method_name = self._get_method_name()
146213

147214
async def process() -> "ValueMesh[R]":
148215
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
149216
from monarch._src.actor.shape import NDSlice
150217

151-
try:
152-
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
153-
for _ in range(extent.nelements):
154-
rank, value = await r._recv()
155-
results[rank] = value
156-
call_shape = Shape(
157-
extent.labels,
158-
NDSlice.new_row_major(extent.sizes),
159-
)
160-
return ValueMesh(call_shape, results)
161-
finally:
162-
duration = datetime.now() - start_time
163-
endpoint_call_latency_histogram.record(
164-
duration.microseconds,
165-
attributes={
166-
"method": str(method_name),
167-
"actor_count": extent.nelements,
168-
},
169-
)
218+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
219+
for _ in range(extent.nelements):
220+
rank, value = await r._recv()
221+
results[rank] = value
222+
call_shape = Shape(
223+
extent.labels,
224+
NDSlice.new_row_major(extent.sizes),
225+
)
226+
return ValueMesh(call_shape, results)
170227

171-
return Future(coro=process())
228+
measured_coro = _measure_latency(
229+
process(),
230+
endpoint_call_latency_histogram,
231+
method_name,
232+
extent.nelements,
233+
)
234+
return Future(coro=measured_coro)
172235

173236
def stream(
174237
self, *args: P.args, **kwargs: P.kwargs

0 commit comments

Comments
 (0)