Skip to content

Commit e1935ed

Browse files
committed
Use AK for supported reductions
1 parent d2e94b8 commit e1935ed

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/mapreduce.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,33 @@ end
142142

143143
## COV_EXCL_STOP
144144

145+
Base.mapreduce(f, op, A::WrappedMtlArray;
146+
dims=:, init=nothing) = _mapreduce(f, op, A, init, dims)
147+
# dims=:, init=nothing) = AK.mapreduce(f, op, A, init, dims=dims isa Colon ? nothing : dims)
148+
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:MtlArrayStyle};
149+
dims=:, init=nothing) = _mapreduce(f, op, A, init, dims)
150+
# dims=:, init=nothing) = AK.mapreduce(f, op, A, init, dims=dims isa Colon ? nothing : dims)
151+
152+
# "Borrowed" from GPUArrays
153+
@inline function _init_value(f, op, init, As...)
154+
if init === nothing
155+
ET = Broadcast.combine_eltypes(f, As)
156+
ET = Base.promote_op(op, ET, ET)
157+
(ET === Union{} || ET === Any) &&
158+
error("mapreduce cannot figure the output element type, please pass an explicit init value")
159+
160+
init = AK.neutral_element(op, ET)
161+
end
162+
return init
163+
end
164+
165+
function _mapreduce(f, op, A, init, dims::Union{Nothing, Integer})
166+
init_val = _init_value(f, op, init, A)
167+
AK.mapreduce(f, op, A; init=init_val, neutral=init_val, dims)
168+
end
169+
_mapreduce(f, op, A, init, ::Colon) = _mapreduce(f, op, A, init, nothing)
170+
_mapreduce(f, op, A, init, dims) = GPUArrays._mapreduce(f, op, A; dims, init)
171+
145172
function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
146173
A::Union{AbstractArray,Broadcast.Broadcasted};
147174
init=nothing) where {F, OP, T}

0 commit comments

Comments
 (0)