@@ -170,27 +170,26 @@ end
170170# # Base interface
171171
172172Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlVector , dims:: Nothing , init:: Nothing ) =
173- scan ! (op, output, input; dims= 1 )
173+ @inline AK . accumulate ! (op, output, input; dims, init = AK . neutral_element (op, eltype (output)), alg = AK . ScanPrefixes () )
174174
175175Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlArray , dims:: Integer , init:: Nothing ) =
176- scan! (op, output, input; dims= dims)
177-
176+ @inline AK. accumulate! (op, output, input; dims, init= AK. neutral_element (op, eltype (output)), alg= AK. ScanPrefixes ())
178177Base. _accumulate! (op, output:: WrappedMtlArray , input:: MtlVector , dims:: Nothing , init:: Some ) =
179- scan ! (op, output, input; dims= 1 , init= init)
178+ @inline AK . accumulate ! (op, output, input; dims, init= something ( init), alg = AK . ScanPrefixes () )
180179
181180Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlArray , dims:: Integer , init:: Some ) =
182- scan ! (op, output, input; dims= dims , init= init)
181+ @inline AK . accumulate ! (op, output, input; dims, init= something ( init), alg = AK . ScanPrefixes () )
183182
184- Base. accumulate_pairwise! (op, result:: WrappedMtlVector , v:: WrappedMtlVector ) = accumulate! (op, result, v)
183+ Base. accumulate_pairwise! (op, result:: WrappedMtlVector , v:: WrappedMtlVector ) = @inline AK . accumulate! (op, result, v; init = AK . neutral_element (op, eltype (result)), alg = AK . ScanPrefixes () )
185184
186185# default behavior unless dims are specified by the user
187186function Base. accumulate (op, A:: WrappedMtlArray ;
188187 dims:: Union{Nothing,Integer} = nothing , kw... )
188+ nt = values (kw)
189189 if dims === nothing && ! (A isa AbstractVector)
190190 # This branch takes care of the cases not handled by `_accumulate!`.
191- return reshape (accumulate (op, A[:]; kw ... ), size (A))
191+ return reshape (AK . accumulate (op, A[:]; init = ( :init in keys (kw) ? nt . init : AK . neutral_element (op, eltype (A))), alg = AK . ScanPrefixes () ), size (A))
192192 end
193- nt = values (kw)
194193 if isempty (kw)
195194 out = similar (A, Base. promote_op (op, eltype (A), eltype (A)))
196195 elseif keys (nt) === (:init ,)
0 commit comments