Skip to content

Commit 335a1d1

Browse files
Add TraceMe for Mosaic Kernel Injection Id
PiperOrigin-RevId: 803537235
1 parent d846dcf commit 335a1d1

File tree

8 files changed

+97
-23
lines changed

8 files changed

+97
-23
lines changed

jax/_src/lib/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def _parse_version(v: str) -> tuple[int, ...]:
100100
import jaxlib._jax as _jax # noqa: F401
101101

102102

103-
104103
import jaxlib.mlir._mlir_libs._jax_mlir_ext as jax_mlir_ext # noqa: F401
105104
from jaxlib._jax import guard_lib as guard_lib # noqa: F401
106105
from jaxlib._jax import jax_jit as jax_jit # noqa: F401
@@ -110,6 +109,11 @@ def _parse_version(v: str) -> tuple[int, ...]:
110109
from jaxlib import _profiler as _profiler # noqa: F401
111110
from jaxlib import _profile_data as _profile_data # noqa: F401
112111

112+
try:
113+
from jaxlib import _gpu_ondevice_tracing as _gpu_ondevice_tracing # noqa: F401
114+
except (ImportError, ModuleNotFoundError):
115+
_gpu_ondevice_tracing = None
116+
113117
from jaxlib._jax import ffi as ffi # noqa: F401
114118
import jaxlib.cpu_sparse as cpu_sparse # noqa: F401
115119
has_cpu_sparse = True

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def __post_init__(self):
121121
object.__setattr__(
122122
self, "dimension_semantics", tuple(self.dimension_semantics)
123123
)
124-
if bool(self.profile_space) ^ bool(self.profile_dir):
124+
if bool(self.profile_dir) and not bool(self.profile_space):
125125
raise ValueError(
126-
"Either both profile_space and profile_dir must be set, or neither."
126+
"If profile_dir is set, the profile_space must be set too."
127127
)
128128

129129

jax/_src/pallas/mosaic_gpu/pallas_call_registration.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,33 @@ def zero_init_gmem_scratch():
120120
)
121121
if (prof_spec := lowering_result.profiler_spec) is not None:
122122
*outs, prof_buffer = outs
123-
out_file = os.path.join(
124-
prof_spec.dump_path,
125-
f"{mlir.sanitize_name(debug_info.func_name)}-{time.time_ns()}-trace.json",
126-
)
127123
def dump_profile(prof_buffer):
128-
try:
129-
with open(out_file, "x") as f:
130-
prof_spec.dump(
131-
prof_buffer,
132-
f,
133-
grid=lowering_result.grid,
134-
block=lowering_result.block,
135-
)
136-
except FileExistsError:
137-
warnings.warn(
138-
f"Failed to dump profile for pallas_call {debug_info.func_src_info}, "
139-
f"profile already exists at {out_file}"
124+
if prof_spec.dump_path is None:
125+
prof_spec.dump(
126+
prof_buffer,
127+
None,
128+
grid=lowering_result.grid,
129+
block=lowering_result.block,
140130
)
131+
else:
132+
out_file = os.path.join(
133+
prof_spec.dump_path,
134+
f"{mlir.sanitize_name(debug_info.func_name)}-{time.time_ns()}-trace.json",
135+
)
136+
try:
137+
with open(out_file, "x") as f:
138+
prof_spec.dump(
139+
prof_buffer,
140+
f,
141+
grid=lowering_result.grid,
142+
block=lowering_result.block,
143+
)
144+
except FileExistsError:
145+
warnings.warn(
146+
"Failed to dump profile for pallas_call"
147+
f" {debug_info.func_src_info}, profile already exists at"
148+
f" {out_file}"
149+
)
141150
def do_callback(prof_buffer):
142151
jax.debug.callback(dump_profile, prof_buffer)
143152
return ()

jax/experimental/mosaic/gpu/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,15 @@ def prof_kernel(*args):
925925
_check_args(*args)
926926
*results, prof_buffer = bind(*args)
927927
def dump_profile(prof_buffer):
928-
out_file = os.path.join(prof_spec.dump_path, f"{time.time_ns()}-trace.json")
929928
try:
930-
with open(out_file, "x") as f:
931-
prof_spec.dump(prof_buffer, f, grid=grid, block=block)
929+
if prof_spec.dump_path is not None:
930+
out_file = os.path.join(
931+
prof_spec.dump_path, f"{time.time_ns()}-trace.json"
932+
)
933+
with open(out_file, "x") as f:
934+
prof_spec.dump(prof_buffer, f, grid=grid, block=block)
935+
else:
936+
prof_spec.dump(prof_buffer, None, grid=grid, block=block)
932937
except FileExistsError:
933938
pass # TODO: Retry
934939
jax.debug.callback(dump_profile, prof_buffer)

jax/experimental/mosaic/gpu/profiler.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import jax
2727
from jax._src import stages
2828
from jax._src import util
29+
from jax._src.lib import _gpu_ondevice_tracing as gpu_ondevice_tracing
30+
from jax._src.lib import _profiler
2931
import jax.numpy as jnp
3032
from jaxlib.mlir import ir
3133
from jaxlib.mlir.dialects import arith
@@ -185,6 +187,20 @@ def __init__(self, entries_per_warpgroup: int, dump_path: str = "sponge"):
185187
)
186188
else:
187189
self.dump_path = dump_path
190+
self.tracing_version = self.injection_id = 0
191+
self.check_gpu_ondevice_tracing()
192+
193+
def check_gpu_ondevice_tracing(self):
194+
if (
195+
gpu_ondevice_tracing is not None
196+
and self.tracing_version == 0
197+
and self.injection_id == 0
198+
):
199+
self.tracing_version = gpu_ondevice_tracing.active_version()
200+
if self.tracing_version > 0:
201+
self.injection_id = gpu_ondevice_tracing.start_injection_instance(
202+
self.tracing_version
203+
)
188204

189205
def _num_warpgroups(
190206
self, grid: tuple[int, ...], block: tuple[int, ...]
@@ -294,7 +310,42 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]):
294310
events.append(block_events)
295311
events = sorted(events, key=lambda x: x[0]["ts"])
296312
flat_events = list(itertools.chain.from_iterable(events))
297-
return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f)
313+
314+
if f is not None:
315+
json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f)
316+
317+
if (
318+
gpu_ondevice_tracing is not None
319+
and self.tracing_version > 0
320+
and self.injection_id > 0
321+
):
322+
with _profiler.TraceMe(
323+
"MosaicGpuProfilerDump", inject_id=self.injection_id
324+
):
325+
range_dict = {}
326+
for event in flat_events:
327+
range_key = (event["name"], event["pid"], event["tid"])
328+
if event["ph"] == "B":
329+
range_dict[range_key] = event["ts"]
330+
elif event["ph"] == "E":
331+
if range_key in range_dict:
332+
begin_ts = range_dict[range_key]
333+
range_dict.pop(range_key)
334+
gpu_ondevice_tracing.inject(
335+
version=self.tracing_version,
336+
injection_instance_id=self.injection_id,
337+
tag_name=event["name"],
338+
tag_id=self.interned_names[event["name"]],
339+
pid=event["pid"],
340+
tid=event["tid"],
341+
start_time_ns=int(begin_ts * 1e3),
342+
duration_ps=int((event["ts"] - begin_ts) * 1e6),
343+
)
344+
else:
345+
warnings.warn(
346+
"Event start time not found for ending event:"
347+
f" {event['name']}@{event['ts']}"
348+
)
298349

299350

300351
@dataclasses.dataclass(frozen=True)

jaxlib/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pytype_strict_library(
9292
"//jaxlib/mosaic/python:gpu_dialect",
9393
"//jaxlib/mosaic/python:tpu_dialect",
9494
"//jaxlib/triton",
95+
"@xla//xla/python:_gpu_ondevice_tracing",
9596
"@xla//xla/python:_profile_data",
9697
"@xla//xla/python:_profiler",
9798
],
@@ -178,6 +179,7 @@ pywrap_library(
178179
"//jaxlib/mlir/_mlir_libs:_stablehlo",
179180
"//jaxlib/mlir/_mlir_libs:_tpu_ext",
180181
"//jaxlib/mlir/_mlir_libs:_triton_ext",
182+
"@xla//xla/python:_gpu_ondevice_tracing",
181183
"@xla//xla/python:_profile_data",
182184
"@xla//xla/python:_profiler",
183185
],

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
662662
mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(
663663
reinterpret_cast<cudaStream_t>(stream));
664664
}
665+
tsl::profiler::TraceMe trace("MosaicGpuLaunchKernel");
665666
std::get<1>(ctx_kernel_comm)(args);
666667
}
667668

@@ -724,6 +725,7 @@ absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs,
724725
mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(
725726
reinterpret_cast<cudaStream_t>(stream));
726727
}
728+
tsl::profiler::TraceMe trace("MosaicGpuLaunchKernel");
727729
std::get<1>(ctx_kernel_comm)(args);
728730
return absl::OkStatus();
729731
}

jaxlib/tools/build_wheel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
213213
f"{source_file_prefix}jaxlib/_jax.{pyext}",
214214
f"{source_file_prefix}jaxlib/_sdy_mpmd.{pyext}",
215215
f"{source_file_prefix}jaxlib/_pathways.{pyext}",
216+
f"{source_file_prefix}jaxlib/_gpu_ondevice_tracing.{pyext}",
216217
f"{source_file_prefix}jaxlib/_profiler.{pyext}",
217218
f"{source_file_prefix}jaxlib/_profile_data.{pyext}",
218219
],

0 commit comments

Comments
 (0)