-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Currently, when calling accumulate() with a MtlArray, even though it dispatches to the Metal extension accumulate! but still attempts to use the decoupled lookback algorithm since the default non-extension accumulate has DecoupledLookback() as it's default arument, and it gets passed on to the metal accumulate! which happily tries to use DecoupledLookback.
AcceleratedKernels.jl/src/accumulate/accumulate.jl
Lines 229 to 261 in 111c89b
| function accumulate( | |
| op, v::AbstractArray, backend::Backend=get_backend(v); | |
| init, | |
| neutral=neutral_element(op, eltype(v)), | |
| dims::Union{Nothing, Int}=nothing, | |
| inclusive::Bool=true, | |
| # Algorithm choice | |
| alg::AccumulateAlgorithm=DecoupledLookback(), | |
| # GPU settings | |
| block_size::Int=256, | |
| temp::Union{Nothing, AbstractArray}=nothing, | |
| temp_flags::Union{Nothing, AbstractArray}=nothing, | |
| ) | |
| dst_type = Base.promote_op(op, eltype(v), typeof(init)) | |
| vcopy = similar(v, dst_type) | |
| copyto!(vcopy, v) | |
| accumulate!( | |
| op, vcopy, backend; | |
| init=init, | |
| neutral=neutral, | |
| dims=dims, | |
| inclusive=inclusive, | |
| alg=alg, | |
| block_size=block_size, | |
| temp=temp, | |
| temp_flags=temp_flags, | |
| ) | |
| vcopy | |
| end |
The quick solution would be to define accumulate in the Metal extension, but I think it would be better (assuming no performance impacts) to just not specify the default arguments in the higher level definitions, and leave them to the implementations.
This will reduce the amount of copy-pasting default arguments, and the docstrings will still let users know what the potential kwargs are.