|
26 | 26 | import jax |
27 | 27 | from jax._src import stages |
28 | 28 | from jax._src import util |
| 29 | +from jax._src.lib import _gpu_ondevice_tracing as gpu_ondevice_tracing |
| 30 | +from jax._src.lib import _profiler |
29 | 31 | import jax.numpy as jnp |
30 | 32 | from jaxlib.mlir import ir |
31 | 33 | from jaxlib.mlir.dialects import arith |
@@ -185,6 +187,20 @@ def __init__(self, entries_per_warpgroup: int, dump_path: str = "sponge"): |
185 | 187 | ) |
186 | 188 | else: |
187 | 189 | self.dump_path = dump_path |
| 190 | + self.tracing_version = self.injection_id = 0 |
| 191 | + self.check_gpu_ondevice_tracing() |
| 192 | + |
| 193 | + def check_gpu_ondevice_tracing(self): |
| 194 | + if ( |
| 195 | + gpu_ondevice_tracing is not None |
| 196 | + and self.tracing_version == 0 |
| 197 | + and self.injection_id == 0 |
| 198 | + ): |
| 199 | + self.tracing_version = gpu_ondevice_tracing.active_version() |
| 200 | + if self.tracing_version > 0: |
| 201 | + self.injection_id = gpu_ondevice_tracing.start_injection_instance( |
| 202 | + self.tracing_version |
| 203 | + ) |
188 | 204 |
|
189 | 205 | def _num_warpgroups( |
190 | 206 | self, grid: tuple[int, ...], block: tuple[int, ...] |
@@ -294,7 +310,42 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): |
294 | 310 | events.append(block_events) |
295 | 311 | events = sorted(events, key=lambda x: x[0]["ts"]) |
296 | 312 | flat_events = list(itertools.chain.from_iterable(events)) |
297 | | - return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) |
| 313 | + |
| 314 | + if f is not None: |
| 315 | + json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) |
| 316 | + |
| 317 | + if ( |
| 318 | + gpu_ondevice_tracing is not None |
| 319 | + and self.tracing_version > 0 |
| 320 | + and self.injection_id > 0 |
| 321 | + ): |
| 322 | + with _profiler.TraceMe( |
| 323 | + "MosaicGpuProfilerDump", inject_id=self.injection_id |
| 324 | + ): |
| 325 | + range_dict = {} |
| 326 | + for event in flat_events: |
| 327 | + range_key = (event["name"], event["pid"], event["tid"]) |
| 328 | + if event["ph"] == "B": |
| 329 | + range_dict[range_key] = event["ts"] |
| 330 | + elif event["ph"] == "E": |
| 331 | + if range_key in range_dict: |
| 332 | + begin_ts = range_dict[range_key] |
| 333 | + range_dict.pop(range_key) |
| 334 | + gpu_ondevice_tracing.inject( |
| 335 | + version=self.tracing_version, |
| 336 | + injection_instance_id=self.injection_id, |
| 337 | + tag_name=event["name"], |
| 338 | + tag_id=self.interned_names[event["name"]], |
| 339 | + pid=event["pid"], |
| 340 | + tid=event["tid"], |
| 341 | + start_time_ns=int(begin_ts * 1e3), |
| 342 | + duration_ps=int((event["ts"] - begin_ts) * 1e6), |
| 343 | + ) |
| 344 | + else: |
| 345 | + warnings.warn( |
| 346 | + "Event start time not found for ending event:" |
| 347 | + f" {event['name']}@{event['ts']}" |
| 348 | + ) |
298 | 349 |
|
299 | 350 |
|
300 | 351 | @dataclasses.dataclass(frozen=True) |
|
0 commit comments