Skip to content

Commit 80db2ac

Browse files
authored
Merge branch 'main' into fix-coreml-to-edge-transform-and-lower
2 parents 223db05 + a624083 commit 80db2ac

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

exir/backend/test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def false_fn(x, y):
10331033

10341034
def f(x, y):
10351035
x = x + y
1036-
x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1036+
x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
10371037
x = x - y
10381038
return x
10391039

exir/tests/control_flow_models.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ def true_branch(x):
2020
def false_branch(x):
2121
return x * x
2222

23-
return torch.ops.higher_order.cond(
24-
inp.sum() > 4, true_branch, false_branch, [inp]
25-
)
23+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
2624

2725
def get_random_inputs(self):
2826
return (torch.rand(5),)
@@ -39,9 +37,7 @@ def true_branch(x):
3937
def false_branch(x):
4038
return x * x * x
4139

42-
return torch.ops.higher_order.cond(
43-
inp.sum() > 4, true_branch, false_branch, [inp]
44-
)
40+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
4541

4642
def get_upper_bound_inputs(self):
4743
return (torch.rand(8),)
@@ -72,9 +68,7 @@ def true_branch(x):
7268
def false_branch(x):
7369
return x * 2
7470

75-
return torch.ops.higher_order.cond(
76-
inp.sum() > 4, true_branch, false_branch, [inp]
77-
)
71+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
7872

7973
def get_random_inputs(self):
8074
return (torch.eye(5) * 2,)

exir/tests/test_passes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,9 +1463,7 @@ def forward(self, pred, x):
14631463
out = torch.nn.functional.linear(
14641464
x, self.w.to(torch.float16).to(torch.float32)
14651465
)
1466-
return torch.ops.higher_order.cond(
1467-
pred, self.true_fn, self.false_fn, [out]
1468-
)
1466+
return torch.cond(pred, self.true_fn, self.false_fn, [out])
14691467

14701468
mod = Module()
14711469
x = torch.randn([3, 3])

0 commit comments

Comments
 (0)