Skip to content

accumulate algorithm selection for Metal implementation being overridden #37

@christiangnrd

Description

@christiangnrd

Currently, when calling accumulate() with a MtlArray, even though it dispatches to the Metal extension accumulate! but still attempts to use the decoupled lookback algorithm since the default non-extension accumulate has DecoupledLookback() as it's default arument, and it gets passed on to the metal accumulate! which happily tries to use DecoupledLookback.

function accumulate(
op, v::AbstractArray, backend::Backend=get_backend(v);
init,
neutral=neutral_element(op, eltype(v)),
dims::Union{Nothing, Int}=nothing,
inclusive::Bool=true,
# Algorithm choice
alg::AccumulateAlgorithm=DecoupledLookback(),
# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
temp_flags::Union{Nothing, AbstractArray}=nothing,
)
dst_type = Base.promote_op(op, eltype(v), typeof(init))
vcopy = similar(v, dst_type)
copyto!(vcopy, v)
accumulate!(
op, vcopy, backend;
init=init,
neutral=neutral,
dims=dims,
inclusive=inclusive,
alg=alg,
block_size=block_size,
temp=temp,
temp_flags=temp_flags,
)
vcopy
end

The quick solution would be to define accumulate in the Metal extension, but I think it would be better (assuming no performance impacts) to just not specify the default arguments in the higher level definitions, and leave them to the implementations.

This will reduce the amount of copy-pasting default arguments, and the docstrings will still let users know what the potential kwargs are.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions