Skip to content

Commit a8270dd

Browse files
Revert "kill allow_complex_guards_as_runtime_asserts (pytorch#160198)"
This reverts commit 196232b. Reverted pytorch#160198 on behalf of https://github.com/atalman due to dynamo/test_activation_checkpointing.py::ActivationCheckpointingViaTagsTestsCUDA::test_compile_selective_checkpoint_triton_kernel_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/17289619543/job/49074475338) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/196232bb935cb346f143d5c39e9a73c44121a033) ([comment](pytorch#160198 (comment)))
1 parent 63632fc commit a8270dd

File tree

9 files changed

+67
-47
lines changed

9 files changed

+67
-47
lines changed

test/dynamo/test_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10916,8 +10916,8 @@ def test_shape_env_equal_constructor(self):
1091610916
ShapeEnv not equal: field values don't match:
1091710917
1091810918
==> settings: values don't match.
10919-
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
10920-
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
10919+
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
10920+
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
1092110921
""",
1092210922
)
1092310923
self._replay_and_check(main)

test/export/test_export.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,11 +5609,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
56095609
dim0_x = torch.export.Dim("dim0_x", min=3)
56105610
dim1_x = torch.export.Dim("dim1_x", max=8000)
56115611
dynamic_shapes = {"x": (dim0_x, dim1_x)}
5612-
em = torch.export.export(
5612+
em = torch.export._trace._export(
56135613
m,
56145614
(a,),
56155615
dynamic_shapes=dynamic_shapes,
5616-
prefer_deferred_runtime_asserts_over_guards=True,
5616+
allow_complex_guards_as_runtime_asserts=True,
56175617
)
56185618
em.module()(torch.randn(4, 3))
56195619
with self.assertRaisesRegex(
@@ -13497,7 +13497,7 @@ def forward(self, x):
1349713497

1349813498
def test_disable_forced_specializations_ok(self):
1349913499
# check that we don't force specialization, and defer to runtime asserts
13500-
# with prefer_deferred_runtime_asserts_over_guards=True to successfully export
13500+
# with allow_complex_guards_as_runtime_asserts=True to successfully export
1350113501
# case 1: modulo guards
1350213502
from torch.export import dims
1350313503

@@ -13507,11 +13507,11 @@ def forward(self, x):
1350713507

1350813508
inputs = (torch.randn(10, 72),)
1350913509
dx, dy = dims("dx", "dy")
13510-
ep = torch.export.export(
13510+
ep = torch.export._trace._export(
1351113511
Mod4Reshape(),
1351213512
inputs,
1351313513
dynamic_shapes={"x": (dx, dy)},
13514-
prefer_deferred_runtime_asserts_over_guards=True,
13514+
allow_complex_guards_as_runtime_asserts=True,
1351513515
)
1351613516
out1 = ep.module()(torch.randn(8, 7))
1351713517
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@@ -13541,11 +13541,11 @@ def forward(self, x, y, z):
1354113541

1354213542
for private_api in (True, False):
1354313543
if private_api:
13544-
ep = torch.export.export(
13544+
ep = torch.export._trace._export(
1354513545
FreeReshape(),
1354613546
inputs,
1354713547
dynamic_shapes=dynamic_shapes,
13548-
prefer_deferred_runtime_asserts_over_guards=True,
13548+
allow_complex_guards_as_runtime_asserts=True,
1354913549
)
1355013550
else:
1355113551
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@@ -13582,11 +13582,11 @@ def forward(self, x, y):
1358213582
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
1358313583
"y": (Dim("dy", min=8),),
1358413584
}
13585-
ep = torch.export.export(
13585+
ep = torch.export._trace._export(
1358613586
Reshape3d(),
1358713587
inputs,
1358813588
dynamic_shapes=dynamic_shapes,
13589-
prefer_deferred_runtime_asserts_over_guards=True,
13589+
allow_complex_guards_as_runtime_asserts=True,
1359013590
)
1359113591
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
1359213592
self.assertEqual(out1.shape, torch.ones(126).shape)
@@ -13708,11 +13708,11 @@ def forward(self, x):
1370813708
model = Model()
1370913709
x = torch.rand(1024, 20, 16)
1371013710
dynamic_shapes = {"x": {0: Dim("batch")}}
13711-
ep = torch.export.export(
13711+
ep = torch.export._trace._export(
1371213712
model,
1371313713
(x,),
1371413714
dynamic_shapes=dynamic_shapes,
13715-
prefer_deferred_runtime_asserts_over_guards=True,
13715+
allow_complex_guards_as_runtime_asserts=True,
1371613716
)
1371713717
with self.assertRaisesRegex(
1371813718
RuntimeError,
@@ -13785,11 +13785,11 @@ def forward(self, x, y):
1378513785

1378613786
inputs = (torch.randn(6), torch.randn(12))
1378713787
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
13788-
ep = torch.export.export(
13788+
ep = torch.export._trace._export(
1378913789
Foo(),
1379013790
inputs,
1379113791
dynamic_shapes=dynamic_shapes,
13792-
prefer_deferred_runtime_asserts_over_guards=True,
13792+
allow_complex_guards_as_runtime_asserts=True,
1379313793
)
1379413794
# check forward pass
1379513795
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@@ -13824,7 +13824,7 @@ def forward(self, x, y):
1382413824
Foo(),
1382513825
inputs,
1382613826
dynamic_shapes=dynamic_shapes,
13827-
prefer_deferred_runtime_asserts_over_guards=True,
13827+
allow_complex_guards_as_runtime_asserts=True,
1382813828
).run_decompositions()
1382913829

1383013830
self.assertEqual(
@@ -14236,11 +14236,11 @@ def forward(self, x, y):
1423614236

1423714237
inputs = (torch.randn(5), torch.randn(3))
1423814238
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
14239-
ep = torch.export.export(
14239+
ep = torch.export._trace._export(
1424014240
Foo(),
1424114241
inputs,
1424214242
dynamic_shapes=shapes,
14243-
prefer_deferred_runtime_asserts_over_guards=True,
14243+
allow_complex_guards_as_runtime_asserts=True,
1424414244
)
1424514245
# count 2 pow nodes, 2 sym_size.int nodes
1424614246
self.assertEqual(
@@ -15039,11 +15039,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1503915039

1504015040
for private_api in (True, False):
1504115041
if private_api:
15042-
ep = torch.export.export(
15042+
ep = torch.export._trace._export(
1504315043
ModConstraint(),
1504415044
(torch.randn(3, 4),),
1504515045
dynamic_shapes={"x": (dynamic, dynamic)},
15046-
prefer_deferred_runtime_asserts_over_guards=True,
15046+
allow_complex_guards_as_runtime_asserts=True,
1504715047
)
1504815048
else:
1504915049
ep = export(
@@ -15057,7 +15057,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1505715057
for node in ep.graph.nodes
1505815058
].count(True)
1505915059
if private_api:
15060-
self.assertEqual(num_asserts, 6)
15060+
self.assertEqual(num_asserts, 7)
1506115061
with self.assertRaisesRegex(
1506215062
RuntimeError,
1506315063
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",

torch/_dynamo/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@
258258
# hybrid backed unbacked symints
259259
prefer_deferred_runtime_asserts_over_guards = False
260260

261+
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
262+
# range constraints + dims + derived dims language, we raise constraint violation
263+
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
264+
# and allows complex guards as runtime assertions in the graph.
265+
allow_complex_guards_as_runtime_asserts = False
266+
261267
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
262268
# will wait to see the int change over multiple runs before generalizing and
263269
# (2) it will still always 0/1 specialize an int. When true, this knob

torch/_dynamo/eval_frame.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,7 @@ def export(
17341734
same_signature: bool = True,
17351735
disable_constraint_solver: bool = False,
17361736
prefer_deferred_runtime_asserts_over_guards: bool = False,
1737+
allow_complex_guards_as_runtime_asserts: bool = False,
17371738
_log_export_usage: bool = True,
17381739
constraints: Optional[list[Constraint]] = None,
17391740
**extra_kwargs: Any,
@@ -1960,6 +1961,7 @@ def fakify_with_ambient(
19601961
capture_dynamic_output_shape_ops=True,
19611962
capture_scalar_outputs=True,
19621963
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
1964+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
19631965
),
19641966
_compiling_state_context(),
19651967
):

torch/_dynamo/output_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def __init__(
468468
allow_scalar_outputs=config.capture_scalar_outputs,
469469
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
470470
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
471+
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
471472
co_fields=self.co_fields,
472473
)
473474

torch/_export/non_strict_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def make_fake_inputs(
330330
args,
331331
kwargs,
332332
dynamic_shapes,
333-
prefer_deferred_runtime_asserts_over_guards=False,
333+
allow_complex_guards_as_runtime_asserts=False,
334334
):
335335
"""
336336
Given an nn module, example inputs, and constraints, return a new fake mode,
@@ -382,7 +382,8 @@ def make_fake_inputs(
382382
shape_env=ShapeEnv(
383383
tracked_fakes=[],
384384
co_fields=co_fields,
385-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
385+
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
386+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
386387
trace_asserts=True,
387388
),
388389
allow_non_fake_inputs=True,

torch/export/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def export_for_training(
158158
dynamic_shapes,
159159
strict=strict,
160160
preserve_module_call_signature=preserve_module_call_signature,
161-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
161+
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
162162
)
163163

164164

@@ -282,7 +282,7 @@ def export(
282282
strict=strict,
283283
preserve_module_call_signature=preserve_module_call_signature,
284284
pre_dispatch=True,
285-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
285+
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
286286
)
287287
except Exception as e:
288288
draft_export_msg = (

torch/export/_trace.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def _export_to_torch_ir(
750750
*,
751751
preserve_module_call_signature: tuple[str, ...] = (),
752752
disable_constraint_solver: bool = False,
753-
prefer_deferred_runtime_asserts_over_guards: bool = False,
753+
allow_complex_guards_as_runtime_asserts: bool = False,
754754
restore_fqn: bool = True,
755755
_log_export_usage: bool = True,
756756
same_signature: bool = True,
@@ -810,7 +810,10 @@ def _export_to_torch_ir(
810810
assume_static_by_default=True,
811811
tracing_mode="symbolic",
812812
disable_constraint_solver=disable_constraint_solver,
813-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
813+
# currently the following 2 flags are tied together for export purposes,
814+
# but untangle for sake of dynamo export api
815+
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
816+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
814817
_log_export_usage=_log_export_usage,
815818
same_signature=same_signature,
816819
)(
@@ -1399,7 +1402,7 @@ def _strict_export(
13991402
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
14001403
preserve_module_call_signature: tuple[str, ...],
14011404
orig_in_spec: TreeSpec,
1402-
prefer_deferred_runtime_asserts_over_guards: bool,
1405+
allow_complex_guards_as_runtime_asserts: bool,
14031406
_to_aten_func: Callable,
14041407
) -> ExportArtifact:
14051408
"""
@@ -1413,7 +1416,7 @@ def _strict_export(
14131416
dynamic_shapes,
14141417
preserve_module_call_signature=preserve_module_call_signature,
14151418
restore_fqn=False, # don't need to restore because we will do it later
1416-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
1419+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
14171420
_log_export_usage=False,
14181421
)
14191422

@@ -1861,7 +1864,7 @@ def _non_strict_export(
18611864
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
18621865
preserve_module_call_signature: tuple[str, ...],
18631866
orig_in_spec: TreeSpec,
1864-
prefer_deferred_runtime_asserts_over_guards: bool,
1867+
allow_complex_guards_as_runtime_asserts: bool,
18651868
_to_aten_func: Callable,
18661869
) -> ExportArtifact:
18671870
"""
@@ -1958,7 +1961,7 @@ def forward(self, *args, **kwargs):
19581961
args,
19591962
kwargs,
19601963
dynamic_shapes,
1961-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
1964+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization
19621965
)
19631966

19641967
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
@@ -2076,7 +2079,7 @@ def _export_for_training(
20762079
*,
20772080
strict: bool = True,
20782081
preserve_module_call_signature: tuple[str, ...] = (),
2079-
prefer_deferred_runtime_asserts_over_guards: bool = False,
2082+
allow_complex_guards_as_runtime_asserts: bool = False,
20802083
) -> ExportedProgram:
20812084
global _EXPORT_MODULE_HIERARCHY
20822085
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@@ -2106,7 +2109,7 @@ def _export_for_training(
21062109
dynamic_shapes=dynamic_shapes,
21072110
preserve_module_call_signature=preserve_module_call_signature,
21082111
orig_in_spec=orig_in_spec,
2109-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
2112+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
21102113
_to_aten_func=_export_to_aten_ir_make_fx,
21112114
)
21122115

@@ -2177,7 +2180,7 @@ def _export(
21772180
strict: bool = True,
21782181
preserve_module_call_signature: tuple[str, ...] = (),
21792182
pre_dispatch: bool = False,
2180-
prefer_deferred_runtime_asserts_over_guards: bool = False,
2183+
allow_complex_guards_as_runtime_asserts: bool = False,
21812184
) -> ExportedProgram:
21822185
"""
21832186
Traces either an nn.Module's forward function or just a callable with PyTorch
@@ -2208,7 +2211,7 @@ def _export(
22082211
preserve_module_call_signature: A list of submodule paths for which the original
22092212
calling conventions are preserved as metadata.
22102213
2211-
prefer_deferred_runtime_asserts_over_guards:
2214+
allow_complex_guards_as_runtime_asserts:
22122215
With the current dynamic shapes language for dims and derived dims, we can run into constraints
22132216
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
22142217
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
@@ -2252,7 +2255,7 @@ def _export(
22522255
dynamic_shapes,
22532256
strict=strict,
22542257
preserve_module_call_signature=preserve_module_call_signature,
2255-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
2258+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
22562259
)
22572260
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
22582261
return ep
@@ -2277,7 +2280,7 @@ def _export(
22772280
dynamic_shapes=dynamic_shapes,
22782281
preserve_module_call_signature=preserve_module_call_signature,
22792282
orig_in_spec=original_in_spec,
2280-
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
2283+
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
22812284
_to_aten_func=functools.partial(
22822285
_export_to_aten_ir,
22832286
pre_dispatch=pre_dispatch,

0 commit comments

Comments
 (0)