diff --git a/src/indexing.jl b/src/indexing.jl index f59d423ea..67f3a56b3 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -23,9 +23,12 @@ end function Base.findall(bools::WrappedMtlArray{Bool}) I = keytype(bools) - indices = cumsum(reshape(bools, prod(size(bools)))) + boolslen = prod(size(bools)) - n = @allowscalar indices[end] + indices = MtlVector{Int64, Metal.SharedStorage}(undef, boolslen) + cumsum!(indices, reshape(bools, boolslen)) + + n = indices[end] ys = similar(bools, I, n) if n > 0 diff --git a/test/array.jl b/test/array.jl index cade39b93..81a676fe2 100644 --- a/test/array.jl +++ b/test/array.jl @@ -490,7 +490,6 @@ end end @testset "accumulate" begin - testf(f, x) = Array(f(MtlArray(x))) ≈ f(x) for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not @test testf(x->accumulate(+, x), rand(Float32, n)) @test testf(x->accumulate(+, x), rand(Float32, n, 2)) @@ -500,17 +499,17 @@ end # multidimensional for (sizes, dims) in ((2,) => 2, (3,4,5) => 2, - (1, 70, 50, 20) => 3) - @test testf(x->accumulate(+, x; dims=dims), rand(Int, sizes)) - @test testf(x->accumulate(+, x), rand(Int, sizes)) + (1, 70, 50, 20) => 3,) + @test testf(x->accumulate(+, x; dims=dims), rand(-10:10, sizes)) + @test testf(x->accumulate(+, x), rand(-10:10, sizes)) end # using initializer for (sizes, dims) in ((2,) => 2, (3,4,5) => 2, (1, 70, 50, 20) => 3) - @test testf(Base.Fix2((x,y)->accumulate(+, x; dims=dims, init=y), rand(Int)), rand(Int, sizes)) - @test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(Int)), rand(Int, sizes)) + @test testf(Base.Fix2((x,y)->accumulate(+, x; dims=dims, init=y), rand(-10:10)), rand(-10:10, sizes)) + @test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(-10:10)), rand(-10:10, sizes)) end # in place