diff --git a/Project.toml b/Project.toml index 8c98be9..0249c97 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" version = "0.4.22" [deps] +ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -10,6 +11,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Distances = "0.10.12" StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0" julia = "1.6" +ArraysOfArrays = "0.6" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 875ede0..d5dbaca 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.1" +julia_version = "1.11.5" manifest_format = "2.0" project_hash = "c2d4f1e1a4db771bb121b0dd2aff4834a9af3804" @@ -13,6 +13,22 @@ version = "0.4.5" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.2" +[[deps.ArraysOfArrays]] +deps = ["Statistics"] +git-tree-sha1 = "8e64c97ac7bffbd3327d8ddadf8dad26b87a2664" +uuid = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" +version = "0.6.6" + + [deps.ArraysOfArrays.extensions] + ArraysOfArraysAdaptExt = "Adapt" + ArraysOfArraysChainRulesCoreExt = "ChainRulesCore" + ArraysOfArraysStaticArraysCoreExt = "StaticArraysCore" + + [deps.ArraysOfArrays.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" version = "1.11.0" @@ -167,10 +183,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.12.12" [[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] +deps = ["ArraysOfArrays", "Distances", "StaticArrays"] path = ".." uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.21" +version = "0.4.22" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -184,7 +200,7 @@ version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" +version = "0.8.5+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6edbc1c..426f347 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -4,6 +4,7 @@ using Distances import Distances: PreMetric, Metric, UnionMinkowskiMetric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters using StaticArrays +using ArraysOfArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree diff --git a/src/inrange.jl b/src/inrange.jl index f10675e..9ae3b64 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -22,10 +22,13 @@ function inrange(tree::NNTree, check_input(tree, points) check_radius(radius) - idxs = [Vector{Int}() for _ in 1:length(points)] + idxs = VectorOfArrays{Int, 1}() + idx = Int[] for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) + inrange_point!(tree, points[i], radius, sortres, idx) + push!(idxs, idx) + resize!(idx, 0) end return idxs end @@ -79,11 +82,14 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb check_input(tree, points) check_radius(radius) n_points = size(points, 2) - idxs = [Vector{Int}() for _ in 1:n_points] + idxs = VectorOfArrays{Int, 1}() + idx = Int[] for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - inrange_point!(tree, point, radius, sortres, idxs[i]) + inrange_point!(tree, point, radius, sortres, idx) + push!(idxs, idx) + resize!(idx, 0) end return idxs end diff --git a/src/knn.jl b/src/knn.jl index 7e13659..23f0065 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -26,10 +26,17 @@ function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, check_input(tree, points) check_k(tree, k) n_points = length(points) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + dists = VectorOfArrays{get_T(eltype(V)), 1}() + idxs = VectorOfArrays{Int, 1}() + dist = zeros(get_T(eltype(V)), k) + idx = zeros(Int, k) + for i in 1:n_points - knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) + knn_point!(tree, points[i], sortres, dist, idx, skip) + push!(dists, dist) + push!(idxs, idx) + fill!(dist, 0) + fill!(idx, 0) end return idxs, dists end @@ -93,12 +100,18 @@ function knn_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, ::Val{di check_input(tree, points) check_k(tree, k) n_points = size(points, 2) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + dists = VectorOfArrays{Float64, 1}() + idxs = VectorOfArrays{Int, 1}() + dist = zeros(Float64, k) + idx = zeros(Int, k) for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - knn_point!(tree, point, sortres, dists[i], idxs[i], skip) + knn_point!(tree, point, sortres, dist, idx, skip) + push!(dists, dist) + push!(idxs, idx) + fill!(dist, 0) + fill!(idx, 0) end return idxs, dists end diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 02b0914..d2ec6a3 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -81,7 +81,7 @@ end points = rand(SVector{3, Float64}, 100) kdtree = KDTree(points) idxs = inrange(kdtree, view(points, 1:10), 0.1) - @test idxs isa Vector{Vector{Int}} + @test eltype(idxs) <: AbstractVector{Int} end @testset "mutating" begin diff --git a/test/test_knn.jl b/test/test_knn.jl index 3667772..99fd085 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -133,8 +133,8 @@ end points = rand(SVector{3, Float64}, 100) kdtree = KDTree(points) idxs, dists = knn(kdtree, view(points, 1:10), 3) - @test idxs isa Vector{Vector{Int}} - @test dists isa Vector{Vector{Float64}} + @test eltype(idxs) <: AbstractVector{Int} + @test eltype(dists) <: AbstractVector{Float64} end @testset "mutating" begin