Skip to content

Commit cc8cd91

Browse files
StrongerXipobin6
authored andcommitted
[dynamo] Support is comparison for symnodes (pytorch#140754)
Fixes pytorch#109504. Pull Request resolved: pytorch#140754 Approved by: https://github.com/williamwen42
1 parent 2b993a8 commit cc8cd91

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/dynamo/test_repros.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6338,6 +6338,30 @@ def forward(self, x):
63386338
res = opt_mod(x)
63396339
self.assertEqual(ref, res)
63406340

6341+
def test_symnode_is_op(self):
6342+
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
6343+
def f(x, xs):
6344+
if x.size(0) is xs:
6345+
return x + 1
6346+
else:
6347+
return x * 2
6348+
6349+
t = torch.randn(2)
6350+
res = f(t, [1, 2])
6351+
self.assertEqual(t * 2, res)
6352+
6353+
def test_symnode_is_not_op(self):
6354+
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
6355+
def f(x, xs):
6356+
if x.size(0) is not xs:
6357+
return x + 1
6358+
else:
6359+
return x * 2
6360+
6361+
t = torch.randn(2)
6362+
res = f(t, [1, 2])
6363+
self.assertEqual(t + 1, res)
6364+
63416365

63426366
instantiate_parametrized_tests(ReproTests)
63436367

torch/_dynamo/variables/tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
"<=": operator.le,
7070
"==": operator.eq,
7171
"!=": operator.ne,
72+
"is": operator.is_,
73+
"is not": operator.is_not,
7274
}
7375
# Ops that allow tensor <op> None
7476
supported_const_comparison_ops = {

0 commit comments

Comments
 (0)