diff --git a/docs/src/threading.md b/docs/src/threading.md new file mode 100644 index 0000000..467298b --- /dev/null +++ b/docs/src/threading.md @@ -0,0 +1,109 @@ +# Threading Support + +DiskArrays.jl provides support for threaded algorithms when the underlying +storage backend supports thread-safe read operations. + +## Threading Trait System + +The threading support is based on a trait system that allows backends to +declare whether they support thread-safe operations: + +```julia +using DiskArrays + +# Check if an array supports threading +is_thread_safe(my_array) + +# Get the threading trait +threading_trait(my_array) # Returns ThreadSafe() or NotThreadSafe() +``` + +## Global Threading Control + +You can globally enable or disable threading for all DiskArray operations: + +```julia +# Disable threading globally +disable_threading() + +# Enable threading globally (default) +enable_threading() + +# Check current status +threading_enabled() +``` + +## Implementing Threading Support in Backends + +Backend developers can opt into threading support by overriding the threading_trait method: + +```julia +# For a hypothetical ThreadSafeArray type +DiskArrays.threading_trait(::Type{ThreadSafeArray}) = DiskArrays.ThreadSafe() +``` + +Important: Only declare your backend as thread-safe if: + +* Multiple threads can safely read from the storage simultaneously +* The underlying storage system (files, network, etc.) supports concurrent access +* No global state is modified during read operations + +## Implementing Threading Support for Disk Array Methods + +Add a (or rename the existing) single-threaded method using this signature: + +``` +function Base.myfun(::Type{SingleThreaded}, ...) +``` + +Write a threaded version using this signature: + +``` +function Base.myfun(::Type{MultiThreaded}, ...) +``` + +Add this additional method to automatically dispatch between the two: + +``` +Base.myfun(v::AbstractDiskArray, ...) = myfun(should_use_threading(v), ...) +``` + +## Threaded Algorithms + +Currently supported threaded algorithms: + +### unique + +```julia +# Will automatically use threading if backend supports it +result = unique(my_disk_array) + +# With a function +result = unique(x -> x % 10, my_disk_array) + +# Explicitly use threaded version +result = unique(MultiThreaded, f, my_disk_array) +``` + +The threaded unique algorithm: + +* Processes each chunk in parallel using `Threads.@threads` +* Combines results using a reduction operation +* Falls back to single-threaded implementation for non-thread-safe backends + +### count + +Similarly to `unique`, threads will be automatically used unless disabled, or +can be explicitly used: + +``` +count([f], my_disk_array) +count(MultiThreaded, f, my_disk_array) +``` + +## Performance Considerations + +* Threading is most beneficial for arrays with many chunks +* I/O bound operations may see limited speedup due to storage bottlenecks +* Consider the overhead of thread coordination for small arrays +* Test with your specific storage backend and access patterns diff --git a/src/DiskArrays.jl b/src/DiskArrays.jl index 9f0bba6..6fe3aed 100644 --- a/src/DiskArrays.jl +++ b/src/DiskArrays.jl @@ -19,6 +19,7 @@ export AbstractDiskArray, eachchunk, ChunkIndex, ChunkIndices include("scalar.jl") include("chunks.jl") include("diskarray.jl") +include("threading.jl") include("batchgetindex.jl") include("diskindex.jl") include("indexing.jl") @@ -37,6 +38,9 @@ include("show.jl") include("cached.jl") include("pad.jl") +export ThreadingTrait, ThreadSafe, NotThreadSafe, threading_trait, is_thread_safe, + AlgorithmTrait, SingleThreaded, MultiThreaded, enable_threading, threading_enabled + # The all-in-one macro macro implement_diskarray(t) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 5df3c83..53a306b 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -60,16 +60,40 @@ for fname in [:sum, :prod, :all, :any, :minimum, :maximum] end end -Base.count(v::AbstractDiskArray) = count(identity, v::AbstractDiskArray) -function Base.count(f, v::AbstractDiskArray) +Base.count(v::AbstractDiskArray) = count(should_use_threading(v), identity, v::AbstractDiskArray) +Base.count(f, v::AbstractDiskArray) = count(should_use_threading(v), f, v::AbstractDiskArray) + +function Base.count(::Type{SingleThreaded}, f, v::AbstractDiskArray) sum(eachchunk(v)) do chunk count(f, v[chunk...]) end end -Base.unique(v::AbstractDiskArray) = unique(identity, v) -function Base.unique(f, v::AbstractDiskArray) +function Base.count(::Type{MultiThreaded}, f, v::AbstractDiskArray) + chunks = eachchunk(v) + u = Vector{Int}(undef, length(chunks)) + Threads.@threads for i in 1:length(chunks) + u[i] = count(f, v[chunks[i]...]) + end + sum(u) +end + +Base.unique(v::AbstractDiskArray) = unique(should_use_threading(v), identity, v) +Base.unique(f, v::AbstractDiskArray) = unique(should_use_threading(v), f, v) + +function Base.unique(::Type{SingleThreaded}, f, v::AbstractDiskArray) reduce((unique(f, v[c...]) for c in eachchunk(v))) do acc, u unique!(f, append!(acc, u)) end end + +function Base.unique(::Type{MultiThreaded}, f, v::AbstractDiskArray) + chunks = eachchunk(v) + u = Vector{Vector{eltype(v)}}(undef, length(chunks)) + Threads.@threads for i in 1:length(chunks) + u[i] = unique(f, v[chunks[i]...]) + end + reduce(u) do acc, t + unique!(f, append!(acc, t)) + end +end diff --git a/src/threading.jl b/src/threading.jl new file mode 100644 index 0000000..b2f0467 --- /dev/null +++ b/src/threading.jl @@ -0,0 +1,85 @@ +""" + ThreadingTrait + +Trait to indicate whether a DiskArray backend supports thread-safe operations. +""" +abstract type ThreadingTrait end + +""" + ThreadSafe() + +Indicates that the DiskArray backend supports thread-safe read operations. +""" +struct ThreadSafe <: ThreadingTrait end + +""" + NotThreadSafe() + +Indicates that the DiskArray backend does not support thread-safe operations. +Default for all backends unless explicitly overridden. +""" +struct NotThreadSafe <: ThreadingTrait end + +""" + threading_trait(::Type{T}) -> ThreadingTrait + threading_trait(x) -> ThreadingTrait + +Return the threading trait for a DiskArray type or instance. +Defaults to `NotThreadSafe()` for safety. +""" +threading_trait(::Type{<:AbstractDiskArray}) = NotThreadSafe() +threading_trait(x::AbstractDiskArray) = threading_trait(typeof(x)) + +""" + is_thread_safe(x) -> Bool + +Check if a DiskArray supports thread-safe operations. +""" +is_thread_safe(x) = threading_trait(x) isa ThreadSafe + +""" + AlgorithmTrait + +Trait to indicate whether a method is multithreaded or not +""" +abstract type AlgorithmTrait end + +""" + SingleThreaded() + +Indicates that a method uses just one thread +""" +struct SingleThreaded <: AlgorithmTrait end + +""" + MultiThreaded() + +Indicates that a method uses all threads available +""" +struct MultiThreaded <: AlgorithmTrait end + +# Global threading control +const THREADING_ENABLED = Ref(true) + +""" + enable_threading(enable::Bool=true) + +Globally enable or disable threading for DiskArray operations. +When disabled, all algorithms will run single-threaded regardless of backend support. +""" +enable_threading(enable::Bool=true) = (THREADING_ENABLED[] = enable) + +""" + threading_enabled() -> Bool + +Check if threading is globally enabled. +""" +threading_enabled() = THREADING_ENABLED[] + +""" + should_use_threading(x) -> Val(Bool) + +Determine if threading should be used for a given DiskArray. +Returns true only if both global threading is enabled AND the backend is thread-safe. +""" +should_use_threading(x) = threading_enabled() && is_thread_safe(x) ? MultiThreaded : SingleThreaded diff --git a/test/runtests.jl b/test/runtests.jl index eaa3b49..ae5bf42 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,8 @@ using TraceFuns, Suppressor # using JET # JET.report_package(DiskArrays) +include("threading.jl") + @testset "Aqua.jl" begin Aqua.test_ambiguities([DiskArrays, Base, Core]) Aqua.test_unbound_args(DiskArrays) diff --git a/test/threading.jl b/test/threading.jl new file mode 100644 index 0000000..e09a56e --- /dev/null +++ b/test/threading.jl @@ -0,0 +1,110 @@ +# Mock thread-safe DiskArray for testing +struct MockThreadSafeDiskArray{T,N} <: AbstractDiskArray{T,N} + data::Array{T,N} + chunks::NTuple{N,Int} +end + +Base.size(a::MockThreadSafeDiskArray) = size(a.data) +Base.getindex(a::MockThreadSafeDiskArray, i::Int...) = a.data[i...] +DiskArrays.eachchunk(a::MockThreadSafeDiskArray) = DiskArrays.GridChunks(a, a.chunks) +DiskArrays.haschunks(::MockThreadSafeDiskArray) = DiskArrays.Chunked() +DiskArrays.readblock!(a::MockThreadSafeDiskArray, aout, r::AbstractUnitRange...) = (aout .= a.data[r...]) + +# Override threading trait for our mock array +DiskArrays.threading_trait(::Type{<:MockThreadSafeDiskArray}) = DiskArrays.ThreadSafe() + +@testset "Threading Traits" begin + # Test default behavior (not thread safe) + regular_array = ChunkedDiskArray(rand(10, 10), (5, 5)) + @test DiskArrays.threading_trait(regular_array) isa DiskArrays.NotThreadSafe + @test !DiskArrays.is_thread_safe(regular_array) + + # Test thread-safe array + thread_safe_array = MockThreadSafeDiskArray(rand(10, 10), (5, 5)) + @test DiskArrays.threading_trait(thread_safe_array) isa DiskArrays.ThreadSafe + @test DiskArrays.is_thread_safe(thread_safe_array) +end + +@testset "Threading Control" begin + # Test global threading control + @test DiskArrays.threading_enabled() # Should be true by default + + DiskArrays.enable_threading(false) + @test !DiskArrays.threading_enabled() + + DiskArrays.enable_threading() + @test DiskArrays.threading_enabled() + + # Test should_use_threading logic + thread_safe_array = MockThreadSafeDiskArray(rand(10, 10), (5, 5)) + regular_array = ChunkedDiskArray(rand(10, 10), (5, 5)) + + DiskArrays.enable_threading() + @test DiskArrays.should_use_threading(thread_safe_array) == MultiThreaded + @test DiskArrays.should_use_threading(regular_array) == SingleThreaded + + DiskArrays.enable_threading(false) + @test DiskArrays.should_use_threading(thread_safe_array) == SingleThreaded + @test DiskArrays.should_use_threading(regular_array) == SingleThreaded + + # Reset to default + DiskArrays.enable_threading() +end + +@testset "Threaded unique" begin + # Test with thread-safe array + data = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 1, 2, 3, 4, 5, 5, 6, 6, 6, 7] + reshape_data = reshape(data, 4, 5) + thread_safe_array = MockThreadSafeDiskArray(reshape_data, (2, 3)) + + result = unique(thread_safe_array) + expected = unique(data) + @test sort(result) == sort(expected) + + # Test with function + result_with_func = unique(x -> x % 3, thread_safe_array) + expected_with_func = unique(x -> x % 3, data) + @test sort(result_with_func) == sort(expected_with_func) + + # Test fallback for non-thread-safe array + regular_array = ChunkedDiskArray(reshape_data, (2, 3)) + result_fallback = unique(regular_array) + @test sort(result_fallback) == sort(expected) + + # Test with threading disabled + DiskArrays.enable_threading(false) + result_no_threading = unique(thread_safe_array) + @test sort(result_no_threading) == sort(expected) + DiskArrays.enable_threading() # Reset +end + +@testset "Threaded count" begin + # Test with thread-safe array + data_int = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 1, 2, 3, 4, 5, 5, 6, 6, 6, 7] + f(x) = x % 3 == 0 + data = Array(f.(data_int)) # instead of BitMatrix + reshape_data_int = reshape(data_int, 4, 5) + thread_safe_array_int = MockThreadSafeDiskArray(reshape_data_int, (2, 3)) + reshape_data = Array(reshape(data, 4, 5)) + thread_safe_array = MockThreadSafeDiskArray(reshape_data, (2, 3)) + + result = count(thread_safe_array) + expected = count(data) + @test result == expected + + # Test with function + result_with_func = count(f, thread_safe_array_int) + expected_with_func = count(f, data_int) + @test result_with_func == expected_with_func + + # Test fallback for non-thread-safe array + regular_array = ChunkedDiskArray(reshape_data, (2, 3)) + result_fallback = count(regular_array) + @test result_fallback == expected + + # Test with threading disabled + DiskArrays.enable_threading(false) + result_no_threading = count(thread_safe_array) + @test result_no_threading == expected + DiskArrays.enable_threading() # Reset +end