Skip to content

Commit 9ac48f3

Browse files
committed
Search visitor (#31)
* Improved search visitor handling. * Added approximate versions of SearchNn, SearchRadius, and search_radius. * Added support for Eigen::Map<const Eigen::Matrix<>>. * Added RKdTree to pico_understory. * Added the mnist example. * Version bump.
1 parent 9d83fa7 commit 9ac48f3

25 files changed

+929
-230
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/utils.cmake)
66

77
project(pico_tree
88
LANGUAGES CXX
9-
VERSION 0.8.0
9+
VERSION 0.8.1
1010
DESCRIPTION "PicoTree is a C++ header only library for fast nearest neighbor searches and range searches using a KdTree."
1111
HOMEPAGE_URL "https://github.com/Jaybro/pico_tree")
1212

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ PicoTree is a C++ header only library with [Python bindings](https://github.com/
1111
| [Scikit-learn KDTree][skkd] 1.2.2 | ... | 6.2s | ... | 42.2s |
1212
| [pykdtree][pykd] 1.3.7 | ... | 1.0s | ... | 6.6s |
1313
| [OpenCV FLANN][cvfn] 4.6.0 | 1.9s | ... | 4.7s | ... |
14-
| PicoTree KdTree v0.8.0 | 0.9s | 1.0s | 2.8s | 3.1s |
14+
| PicoTree KdTree v0.8.1 | 0.9s | 1.0s | 2.8s | 3.1s |
1515

1616
Two [LiDAR](./docs/benchmark.md) based point clouds of sizes 7733372 and 7200863 were used to generate these numbers. The first point cloud was the input to the build algorithm and the second to the query algorithm. All benchmarks were run on a single thread with the following parameters: `max_leaf_size=10` and `knn=1`. A more detailed [C++ comparison](./docs/benchmark.md) of PicoTree is available with respect to [nanoflann][nano].
1717

@@ -61,6 +61,7 @@ PicoTree can interface with different types of points and point sets through tra
6161
* Creating a [custom search visitor](./examples/kd_tree/kd_tree_custom_search_visitor.cpp).
6262
* [Saving and loading](./examples/kd_tree/kd_tree_save_and_load.cpp) a KdTree to and from a file.
6363
* Support for [Eigen](./examples/eigen/eigen.cpp) and [OpenCV](./examples/opencv/opencv.cpp) data types.
64+
* Running the KdTree on the [MNIST](./examples/mnist/mnist.cpp) [database](http://yann.lecun.com/exdb/mnist/).
6465
* How to use the [KdTree with Python](./examples/python/kd_tree.py).
6566

6667
# Requirements

examples/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ else()
3535
message(STATUS "benchmark not found. PicoTree benchmarks skipped.")
3636
endif()
3737

38+
if(Eigen3_FOUND)
39+
add_subdirectory(mnist)
40+
endif()
41+
3842
# The Python examples only get copied when the bindings module will be build.
3943
if(TARGET _pyco_tree)
4044
add_subdirectory(python)

examples/benchmark/bm_opencv_flann.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,10 @@ BENCHMARK_DEFINE_F(BmOpenCvFlann, KnnCt)(benchmark::State& state) {
116116
// There is also the option to query them all at once, but this doesn't really
117117
// change performance and this version looks more like the other benchmarks.
118118
for (auto _ : state) {
119-
std::vector<Index> indices(knn_count);
119+
// The only supported index type is int.
120+
std::vector<int> indices(knn_count);
120121
std::vector<Scalar> distances(knn_count);
121-
fl::Matrix<Index> mat_indices(indices.data(), 1, knn_count);
122+
fl::Matrix<int> mat_indices(indices.data(), 1, knn_count);
122123
fl::Matrix<Scalar> mat_distances(distances.data(), 1, knn_count);
123124

124125
for (auto& p : points_test_) {

examples/kd_tree/kd_tree_custom_search_visitor.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,39 @@
33
#include <pico_tree/kd_tree.hpp>
44
#include <pico_tree/vector_traits.hpp>
55

6-
//! \brief Search visitor that counts how many points were considered as a
7-
//! nearest neighbor.
6+
// Search visitor that counts how many points were considered as a possible
7+
// nearest neighbor.
88
template <typename Neighbor>
99
class SearchNnCounter {
1010
public:
1111
using NeighborType = Neighbor;
1212
using IndexType = typename Neighbor::IndexType;
1313
using ScalarType = typename Neighbor::ScalarType;
1414

15-
//! \brief Creates a visitor for approximate nearest neighbor searching.
16-
//! \param nn Search result.
15+
// Create a visitor for approximate nearest neighbor searching. The argument
16+
// is the search result.
1717
inline SearchNnCounter(Neighbor& nn) : count_(0), nn_(nn) {
1818
// Initial search distance.
1919
nn_.distance = std::numeric_limits<ScalarType>::max();
2020
}
2121

22-
//! \brief Visit current point.
23-
//! \details This method is required. The KdTree calls this function when it
24-
//! finds a point that is closer to the query than the result of this
25-
//! visitors' max() function. I.e., it found a new nearest neighbor.
26-
//! \param idx Point index.
27-
//! \param d Point distance (that depends on the metric).
22+
// Visit current point. This method is required. The search algorithm calls
23+
// this function for every point it encounters in the KdTree. The arguments of
24+
// the method are respectively the index and distance of the visited point.
2825
inline void operator()(IndexType const idx, ScalarType const dst) {
26+
// Only update the nearest neighbor when the point we visit is actually
27+
// closer to the query point.
28+
if (max() > dst) {
29+
nn_ = {idx, dst};
30+
}
2931
count_++;
30-
nn_ = {idx, dst};
3132
}
3233

33-
//! \brief Maximum search distance with respect to the query point.
34-
//! \details This method is required.
34+
// Maximum search distance with respect to the query point. This method is
35+
// required. The nodes of the KdTree are filtered using this method.
3536
inline ScalarType const& max() const { return nn_.distance; }
3637

37-
//! \brief Returns the number of points that were considered the nearest
38-
//! neighbor.
39-
//! \details This method is not required.
38+
// The amount of points visited during a query.
4039
inline IndexType const& count() const { return count_; }
4140

4241
private:
@@ -62,7 +61,7 @@ int main() {
6261
SearchNnCounter<Neighbor> v(nn);
6362
tree.SearchNearest(q, v);
6463

65-
std::cout << "Custom visitor # nns considered: " << v.count() << std::endl;
64+
std::cout << "Number of points visited: " << v.count() << std::endl;
6665

6766
return 0;
6867
}

examples/mnist/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_executable(mnist mnist.cpp)
2+
set_default_target_properties(mnist)
3+
target_link_libraries(mnist PUBLIC pico_toolshed pico_understory Eigen3::Eigen)

examples/mnist/mnist.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include <algorithm>
2+
#include <filesystem>
3+
#include <iostream>
4+
#include <pico_toolshed/format/format_bin.hpp>
5+
#include <pico_toolshed/format/format_mnist.hpp>
6+
#include <pico_toolshed/scoped_timer.hpp>
7+
#include <pico_tree/array_traits.hpp>
8+
#include <pico_tree/kd_tree.hpp>
9+
#include <pico_tree/vector_traits.hpp>
10+
#include <pico_understory/rkd_tree.hpp>
11+
12+
template <typename U, typename T, std::size_t N>
13+
std::array<U, N> Cast(std::array<T, N> const& i) {
14+
std::array<U, N> c;
15+
std::transform(i.begin(), i.end(), c.begin(), [](T a) -> U {
16+
return static_cast<U>(a);
17+
});
18+
return c;
19+
}
20+
21+
template <std::size_t N>
22+
std::vector<std::array<float, N>> Cast(
23+
std::vector<std::array<std::byte, N>> const& i) {
24+
std::vector<std::array<float, N>> c;
25+
std::transform(
26+
i.begin(),
27+
i.end(),
28+
std::back_inserter(c),
29+
[](std::array<std::byte, N> const& a) -> std::array<float, N> {
30+
return Cast<float>(a);
31+
});
32+
return c;
33+
}
34+
35+
int main(int argc, char** argv) {
36+
using ImageByte = std::array<std::byte, 28 * 28>;
37+
using ImageFloat = std::array<float, 28 * 28>;
38+
39+
std::string fn_images_train = "train-images.idx3-ubyte";
40+
std::string fn_images_test = "t10k-images.idx3-ubyte";
41+
std::string fn_mnist_nns_gt = "mnist_nns_gt.bin";
42+
43+
if (!std::filesystem::exists(fn_images_train)) {
44+
std::cout << fn_images_train << " doesn't exist." << std::endl;
45+
return 0;
46+
}
47+
48+
if (!std::filesystem::exists(fn_images_test)) {
49+
std::cout << fn_images_test << " doesn't exist." << std::endl;
50+
return 0;
51+
}
52+
53+
std::vector<ImageFloat> images_train;
54+
{
55+
std::vector<ImageByte> images_train_u8;
56+
pico_tree::ReadMnistImages(fn_images_train, images_train_u8);
57+
images_train = Cast(images_train_u8);
58+
}
59+
60+
std::vector<ImageFloat> images_test;
61+
{
62+
std::vector<ImageByte> images_test_u8;
63+
pico_tree::ReadMnistImages(fn_images_test, images_test_u8);
64+
images_test = Cast(images_test_u8);
65+
}
66+
67+
std::size_t max_leaf_size_ex = 16;
68+
std::size_t max_leaf_size_rp = 128;
69+
// With 16 trees we can get a precision of around 85-90%.
70+
// With 32 trees we can get a precision of around 95-97%.
71+
std::size_t forest_size = 2;
72+
std::size_t count = images_test.size();
73+
std::vector<pico_tree::Neighbor<int, float>> nns(count);
74+
75+
if (!std::filesystem::exists(fn_images_train)) {
76+
auto kd_tree = [&images_train, &max_leaf_size_ex]() {
77+
ScopedTimer t0("kd_tree build");
78+
return pico_tree::KdTree<std::reference_wrapper<std::vector<ImageFloat>>>(
79+
images_train, max_leaf_size_ex);
80+
}();
81+
82+
{
83+
ScopedTimer t1("kd_tree query");
84+
for (std::size_t i = 0; i < nns.size(); ++i) {
85+
kd_tree.SearchNn(images_test[i], nns[i]);
86+
}
87+
}
88+
89+
std::cout << "Writing " << fn_mnist_nns_gt << "." << std::endl;
90+
pico_tree::WriteBin(fn_mnist_nns_gt, nns);
91+
} else {
92+
std::cout << "Reading " << fn_mnist_nns_gt << "." << std::endl;
93+
pico_tree::ReadBin(fn_mnist_nns_gt, nns);
94+
}
95+
96+
std::size_t equal = 0;
97+
98+
{
99+
auto rkd_tree = [&images_train, &max_leaf_size_rp, &forest_size]() {
100+
ScopedTimer t0("rkd_tree build");
101+
return pico_tree::RKdTree<
102+
std::reference_wrapper<std::vector<ImageFloat>>>(
103+
images_train, max_leaf_size_rp, forest_size);
104+
}();
105+
106+
ScopedTimer t1("rkd_tree query");
107+
pico_tree::Neighbor<int, float> nn;
108+
for (std::size_t i = 0; i < nns.size(); ++i) {
109+
rkd_tree.SearchNn(images_test[i], nn);
110+
111+
if (nns[i].index == nn.index) {
112+
++equal;
113+
}
114+
}
115+
}
116+
117+
std::cout << "Precision: "
118+
<< (static_cast<float>(equal) / static_cast<float>(count))
119+
<< std::endl;
120+
121+
return 0;
122+
}

examples/pico_understory/CMakeLists.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@ target_include_directories(pico_understory INTERFACE ${CMAKE_CURRENT_LIST_DIR})
33
target_link_libraries(pico_understory INTERFACE PicoTree::PicoTree)
44
target_sources(pico_understory
55
INTERFACE
6-
${CMAKE_CURRENT_LIST_DIR}/pico_understory/cover_tree.hpp
7-
${CMAKE_CURRENT_LIST_DIR}/pico_understory/metric.hpp
6+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_base.hpp
7+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_builder.hpp
8+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_data.hpp
9+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_node.hpp
10+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_search.hpp
11+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_builder.hpp
12+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_rr_data.hpp
13+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_search.hpp
14+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/static_buffer.hpp
15+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/cover_tree.hpp
16+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/metric.hpp
17+
${CMAKE_CURRENT_LIST_DIR}/pico_understory/rkd_tree.hpp
818
)

0 commit comments

Comments
 (0)