Skip to content

Commit e2e67a0

Browse files
masnesralpytorchmergebot
authored andcommitted
[logging] Add dynamo_compile fields for pre-dispatch/joint/post-dispatch times (#140306)
Tested internally: P1679622670 Differential Revision: [D65986059](https://our.internmc.facebook.com/intern/diff/D65986059) Pull Request resolved: pytorch/pytorch#140306 Approved by: https://github.com/ezyang
1 parent 1b95ca9 commit e2e67a0

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

test/dynamo/test_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
156156
'Scheduler.__init__': [0.0, 0.0],
157157
'Scheduler.codegen': [0.0, 0.0],
158158
'_compile.compile_inner': [0.0],
159+
'_recursive_joint_graph_passes': [0.0],
159160
'_recursive_post_grad_passes': [0.0, 0.0],
160161
'_recursive_pre_grad_passes': [0.0],
161162
'async_compile.wait': [0.0, 0.0],
@@ -173,7 +174,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
173174
self.assertExpectedInline(
174175
pprint.pformat(time_spent),
175176
"""\
176-
{'backend_compile': 0.0,
177+
{'_recursive_joint_graph_passes': 0.0,
178+
'_recursive_post_grad_passes': 0.0,
179+
'_recursive_pre_grad_passes': 0.0,
180+
'backend_compile': 0.0,
177181
'code_gen': 0.0,
178182
'entire_backward_compile': 0.0,
179183
'entire_frame_compile': 0.0,
@@ -235,9 +239,12 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
235239
'inductor_compile_time_s': 0.0,
236240
'inductor_cumulative_compile_time_us': 0,
237241
'is_forward': True,
242+
'joint_graph_pass_time_us': 0,
238243
'log_format_version': 3,
239244
'non_compliant_ops': set(),
240245
'num_triton_bundles': None,
246+
'post_grad_pass_time_us': 0,
247+
'pre_grad_pass_time_us': 0,
241248
'remote_cache_time_saved_s': None,
242249
'remote_fx_graph_cache_get_time_ms': None,
243250
'remote_fx_graph_cache_get_time_us': None,
@@ -295,9 +302,12 @@ def test_dynamo_timed(self, mock_time, mock_time_ns):
295302
'inductor_compile_time_s': 0.0,
296303
'inductor_cumulative_compile_time_us': 0,
297304
'is_forward': False,
305+
'joint_graph_pass_time_us': None,
298306
'log_format_version': 3,
299307
'non_compliant_ops': None,
300308
'num_triton_bundles': None,
309+
'post_grad_pass_time_us': 0,
310+
'pre_grad_pass_time_us': None,
301311
'remote_cache_time_saved_s': None,
302312
'remote_fx_graph_cache_get_time_ms': None,
303313
'remote_fx_graph_cache_get_time_us': None,

torch/_dynamo/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,9 @@ class CompilationMetrics:
859859
remote_fx_graph_cache_put_time_us: Optional[int] = None
860860
backward_cumulative_compile_time_us: Optional[int] = None
861861
end_time_us: Optional[int] = None
862+
pre_grad_pass_time_us: Optional[int] = None
863+
post_grad_pass_time_us: Optional[int] = None
864+
joint_graph_pass_time_us: Optional[int] = None
862865
log_format_version: int = LOG_FORMAT_VERSION
863866

864867

torch/_inductor/compile_fx.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,11 @@ def _get_subgraph_names(gm: GraphModule) -> Generator[str, None, None]:
325325
def _recursive_pre_grad_passes(
326326
gm: GraphModule, example_inputs: Sequence[InputType]
327327
) -> GraphModule:
328-
with dynamo_timed("_recursive_pre_grad_passes", log_pt2_compile_event=True):
328+
with dynamo_timed(
329+
"_recursive_pre_grad_passes",
330+
log_pt2_compile_event=True,
331+
dynamo_compile_column_us="pre_grad_pass_time_us",
332+
):
329333
for subgraph_name in _get_subgraph_names(gm):
330334
subgraph = getattr(gm, subgraph_name)
331335
# as we don't have recursive example inputs, passing empty set here
@@ -335,14 +339,23 @@ def _recursive_pre_grad_passes(
335339

336340

337341
def _recursive_joint_graph_passes(gm: GraphModule) -> None:
338-
for subgraph_name in _get_subgraph_names(gm):
339-
subgraph = getattr(gm, subgraph_name)
340-
_recursive_joint_graph_passes(subgraph)
341-
joint_graph_passes(gm)
342+
with dynamo_timed(
343+
"_recursive_joint_graph_passes",
344+
log_pt2_compile_event=True,
345+
dynamo_compile_column_us="joint_graph_pass_time_us",
346+
):
347+
for subgraph_name in _get_subgraph_names(gm):
348+
subgraph = getattr(gm, subgraph_name)
349+
_recursive_joint_graph_passes(subgraph)
350+
joint_graph_passes(gm)
342351

343352

344353
def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None:
345-
with dynamo_timed("_recursive_post_grad_passes", log_pt2_compile_event=True):
354+
with dynamo_timed(
355+
"_recursive_post_grad_passes",
356+
log_pt2_compile_event=True,
357+
dynamo_compile_column_us="post_grad_pass_time_us",
358+
):
346359
for subgraph_name in _get_subgraph_names(gm):
347360
subgraph = getattr(gm, subgraph_name)
348361
_recursive_post_grad_passes(subgraph, is_inference)

0 commit comments

Comments
 (0)