|
142 | 142 |
|
143 | 143 | ## COV_EXCL_STOP |
144 | 144 |
|
| 145 | +Base.mapreduce(f, op, A::WrappedMtlArray; |
| 146 | + dims=:, init=nothing) = _mapreduce(f, op, A, init, dims) |
| 147 | + # dims=:, init=nothing) = AK.mapreduce(f, op, A, init, dims=dims isa Colon ? nothing : dims) |
| 148 | +Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:MtlArrayStyle}; |
| 149 | + dims=:, init=nothing) = _mapreduce(f, op, A, init, dims) |
| 150 | + # dims=:, init=nothing) = AK.mapreduce(f, op, A, init, dims=dims isa Colon ? nothing : dims) |
| 151 | + |
| 152 | +# "Borrowed" from GPUArrays |
| 153 | +@inline function _init_value(f, op, init, As...) |
| 154 | + if init === nothing |
| 155 | + ET = Broadcast.combine_eltypes(f, As) |
| 156 | + ET = Base.promote_op(op, ET, ET) |
| 157 | + (ET === Union{} || ET === Any) && |
| 158 | + error("mapreduce cannot figure the output element type, please pass an explicit init value") |
| 159 | + |
| 160 | + init = AK.neutral_element(op, ET) |
| 161 | + end |
| 162 | + return init |
| 163 | +end |
| 164 | + |
| 165 | +function _mapreduce(f, op, A, init, dims::Union{Nothing, Integer}) |
| 166 | + init_val = _init_value(f, op, init, A) |
| 167 | + AK.mapreduce(f, op, A; init=init_val, neutral=init_val, dims) |
| 168 | +end |
| 169 | +_mapreduce(f, op, A, init, ::Colon) = _mapreduce(f, op, A, init, nothing) |
| 170 | +_mapreduce(f, op, A, init, dims) = GPUArrays._mapreduce(f, op, A; dims, init) |
| 171 | + |
145 | 172 | function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, |
146 | 173 | A::Union{AbstractArray,Broadcast.Broadcasted}; |
147 | 174 | init=nothing) where {F, OP, T} |
|
0 commit comments