@@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
4
4
import .. NodeUtilsModule: tree_mapreduce, is_node_constant
5
5
import .. OperatorEnumModule: AbstractOperatorEnum
6
6
import .. ValueInterfaceModule: is_valid
7
+ import .. EvaluateModule: any_special_operators
7
8
8
9
_una_op_kernel (f:: F , l:: T ) where {F,T} = f (l)
9
10
_bin_op_kernel (f:: F , l:: T , r:: T ) where {F,T} = f (l, r)
@@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
19
20
# This is only defined for `Node` as it is not possible for, e.g.,
20
21
# `GraphNode`.
21
22
function combine_operators (tree:: Node{T} , operators:: AbstractOperatorEnum ) where {T}
23
+ # Skip simplification if special operators are in use
24
+ any_special_operators (operators) && return tree
25
+ return _combine_operators (tree, operators)
26
+ end
27
+
28
+ function _combine_operators (tree:: Node{T} , operators:: AbstractOperatorEnum ) where {T}
22
29
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
23
30
# ((const + var) + const) => (const + var)
24
31
# ((const * var) * const) => (const * var)
@@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
28
35
if tree. degree == 0
29
36
return tree
30
37
elseif tree. degree == 1
31
- tree. l = combine_operators (tree. l, operators)
38
+ tree. l = _combine_operators (tree. l, operators)
32
39
elseif tree. degree == 2
33
- tree. l = combine_operators (tree. l, operators)
34
- tree. r = combine_operators (tree. r, operators)
40
+ tree. l = _combine_operators (tree. l, operators)
41
+ tree. r = _combine_operators (tree. r, operators)
35
42
end
36
43
37
44
top_level_constant =
123
130
124
131
# Simplify tree
125
132
function simplify_tree! (tree:: AbstractExpressionNode , operators:: AbstractOperatorEnum )
133
+ # Skip simplification if special operators are in use
134
+ if any_special_operators (operators)
135
+ return tree
136
+ end
137
+
126
138
return tree_mapreduce (
127
139
identity, (p, c... ) -> combine_children! (operators, p, c... ), tree, typeof (tree);
128
140
)
0 commit comments