Skip to content

Commit 763c047

Browse files
committed
Implement sorting with AK
1 parent ea0080a commit 763c047

File tree

5 files changed

+119
-1
lines changed

5 files changed

+119
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "11.2.6"
66
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
9+
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
910
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1011
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -26,6 +27,7 @@ JLD2Ext = "JLD2"
2627
AcceleratedKernels = "0.4.3"
2728
Adapt = "4.0"
2829
GPUArraysCore = "= 0.2.0"
30+
GPUToolbox = "0.2, 0.3, 1"
2931
JLD2 = "0.4, 0.5, 0.6"
3032
KernelAbstractions = "0.9.28, 0.10"
3133
LLVM = "3.9, 4, 5, 6, 7, 8, 9"

src/GPUArrays.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module GPUArrays
22

3+
using GPUToolbox
34
using KernelAbstractions
45
using Serialization
56
using Random
@@ -15,7 +16,7 @@ using LLVM.Interop
1516
using Reexport
1617
@reexport using GPUArraysCore
1718

18-
using KernelAbstractions
19+
import KernelAbstractions as KA
1920
import AcceleratedKernels as AK
2021

2122
# device functionality
@@ -32,6 +33,7 @@ include("host/mapreduce.jl")
3233
include("host/linalg.jl")
3334
include("host/math.jl")
3435
include("host/random.jl")
36+
include("host/sorting.jl")
3537
include("host/quirks.jl")
3638
include("host/uniformscaling.jl")
3739
include("host/statistics.jl")

src/host/sorting.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
abstract type SortingAlgorithm end
3+
struct MergeSortAlg <: SortingAlgorithm end
4+
5+
const MergeSort = MergeSortAlg()
6+
7+
8+
function Base.sort!(c::AnyGPUVector, alg::MergeSortAlg; lt=isless, by=identity, rev=false)
9+
# for reverse sorting, invert the less-than function
10+
if rev
11+
lt = !lt
12+
end
13+
14+
AK.merge_sort!(c; lt, by)
15+
return c
16+
end
17+
18+
function Base.sort!(c::AnyGPUArray; alg::SortingAlgorithm = MergeSort, kwargs...)
19+
return sort!(c, alg; kwargs...)
20+
end
21+
22+
function Base.sort(c::AnyGPUArray; kwargs...)
23+
return sort!(copy(c); kwargs...)
24+
end
25+
26+
function Base.partialsort!(c::AnyGPUVector, k::Union{Integer, OrdinalRange},
27+
alg::MergeSortAlg; lt=isless, by=identity, rev=false)
28+
29+
sort!(c, alg; lt, by, rev)
30+
return @allowscalar copy(c[k])
31+
end
32+
33+
function Base.partialsort!(c::AnyGPUArray, k::Union{Integer, OrdinalRange};
34+
alg::SortingAlgorithm=MergeSort, kwargs...)
35+
return partialsort!(c, k, alg; kwargs...)
36+
end
37+
38+
function Base.partialsort(c::AnyGPUArray, k::Union{Integer, OrdinalRange}; kwargs...)
39+
return partialsort!(copy(c), k; kwargs...)
40+
end
41+
42+
function Base.sortperm!(ix::AnyGPUArray, A::AnyGPUArray; initialized=false, dims=nothing, kwargs...)
43+
if axes(ix) != axes(A)
44+
throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))
45+
end
46+
if !isnothing(dims)
47+
throw(ArgumentError("GPUArrays sort with `dims` kwarg not yet implemented."))
48+
end
49+
50+
AK.merge_sortperm!(ix, A; kwargs...)
51+
return ix
52+
end
53+
54+
function Base.sortperm(c::AnyGPUVector; initialized=false, kwargs...)
55+
AK.merge_sortperm!(KA.allocate(get_backend(c), Int, length(c)), c; kwargs...)
56+
end
57+
58+
function Base.sortperm(c::AnyGPUArray; dims, kwargs...)
59+
# Base errors for Matrices without dims arg, we should too
60+
error("GPU sort with `dims` kwarg not yet implemented.")
61+
# sortperm!(reshape(adapt(get_backend(c), collect(1:length(c))), size(c)), c; initialized=true, dims, kwargs...)
62+
end

test/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ include("testsuite/broadcasting.jl")
9494
include("testsuite/linalg.jl")
9595
include("testsuite/math.jl")
9696
include("testsuite/random.jl")
97+
include("testsuite/sorting.jl")
9798
include("testsuite/uniformscaling.jl")
9899
include("testsuite/statistics.jl")
99100
include("testsuite/alloc_cache.jl")

test/testsuite/sorting.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
@testsuite "sorting/sort" (AT, eltypes)->begin
2+
# Fuzzy correctness testing
3+
@testset "$ET" for ET in filter(x -> x <: Real, eltypes)
4+
for _ in 1:10
5+
num_elems = rand(1:100_000)
6+
@test compare((A)->Base.sort!(A), AT, rand(ET, num_elems))
7+
end
8+
# Not yet implemented
9+
# for _ in 1:5
10+
# size = rand(1:100, 2)
11+
# @test compare((A)->Base.sort!(A; dims=1), AT, rand(ET, size...))
12+
# @test compare((A)->Base.sort!(A; dims=2), AT, rand(ET, size...))
13+
# end
14+
end
15+
end
16+
17+
@testsuite "sorting/sortperm" (AT, eltypes)->begin
18+
# Fuzzy correctness testing
19+
@testset "$ET" for ET in filter(x -> x <: Real, eltypes)
20+
for _ in 1:10
21+
num_elems = rand(1:100_000)
22+
@test compare((ix, A)->Base.sortperm!(ix, A), AT, zeros(Int32, num_elems), rand(ET, num_elems))
23+
end
24+
# Not yet implemented
25+
# for _ in 1:5
26+
# size = rand(1:100, 2)
27+
# @test compare((A)->Base.sort!(A; dims=1), AT, zeros(Int32, size...), rand(ET, size...))
28+
# @test compare((A)->Base.sort!(A; dims=2), AT, zeros(Int32, size...), rand(ET, size...))
29+
# end
30+
end
31+
end
32+
33+
@testsuite "sorting/partialsort" (AT, eltypes)->begin
34+
local N = 10000
35+
@testset "$ET" for ET in filter(x -> x <: Real, eltypes)
36+
@test compare((A)->Base.partialsort!(A, 1), AT, rand(ET, N))
37+
@test compare((A)->Base.partialsort!(A, 1; rev=true), AT, rand(ET, N))
38+
39+
@test compare((A)->Base.partialsort!(A, N), AT, rand(ET, N))
40+
@test compare((A)->Base.partialsort!(A, N; rev=true), AT, rand(ET, N))
41+
42+
@test compare((A)->Base.partialsort!(A, N÷2), AT, rand(ET, N))
43+
@test compare((A)->Base.partialsort!(A, N÷2; rev=true), AT, rand(ET, N))
44+
45+
@test compare((A)->Base.partialsort!(A, (N÷10):(2N÷10)), AT, rand(ET, N))
46+
@test compare((A)->Base.partialsort!(A, (N÷10):(2N÷10); rev=true), AT, rand(ET, N))
47+
48+
@test compare((A)->Base.partialsort!(A, 1:N), AT, rand(ET, N))
49+
@test compare((A)->Base.partialsort!(A, 1:N; rev=true), AT, rand(ET, N))
50+
end
51+
end

0 commit comments

Comments
 (0)