@@ -5609,11 +5609,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5609
5609
dim0_x = torch.export.Dim("dim0_x", min=3)
5610
5610
dim1_x = torch.export.Dim("dim1_x", max=8000)
5611
5611
dynamic_shapes = {"x": (dim0_x, dim1_x)}
5612
- em = torch.export.export (
5612
+ em = torch.export._trace._export (
5613
5613
m,
5614
5614
(a,),
5615
5615
dynamic_shapes=dynamic_shapes,
5616
- prefer_deferred_runtime_asserts_over_guards =True,
5616
+ allow_complex_guards_as_runtime_asserts =True,
5617
5617
)
5618
5618
em.module()(torch.randn(4, 3))
5619
5619
with self.assertRaisesRegex(
@@ -13497,7 +13497,7 @@ def forward(self, x):
13497
13497
13498
13498
def test_disable_forced_specializations_ok(self):
13499
13499
# check that we don't force specialization, and defer to runtime asserts
13500
- # with prefer_deferred_runtime_asserts_over_guards =True to successfully export
13500
+ # with allow_complex_guards_as_runtime_asserts =True to successfully export
13501
13501
# case 1: modulo guards
13502
13502
from torch.export import dims
13503
13503
@@ -13507,11 +13507,11 @@ def forward(self, x):
13507
13507
13508
13508
inputs = (torch.randn(10, 72),)
13509
13509
dx, dy = dims("dx", "dy")
13510
- ep = torch.export.export (
13510
+ ep = torch.export._trace._export (
13511
13511
Mod4Reshape(),
13512
13512
inputs,
13513
13513
dynamic_shapes={"x": (dx, dy)},
13514
- prefer_deferred_runtime_asserts_over_guards =True,
13514
+ allow_complex_guards_as_runtime_asserts =True,
13515
13515
)
13516
13516
out1 = ep.module()(torch.randn(8, 7))
13517
13517
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@@ -13541,11 +13541,11 @@ def forward(self, x, y, z):
13541
13541
13542
13542
for private_api in (True, False):
13543
13543
if private_api:
13544
- ep = torch.export.export (
13544
+ ep = torch.export._trace._export (
13545
13545
FreeReshape(),
13546
13546
inputs,
13547
13547
dynamic_shapes=dynamic_shapes,
13548
- prefer_deferred_runtime_asserts_over_guards =True,
13548
+ allow_complex_guards_as_runtime_asserts =True,
13549
13549
)
13550
13550
else:
13551
13551
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@@ -13582,11 +13582,11 @@ def forward(self, x, y):
13582
13582
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
13583
13583
"y": (Dim("dy", min=8),),
13584
13584
}
13585
- ep = torch.export.export (
13585
+ ep = torch.export._trace._export (
13586
13586
Reshape3d(),
13587
13587
inputs,
13588
13588
dynamic_shapes=dynamic_shapes,
13589
- prefer_deferred_runtime_asserts_over_guards =True,
13589
+ allow_complex_guards_as_runtime_asserts =True,
13590
13590
)
13591
13591
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
13592
13592
self.assertEqual(out1.shape, torch.ones(126).shape)
@@ -13708,11 +13708,11 @@ def forward(self, x):
13708
13708
model = Model()
13709
13709
x = torch.rand(1024, 20, 16)
13710
13710
dynamic_shapes = {"x": {0: Dim("batch")}}
13711
- ep = torch.export.export (
13711
+ ep = torch.export._trace._export (
13712
13712
model,
13713
13713
(x,),
13714
13714
dynamic_shapes=dynamic_shapes,
13715
- prefer_deferred_runtime_asserts_over_guards =True,
13715
+ allow_complex_guards_as_runtime_asserts =True,
13716
13716
)
13717
13717
with self.assertRaisesRegex(
13718
13718
RuntimeError,
@@ -13785,11 +13785,11 @@ def forward(self, x, y):
13785
13785
13786
13786
inputs = (torch.randn(6), torch.randn(12))
13787
13787
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
13788
- ep = torch.export.export (
13788
+ ep = torch.export._trace._export (
13789
13789
Foo(),
13790
13790
inputs,
13791
13791
dynamic_shapes=dynamic_shapes,
13792
- prefer_deferred_runtime_asserts_over_guards =True,
13792
+ allow_complex_guards_as_runtime_asserts =True,
13793
13793
)
13794
13794
# check forward pass
13795
13795
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@@ -13824,7 +13824,7 @@ def forward(self, x, y):
13824
13824
Foo(),
13825
13825
inputs,
13826
13826
dynamic_shapes=dynamic_shapes,
13827
- prefer_deferred_runtime_asserts_over_guards =True,
13827
+ allow_complex_guards_as_runtime_asserts =True,
13828
13828
).run_decompositions()
13829
13829
13830
13830
self.assertEqual(
@@ -14236,11 +14236,11 @@ def forward(self, x, y):
14236
14236
14237
14237
inputs = (torch.randn(5), torch.randn(3))
14238
14238
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
14239
- ep = torch.export.export (
14239
+ ep = torch.export._trace._export (
14240
14240
Foo(),
14241
14241
inputs,
14242
14242
dynamic_shapes=shapes,
14243
- prefer_deferred_runtime_asserts_over_guards =True,
14243
+ allow_complex_guards_as_runtime_asserts =True,
14244
14244
)
14245
14245
# count 2 pow nodes, 2 sym_size.int nodes
14246
14246
self.assertEqual(
@@ -15039,11 +15039,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
15039
15039
15040
15040
for private_api in (True, False):
15041
15041
if private_api:
15042
- ep = torch.export.export (
15042
+ ep = torch.export._trace._export (
15043
15043
ModConstraint(),
15044
15044
(torch.randn(3, 4),),
15045
15045
dynamic_shapes={"x": (dynamic, dynamic)},
15046
- prefer_deferred_runtime_asserts_over_guards =True,
15046
+ allow_complex_guards_as_runtime_asserts =True,
15047
15047
)
15048
15048
else:
15049
15049
ep = export(
@@ -15057,7 +15057,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
15057
15057
for node in ep.graph.nodes
15058
15058
].count(True)
15059
15059
if private_api:
15060
- self.assertEqual(num_asserts, 6 )
15060
+ self.assertEqual(num_asserts, 7 )
15061
15061
with self.assertRaisesRegex(
15062
15062
RuntimeError,
15063
15063
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
0 commit comments