1
1
module SpecialOperatorsModule
2
2
3
3
using .. OperatorEnumModule: OperatorEnum
4
- using .. EvaluateModule: _eval_tree_array, @return_on_nonfinite_array , deg2_eval
4
+ using .. EvaluateModule:
5
+ _eval_tree_array, @return_on_nonfinite_array , deg2_eval, ResultOk, get_filled_array
5
6
using .. ExpressionModule: AbstractExpression
6
7
using .. ExpressionAlgebraModule: @declare_expression_operator
7
8
24
25
@inline special_operator (:: Type{AssignOperator} ) = true
25
26
get_op_name (o:: AssignOperator ) = " [{FEATURE_" * string (o. target_register) * " } =]"
26
27
27
- # Base.@kwdef struct WhileOperator <: Function
28
- # max_iters::Int = 100
29
- # end
30
- # @inline special_operator(::Type{WhileOperator}) = true
31
- # function deg2_eval_special(tree, cX, op::WhileOperator, eval_options)
32
- # cond = tree.l
33
- # body = tree.r
34
- # for _ in 1:(op.max_iters)
35
- # let cond_result = _eval_tree_array(cond, cX, operators, eval_options)
36
- # !cond_result.ok && return cond_result
37
- # @return_on_nonfinite_array(eval_options, cond_result.x)
38
- # end
39
- # let body_result = _eval_tree_array(body, cX, operators, eval_options)
40
- # !body_result.ok && return body_result
41
- # @return_on_nonfinite_array(eval_options, body_result.x)
42
- # # TODO : Need to somehow mask instances
43
- # end
44
- # end
45
-
46
- # return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
47
- # end
48
- # TODO : Need to void any instance of buffer when using while loop.
49
-
50
28
function deg1_eval_special (tree, cX, op:: AssignOperator , eval_options, operators)
51
29
result = _eval_tree_array (tree. l, cX, operators, eval_options)
52
30
! result. ok && return result
@@ -58,4 +36,47 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators
58
36
return result
59
37
end
60
38
39
+ Base. @kwdef struct WhileOperator <: Function
40
+ max_iters:: Int = 100
41
+ end
42
+
43
+ @declare_expression_operator ((op:: WhileOperator ), 2 )
44
+ @inline special_operator (:: Type{WhileOperator} ) = true
45
+ get_op_name (o:: WhileOperator ) = " while"
46
+
47
+ # TODO : Need to void any instance of buffer when using while loop.
48
+ function deg2_eval_special (tree, cX, op:: WhileOperator , eval_options, operators)
49
+ cond = tree. l
50
+ body = tree. r
51
+ mask = trues (size (cX, 2 ))
52
+ X = @view cX[:, mask]
53
+ # Initialize the result array for all columns
54
+ result_array = get_filled_array (eval_options. buffer, zero (eltype (cX)), cX, axes (cX, 2 ))
55
+ body_result = ResultOk (result_array, true )
56
+
57
+ for _ in 1 : (op. max_iters)
58
+ cond_result = _eval_tree_array (cond, X, operators, eval_options)
59
+ ! cond_result. ok && return cond_result
60
+ @return_on_nonfinite_array (eval_options, cond_result. x)
61
+
62
+ new_mask = cond_result. x .> 0.0
63
+ any (new_mask) || return body_result
64
+
65
+ # Track which columns are still active
66
+ mask[mask] .= new_mask
67
+ X = @view cX[:, mask]
68
+
69
+ # Evaluate just for active columns
70
+ iter_result = _eval_tree_array (body, X, operators, eval_options)
71
+ ! iter_result. ok && return iter_result
72
+
73
+ # Update the corresponding elements in the result array
74
+ body_result. x[mask] .= iter_result. x
75
+ @return_on_nonfinite_array (eval_options, body_result. x)
76
+ end
77
+
78
+ # We passed max_iters, so this result is invalid
79
+ return ResultOk (body_result. x, false )
80
+ end
81
+
61
82
end
0 commit comments