|
21 | 21 | pre_greedy_node_rewriter, |
22 | 22 | ) |
23 | 23 | from pytensor.raise_op import assert_op |
24 | | -from pytensor.tensor.math import Dot, add, dot |
| 24 | +from pytensor.tensor.math import Dot, add, dot, exp |
25 | 25 | from pytensor.tensor.rewriting.basic import constant_folding |
26 | 26 | from pytensor.tensor.subtensor import AdvancedSubtensor |
27 | | -from pytensor.tensor.type import matrix, values_eq_approx_always_true |
| 27 | +from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector |
28 | 28 | from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype |
29 | 29 | from tests.graph.utils import ( |
30 | 30 | MyOp, |
@@ -441,6 +441,23 @@ def test_merge_noinput(self): |
441 | 441 | assert fg.outputs[0] is fg.outputs[1] |
442 | 442 | assert fg.outputs[0] is not fg.outputs[2] |
443 | 443 |
|
| 444 | + @pytest.mark.parametrize("reverse", [False, True]) |
| 445 | + def test_merge_more_specific_types(self, reverse): |
| 446 | + """Check that we choose the most specific static type when merging variables.""" |
| 447 | + |
| 448 | + x1 = vector("x1", shape=(None,)) |
| 449 | + x2 = vector("x2", shape=(500,)) |
| 450 | + |
| 451 | + y1 = exp(x1) |
| 452 | + y2 = exp(x2) |
| 453 | + |
| 454 | + # Simulate case where we find that x2 is equivalent to x1 |
| 455 | + fg = FunctionGraph([x1, x2], [y2, y1] if reverse else [y1, y2], clone=False) |
| 456 | + fg.replace(x1, x2) |
| 457 | + |
| 458 | + MergeOptimizer().rewrite(fg) |
| 459 | + assert fg.outputs == [y2, y2] |
| 460 | + |
444 | 461 |
|
445 | 462 | class TestEquilibrium: |
446 | 463 | def test_1(self): |
|
0 commit comments