Skip to content

Commit c580ee2

Browse files
committed
Use shared memory for findall
1 parent e1cb12e commit c580ee2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/indexing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ end
2323

2424
function Base.findall(bools::WrappedMtlArray{Bool})
2525
I = keytype(bools)
26-
indices = cumsum(reshape(bools, prod(size(bools))))
26+
boolslen = prod(size(bools))
2727

28-
n = @allowscalar indices[end]
28+
indices = MtlVector{Int64, Metal.SharedStorage}(undef, boolslen)
29+
cumsum!(indices, reshape(bools, boolslen))
30+
31+
n = indices[end]
2932
ys = similar(bools, I, n)
3033

3134
if n > 0

0 commit comments

Comments
 (0)