@@ -28,8 +28,14 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
2828# resolve ambiguities
2929Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
3030 dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31+ # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
3132Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
3233 dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34+ # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35+ Base. mapreduce (f, op, A:: AnyGPUArray ;
36+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37+ Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
3339
3440function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
3541 # figure out the destination container type by looking at the initializer element,
@@ -40,7 +46,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
4046 (ET === Union{} || ET === Any) &&
4147 error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
4248
43- init = neutral_element (op, ET)
49+ init = AK . neutral_element (op, ET)
4450 else
4551 ET = typeof (init)
4652 end
@@ -85,14 +91,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
8591 end
8692end
8793
88- Base. any (A:: AnyGPUArray{Bool} ) = mapreduce (identity, | , A)
89- Base. all (A:: AnyGPUArray{Bool} ) = mapreduce (identity, & , A)
94+ Base. any (A:: AnyGPUArray{Bool} ) = AK . any (identity, A)
95+ Base. all (A:: AnyGPUArray{Bool} ) = AK . all (identity, A)
9096
91- Base. any (f:: Function , A:: AnyGPUArray ) = mapreduce (f, | , A)
92- Base. all (f:: Function , A:: AnyGPUArray ) = mapreduce (f, & , A)
97+ Base. any (f:: Function , A:: AnyGPUArray ) = AK . any (f , A)
98+ Base. all (f:: Function , A:: AnyGPUArray ) = AK . all (f , A)
9399
94100Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
95- mapreduce (pred, Base . add_sum, A; init= init , dims= dims)
101+ AK . count (pred, A; init, dims= dims isa Colon ? nothing : dims)
96102
97103# avoid calling into `initarray!`
98104for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
@@ -101,7 +107,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
101107 fname! = Symbol (fname, ' !' )
102108 @eval begin
103109 Base.$ (fname!)(f:: Function , r:: AnyGPUArray , A:: AnyGPUArray{T} ) where T =
104- GPUArrays. mapreducedim! (f, $ (op), r, A; init= neutral_element ($ (op), T))
110+ GPUArrays. mapreducedim! (f, $ (op), r, A; init= AK . neutral_element ($ (op), T))
105111 end
106112end
107113
0 commit comments