diff --git a/Project.toml b/Project.toml index f76e97ab..6cf62ea6 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -26,6 +27,7 @@ Distances = "0.10" FillArrays = "0.10, 0.11, 0.12, 0.13, 1" IterativeSolvers = "0.9" MPI = "0.16, 0.17, 0.18, 0.19, 0.20" +Preferences = "1" SparseMatricesCSR = "0.6" StaticArrays = "1" julia = "1.1" diff --git a/src/PartitionedArrays.jl b/src/PartitionedArrays.jl index 163c5578..3da0a844 100644 --- a/src/PartitionedArrays.jl +++ b/src/PartitionedArrays.jl @@ -11,6 +11,10 @@ import IterativeSolvers import Distances using BlockArrays using Adapt +using Preferences + +export set_default_find_rcv_ids +include("preferences.jl") export length_to_ptrs! export rewind_ptrs! diff --git a/src/mpi_array.jl b/src/mpi_array.jl index 7b3d0d32..3fa07d96 100644 --- a/src/mpi_array.jl +++ b/src/mpi_array.jl @@ -660,9 +660,14 @@ end Issend(data, dest::Integer, tag::Integer, comm::MPI.Comm, req=MPI.Request()) = Issend(MPI.Buffer_send(data), dest, tag, comm, req) - function default_find_rcv_ids(::MPIArray) - find_rcv_ids_gather_scatter + @static if default_find_rcv_ids_algorithm == "gather_scatter" + find_rcv_ids_gather_scatter + elseif default_find_rcv_ids_algorithm == "ibarrier" + find_rcv_ids_ibarrier + else + error("Unknown algorithm: $(default_find_rcv_ids_algorithm)") + end end """ diff --git a/src/preferences.jl b/src/preferences.jl new file mode 100644 index 00000000..187e5adf --- /dev/null +++ b/src/preferences.jl @@ -0,0 +1,28 @@ + +""" + set_default_find_rcv_ids(algorithm::String) + +Sets the default algorithm to discover communication neighbors. The available algorithms are: + +- `gather_scatter`: Gathers neighbors in a single processor, builds the communications graph + and then scatters the information back to all processors. + +- `ibarrier`: Implements Alg. 2 in https://dl.acm.org/doi/10.1145/1837853.1693476 + +Feature only available in Julia 1.6 and later due to restrictions from `Preferences.jl`. +""" +function set_default_find_rcv_ids(algorithm::String) + if !(algorithm in ("gather_scatter", "ibarrier")) + throw(ArgumentError("Invalid algorihtm: \"$(algorithm)\"")) + end + + # Set it in our runtime values, as well as saving it to disk + @set_preferences!("default_find_rcv_ids" => algorithm) + @info("New deafult algorithm set; restart your Julia session for this change to take effect!") +end + +@static if VERSION >= v"1.6" + const default_find_rcv_ids_algorithm = @load_preference("default_find_rcv_ids", "gather_scatter") +else + const default_find_rcv_ids_algorithm = "gather_scatter" +end