Skip to content

Commit f3f78ae

Browse files
committed
fix: avoid simplification when given special operators
1 parent 6015291 commit f3f78ae

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/Simplify.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
44
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
55
import ..OperatorEnumModule: AbstractOperatorEnum
66
import ..ValueInterfaceModule: is_valid
7+
import ..EvaluateModule: any_special_operators
78

89
_una_op_kernel(f::F, l::T) where {F,T} = f(l)
910
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
@@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
1920
# This is only defined for `Node` as it is not possible for, e.g.,
2021
# `GraphNode`.
2122
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
23+
# Skip simplification if special operators are in use
24+
any_special_operators(operators) && return tree
25+
return _combine_operators(tree, operators)
26+
end
27+
28+
function _combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
2229
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
2330
# ((const + var) + const) => (const + var)
2431
# ((const * var) * const) => (const * var)
@@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
2835
if tree.degree == 0
2936
return tree
3037
elseif tree.degree == 1
31-
tree.l = combine_operators(tree.l, operators)
38+
tree.l = _combine_operators(tree.l, operators)
3239
elseif tree.degree == 2
33-
tree.l = combine_operators(tree.l, operators)
34-
tree.r = combine_operators(tree.r, operators)
40+
tree.l = _combine_operators(tree.l, operators)
41+
tree.r = _combine_operators(tree.r, operators)
3542
end
3643

3744
top_level_constant =
@@ -123,6 +130,11 @@ end
123130

124131
# Simplify tree
125132
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
133+
# Skip simplification if special operators are in use
134+
if any_special_operators(operators)
135+
return tree
136+
end
137+
126138
return tree_mapreduce(
127139
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);
128140
)

test/test_special_operators.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,39 @@ end
6464
@test completed == true
6565
@test result == X[1, :] .* 4.0
6666
end
67+
68+
@testitem "Simplification disabled with special operators" begin
69+
using DynamicExpressions
70+
using Test
71+
72+
# Create operators with and without special operator
73+
assign_op = AssignOperator(; target_register=1)
74+
special_operators = OperatorEnum(;
75+
binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op]
76+
)
77+
normal_operators = OperatorEnum(;
78+
binary_operators=[+, -, *, /], unary_operators=[sin, cos]
79+
)
80+
81+
@test DynamicExpressions.SpecialOperatorsModule.any_special_operators(special_operators)
82+
@test !DynamicExpressions.SpecialOperatorsModule.any_special_operators(normal_operators)
83+
84+
# Create expressions using the Expression constructor
85+
const_val = 2.0
86+
87+
# Simple expression that should simplify: 2.0 + 2.0
88+
raw_node = Node(; op=1, l=Node(; val=const_val), r=Node(; val=const_val))
89+
simple_expr = Expression(copy(raw_node); operators=normal_operators)
90+
simple_expr_special = Expression(copy(raw_node); operators=special_operators)
91+
92+
@test string_tree(simple_expr) == "2.0 + 2.0"
93+
@test string_tree(simple_expr_special) == "2.0 + 2.0"
94+
95+
# Test normal simplification works
96+
simplified = simplify_tree!(simple_expr)
97+
@test string_tree(simplified) == "4.0"
98+
99+
# Test simplification is disabled with special operators
100+
not_simplified = simplify_tree!(simple_expr_special)
101+
@test string_tree(not_simplified) == "2.0 + 2.0"
102+
end

0 commit comments

Comments
 (0)