Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v4
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NearestNeighbors"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
version = "0.4.22"
version = "0.4.21"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
63 changes: 46 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for
- A matrix of size `nd × np` where `nd` is the dimensionality and `np` is the number of points, or
- A vector of vectors with fixed dimensionality `nd`, i.e., `data` should be a `Vector{V}` where `V` is a subtype of `AbstractVector` with defined `length(V)`. For example a `Vector{V}` where `V = SVector{3, Float64}` is ok because `length(V) = 3` is defined.
* `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`).
* `leafsize`: Determines the number of points (default 25) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points.
* `leafsize`: Determines the number of points (default 10) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points.
* `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data.

All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation.
Expand All @@ -49,20 +49,19 @@ brutetree = BruteTree(data)
A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods:

```julia
knn(tree, point[s], k [, skip=always_false]) -> idxs, dists
knn!(idxs, dists, tree, point, k [, skip=always_false])
knn(tree, point[s], k, skip = always_false) -> idxs, dists
knn!(idxs, dists, tree, point, k, skip = always_false)
```

* `tree`: The tree instance.
* `point[s]`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors.
* `k`: Number of nearest neighbors to find.
* `skip` (optional): A predicate function to skip certain points, e.g., points already visited.
* `points`: A vector or matrix of points to find the `k` nearest neighbors for. A vector of numbers represents a single point; a matrix means the `k` nearest neighbors for each point (column) will be computed. `points` can also be a vector of vectors.
* `skip` (optional): A predicate to skip certain points, e.g., points already visited.


For the single closest neighbor, you can use `nn`:

```julia
nn(tree, point[s] [, skip=always_false]) -> idx, dist
nn(tree, points, skip = always_false) -> idxs, dists
```

Examples:
Expand All @@ -74,7 +73,7 @@ k = 3
point = rand(3)

kdtree = KDTree(data)
idxs, dists = knn(kdtree, point, k)
idxs, dists = knn(kdtree, point, k, true)

idxs
# 3-element Array{Int64,1}:
Expand All @@ -90,7 +89,7 @@ dists

# Multiple points
points = rand(3, 4)
idxs, dists = knn(kdtree, points, k)
idxs, dists = knn(kdtree, points, k, true)

idxs
# 4-element Array{Array{Int64,1},1}:
Expand All @@ -110,7 +109,7 @@ idxs
using StaticArrays
v = @SVector[0.5, 0.3, 0.2];

idxs, dists = knn(kdtree, v, k)
idxs, dists = knn(kdtree, v, k, true)

idxs
# 3-element Array{Int64,1}:
Expand All @@ -134,15 +133,11 @@ knn!(idxs, dists, kdtree, v, k)
A range search finds all neighbors within the range `r` of given point(s). This is done with the methods:

```julia
inrange(tree, point[s], radius) -> idxs
inrange!(idxs, tree, point, radius)
inrange(tree, points, r) -> idxs
inrange!(idxs, tree, point, r)
```

* `tree`: The tree instance.
* `point[s]`: A vector or matrix of points to find neighbors for.
* `radius`: Search radius.

Note: Distances are not returned, only indices.
Distances are not returned.

Example:

Expand All @@ -169,6 +164,40 @@ inrange!(idxs, balltree, point, r)
neighborscount = inrangecount(balltree, point, r)
```

### Passing a runtime function into the range search
```julia
inrange_callback!(tree, points, radius, callback)
```

Example:
```julia
using NearestNeighbors
data = rand(3,10^4)
data_values = rand(10^4)
r = 0.05
points = rand(3,10)
results = zeros(10)

# this function will sum the `data_values` corresponding to the `data` that is in range of `points`
# `p_idx` is the index of `points` i.e. 1-10
# `data_idx` is is the index of the data in the tree that is in range
# `values` is data needed for the operation
# `results` is a storage space for the results
function sum_values!(p_idx, data_idx, values, results)
results[p_idx] += values[data_idx]
end

# `callback` must be of the form f(p_idx, data_idx)
callback(p_idx, data_idx, p) = sum_values!(p_idx, data_idx, values, results)

kdtree = KDTree(data)

# runs the callback with all tree data points in range of points. In this case sums the `data_values` corresponding to the `data` that is in range of `points`
inrange_callback!(tree, points, radius, callback)
```



## Using On-Disk Data Sets

By default, trees store a copy of the `data` provided during construction. For data sets larger than available memory, `DataFreeTree` can be used to strip a tree of its data field and re-link it later.
Expand Down
2 changes: 1 addition & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using StaticArrays
import Base.show

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs
export knn, knn!, nn, inrange, inrange!,inrangecount, inrange_callback! # TODOs? , allpairs, distmat, npairs
export injectdata

export Euclidean,
Expand Down
29 changes: 10 additions & 19 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,9 @@ end


"""
BallTree(data [, metric = Euclidean(); leafsize = 25, reorder = true])::BallTree
BallTree(data [, metric = Euclidean(); leafsize = 25, reorder = true]) -> balltree

Creates a `BallTree` from the data using the given `metric` and `leafsize`.

# Arguments
- `data`: Point data as a matrix of size `nd × np` or vector of vectors
- `metric`: Distance metric to use (can be any `Metric` from Distances.jl). Default: `Euclidean()`
- `leafsize`: Number of points at which to stop splitting the tree. Default: `25`
- `reorder`: If `true`, reorder data to improve cache locality. Default: `true`

# Returns
- `balltree`: A `BallTree` instance

BallTree works with any metric and is often better for high-dimensional data.
"""
function BallTree(data::AbstractVector{V},
metric::Metric = Euclidean();
Expand Down Expand Up @@ -188,16 +177,18 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
point_index::Int = 1,
callback::Union{Nothing, Function} = nothing) where {V}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return inrange_kernel!(tree, 1, point, ball, callback, point_index) # Call the recursive range finder
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
callback::Union{Nothing, Function},
point_index::Int)

if index > length(tree.hyper_spheres)
return 0
Expand All @@ -215,19 +206,19 @@ function inrange_kernel!(tree::BallTree,
# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r
return add_points_inrange!(idx_in_ball, tree, index, point, r)
return add_points_inrange!(tree, index, point, r, callback, point_index)
end

count = 0

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses_fast(dist, tree.metric, sphere, query_ball)
count += addall(tree, index, idx_in_ball)
count += addall(tree, index, callback, point_index)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getleft(index), point, query_ball, callback, point_index)
count += inrange_kernel!(tree, getright(index), point, query_ball, callback, point_index)
end
return count
end
22 changes: 7 additions & 15 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,9 @@ struct BruteTree{V <: AbstractVector,M <: PreMetric} <: NNTree{V,M}
end

"""
BruteTree(data [, metric = Euclidean()])::Brutetree
BruteTree(data [, metric = Euclidean()) -> brutetree

Creates a `BruteTree` from the data using the given `metric`.

# Arguments
- `data`: Point data as a matrix of size `nd × np` or vector of vectors
- `metric`: Distance metric to use (can be any `PreMetric` from Distances.jl). Default: `Euclidean()`

# Returns
- `brutetree`: A `BruteTree` instance

BruteTree performs exhaustive linear search and is useful as a baseline or for small datasets.
Note: `leafsize` and `reorder` parameters are ignored for BruteTree.
"""
function BruteTree(data::AbstractVector{V}, metric::PreMetric = Euclidean();
reorder::Bool=false, leafsize::Int=0, storedata::Bool=true) where {V <: AbstractVector}
Expand Down Expand Up @@ -71,21 +61,23 @@ end
function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
return inrange_kernel!(tree, point, radius, idx_in_ball)
point_index::Int = 1,
callback::Union{Nothing, Function} = nothing)
return inrange_kernel!(tree, point, radius, callback, point_index)
end


function inrange_kernel!(tree::BruteTree,
point::AbstractVector,
r::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
callback::Union{Nothing, Function},
point_index::Int)
count = 0
for i in 1:length(tree.data)
d = evaluate(tree.metric, tree.data[i], point)
if d <= r
count += 1
idx_in_ball !== nothing && push!(idx_in_ball, i)
!isnothing(callback) && callback(point_index, i)
end
end
return count
Expand Down
1 change: 1 addition & 0 deletions src/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
@inline eval_pow(d::Minkowski, s) = abs(s)^d.p

@inline eval_diff(::NonweightedMinkowskiMetric, a, b, dim) = a - b
@inline eval_diff(::Chebyshev, ::Any, b, dim) = b
@inline eval_diff(m::WeightedMinkowskiMetric, a, b, dim) = m.weights[dim] * (a-b)

function evaluate_maybe_end(d::Distances.UnionMetrics, a::AbstractVector,
Expand Down
33 changes: 0 additions & 33 deletions src/hyperrectangles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,3 @@ get_max_distance_no_end(m, rec, point) =

get_min_distance_no_end(m, rec, point) =
get_min_max_distance_no_end(distance_function_min, m, rec, point)

@inline function update_new_min(M::Metric, old_min, hyper_rec, p_dim, split_dim, split_val)
@inbounds begin
lo = hyper_rec.mins[split_dim]
hi = hyper_rec.maxes[split_dim]
end
ddiff = distance_function_min(p_dim, hi, lo)
split_diff = abs(p_dim - split_val)
split_diff_pow = eval_pow(M, split_diff)
ddiff_pow = eval_pow(M, ddiff)
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim)
return old_min + diff_tot
end

# Compute per-dimension contributions for max distance
function get_max_distance_contributions(m::Metric, rec::HyperRectangle{V}, point::AbstractVector{T}) where {V,T}
p = Distances.parameters(m)
return V(
@inbounds begin
v = distance_function_max(point[dim], rec.maxes[dim], rec.mins[dim])
p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim])
end for dim in eachindex(point)
)
end

# Compute single dimension contribution for max distance
function get_max_distance_contribution_single(m::Metric, point_dim, min_bound::T, max_bound::T, dim::Integer) where {T}
v = distance_function_max(point_dim, max_bound, min_bound)
p = Distances.parameters(m)
return p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim])
end


Loading