Skip to content

Commit 6e0334b

Browse files
committed
feat: create WhileOperator
1 parent f3f78ae commit 6e0334b

File tree

4 files changed

+120
-27
lines changed

4 files changed

+120
-27
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ import .StringsModule: get_op_name
7777
@reexport import .EvaluateModule:
7878
eval_tree_array, differentiable_eval_tree_array, EvalOptions
7979
import .EvaluateModule: ArrayBuffer
80-
@reexport import .SpecialOperatorsModule: AssignOperator
80+
@reexport import .SpecialOperatorsModule: AssignOperator, WhileOperator
8181
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
8282
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
8383
@reexport import .SimplifyModule: combine_operators, simplify_tree!

src/Evaluate.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ end
343343
if long_compilation_time
344344
return quote
345345
op = operators.binops[op_idx]
346-
special_operator(op) && return deg2_eval_special(tree, cX, op, eval_options)
346+
special_operator(op) &&
347+
return deg2_eval_special(tree, cX, op, eval_options, operators)
347348
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
348349
!result_l.ok && return result_l
349350
@return_on_nonfinite_array(eval_options, result_l.x)
@@ -360,7 +361,7 @@ end
360361
i -> i == op_idx,
361362
i -> let op = operators.binops[i]
362363
if special_operator(op)
363-
deg2_eval_special(tree, cX, op, eval_options)
364+
deg2_eval_special(tree, cX, op, eval_options, operators)
364365
elseif tree.l.degree == 0 && tree.r.degree == 0
365366
deg2_l0_r0_eval(tree, cX, op, eval_options)
366367
elseif tree.r.degree == 0

src/SpecialOperators.jl

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module SpecialOperatorsModule
22

33
using ..OperatorEnumModule: OperatorEnum
4-
using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval
4+
using ..EvaluateModule:
5+
_eval_tree_array, @return_on_nonfinite_array, deg2_eval, ResultOk, get_filled_array
56
using ..ExpressionModule: AbstractExpression
67
using ..ExpressionAlgebraModule: @declare_expression_operator
78

@@ -24,29 +25,6 @@ end
2425
@inline special_operator(::Type{AssignOperator}) = true
2526
get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]"
2627

27-
# Base.@kwdef struct WhileOperator <: Function
28-
# max_iters::Int = 100
29-
# end
30-
# @inline special_operator(::Type{WhileOperator}) = true
31-
# function deg2_eval_special(tree, cX, op::WhileOperator, eval_options)
32-
# cond = tree.l
33-
# body = tree.r
34-
# for _ in 1:(op.max_iters)
35-
# let cond_result = _eval_tree_array(cond, cX, operators, eval_options)
36-
# !cond_result.ok && return cond_result
37-
# @return_on_nonfinite_array(eval_options, cond_result.x)
38-
# end
39-
# let body_result = _eval_tree_array(body, cX, operators, eval_options)
40-
# !body_result.ok && return body_result
41-
# @return_on_nonfinite_array(eval_options, body_result.x)
42-
# # TODO: Need to somehow mask instances
43-
# end
44-
# end
45-
46-
# return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
47-
# end
48-
# TODO: Need to void any instance of buffer when using while loop.
49-
5028
function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators)
5129
result = _eval_tree_array(tree.l, cX, operators, eval_options)
5230
!result.ok && return result
@@ -58,4 +36,47 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators
5836
return result
5937
end
6038

39+
Base.@kwdef struct WhileOperator <: Function
40+
max_iters::Int = 100
41+
end
42+
43+
@declare_expression_operator((op::WhileOperator), 2)
44+
@inline special_operator(::Type{WhileOperator}) = true
45+
get_op_name(o::WhileOperator) = "while"
46+
47+
# TODO: Need to void any instance of buffer when using while loop.
48+
function deg2_eval_special(tree, cX, op::WhileOperator, eval_options, operators)
49+
cond = tree.l
50+
body = tree.r
51+
mask = trues(size(cX, 2))
52+
X = @view cX[:, mask]
53+
# Initialize the result array for all columns
54+
result_array = get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
55+
body_result = ResultOk(result_array, true)
56+
57+
for _ in 1:(op.max_iters)
58+
cond_result = _eval_tree_array(cond, X, operators, eval_options)
59+
!cond_result.ok && return cond_result
60+
@return_on_nonfinite_array(eval_options, cond_result.x)
61+
62+
new_mask = cond_result.x .> 0.0
63+
any(new_mask) || return body_result
64+
65+
# Track which columns are still active
66+
mask[mask] .= new_mask
67+
X = @view cX[:, mask]
68+
69+
# Evaluate just for active columns
70+
iter_result = _eval_tree_array(body, X, operators, eval_options)
71+
!iter_result.ok && return iter_result
72+
73+
# Update the corresponding elements in the result array
74+
body_result.x[mask] .= iter_result.x
75+
@return_on_nonfinite_array(eval_options, body_result.x)
76+
end
77+
78+
# We passed max_iters, so this result is invalid
79+
return ResultOk(body_result.x, false)
80+
end
81+
6182
end

test/test_special_operators.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,74 @@ end
100100
not_simplified = simplify_tree!(simple_expr_special)
101101
@test string_tree(not_simplified) == "2.0 + 2.0"
102102
end
103+
104+
@testitem "WhileOperator basic functionality" begin
105+
using DynamicExpressions
106+
using Test
107+
108+
# Define operators
109+
while_op = WhileOperator(; max_iters=100)
110+
assign_x2 = AssignOperator(; target_register=2)
111+
operators = OperatorEnum(;
112+
binary_operators=[+, -, *, /, while_op], # While is binary operator
113+
unary_operators=[assign_x2],
114+
)
115+
variable_names = ["x1", "x2", "x3"]
116+
117+
# Test data - x2 starts at 1.0 for all samples
118+
X = zeros(Float64, 2, 3)
119+
X[2, :] .= 1.0 # x2 initial value
120+
121+
# Build expression: while (3.0 - x2 > 0) do x2 = x2 + 1.0
122+
x2 = Expression(Node(; feature=2); operators, variable_names)
123+
expr = while_op(3.0 - x2, assign_x2(x2 + 1.0))
124+
125+
@test string_tree(expr) == "while(3.0 - x2, [x2 =](x2 + 1.0))"
126+
127+
result, completed = eval_tree_array(expr, X)
128+
@test completed == true
129+
@test all(result .≈ 3.0) # After 2 iterations, x2 becomes 3.0
130+
@test X[2, :] == [1.0, 1.0, 1.0] # Original data unchanged
131+
end
132+
133+
@testitem "Fibonacci sequence with WhileOperator" begin
134+
using DynamicExpressions
135+
using Test
136+
137+
# Define operators
138+
while_op = WhileOperator(; max_iters=100)
139+
assign_ops = [AssignOperator(; target_register=i) for i in 1:5]
140+
operators = OperatorEnum(;
141+
binary_operators=[+, -, *, /, while_op], unary_operators=assign_ops
142+
)
143+
variable_names = ["x1", "x2", "x3", "x4", "x5"]
144+
145+
# Test data - x2=5 (counter), x3=0 (F(0)), x4=1 (F(1))
146+
X = zeros(Float64, 5, 4)
147+
# Set different Fibonacci sequence positions to calculate
148+
X[2, :] = [3.0, 5.0, 7.0, 10.0] # Calculate F(3), F(5), F(7), F(10)
149+
150+
# Initialize all rows with F(0)=0, F(1)=1
151+
X[3, :] .= 0.0 # x3 = 0.0 (F(0))
152+
X[4, :] .= 1.0 # x4 = 1.0 (F(1))
153+
154+
xs = [Expression(Node(; feature=i); operators, variable_names) for i in 1:5]
155+
156+
# Build expression:
157+
condition = xs[2] # WhileOperator implicitly checks if > 0
158+
body =
159+
assign_ops[5](xs[3]) +
160+
assign_ops[3](xs[4]) +
161+
assign_ops[4](xs[5] + xs[4]) +
162+
assign_ops[2](xs[2] - 1.0)
163+
expr = (while_op(condition, body) * 0.0) + xs[3]
164+
165+
@test string_tree(expr) ==
166+
"(while(x2, (([x5 =](x3) + [x3 =](x4)) + [x4 =](x5 + x4)) + [x2 =](x2 - 1.0)) * 0.0) + x3"
167+
168+
result, completed = eval_tree_array(expr, X)
169+
@test completed == true
170+
171+
# Test each Fibonacci number is correctly calculated
172+
@test result [2.0, 5.0, 13.0, 55.0] # F(3)=2, F(5)=5, F(7)=13, F(10)=55
173+
end

0 commit comments

Comments
 (0)