#include #include #include #include #include #include /** * Class for representing a point. coordinate_type must be a numeric type. */ template class point { public: point(std::array c) : coords_(c) {} point(std::initializer_list list) { size_t n = std::min(dimensions, list.size()); std::copy_n(list.begin(), n, coords_.begin()); } /** * Returns the coordinate in the given dimension. * * @param index dimension index (zero based) * @return coordinate in the given dimension */ coordinate_type get(size_t index) const { return coords_[index]; } /** * Returns the distance squared from this point to another * point. * * @param pt another point * @return distance squared from this point to the other point */ double distance(const point& pt) const { double dist = 0; for (size_t i = 0; i < dimensions; ++i) { double d = get(i) - pt.get(i); dist += d * d; } return dist; } private: std::array coords_; }; template std::ostream& operator<<(std::ostream& out, const point& pt) { out << '('; for (size_t i = 0; i < dimensions; ++i) { if (i > 0) out << ", "; out << pt.get(i); } out << ')'; return out; } /** * C++ k-d tree implementation, based on the C version at rosettacode.org. */ template class kdtree { public: typedef point point_type; private: struct node { node(const point_type& pt) : point_(pt), left_(nullptr), right_(nullptr) {} coordinate_type get(size_t index) const { return point_.get(index); } double distance(const point_type& pt) const { return point_.distance(pt); } point_type point_; node* left_; node* right_; }; node* root_ = nullptr; node* best_ = nullptr; double best_dist_ = 0; size_t visited_ = 0; std::vector nodes_; struct node_cmp { node_cmp(size_t index) : index_(index) {} bool operator()(const node& n1, const node& n2) const { return n1.point_.get(index_) < n2.point_.get(index_); } size_t index_; }; node* make_tree(size_t begin, size_t end, size_t index) { if (end <= begin) return nullptr; size_t n = begin + (end - begin)/2; auto i = nodes_.begin(); std::nth_element(i + begin, i + n, i + end, node_cmp(index)); index = (index + 1) % dimensions; nodes_[n].left_ = make_tree(begin, n, index); nodes_[n].right_ = make_tree(n + 1, end, index); return &nodes_[n]; } void nearest(node* root, const point_type& point, size_t index) { if (root == nullptr) return; ++visited_; double d = root->distance(point); if (best_ == nullptr || d < best_dist_) { best_dist_ = d; best_ = root; } if (best_dist_ == 0) return; double dx = root->get(index) - point.get(index); index = (index + 1) % dimensions; nearest(dx > 0 ? root->left_ : root->right_, point, index); if (dx * dx >= best_dist_) return; nearest(dx > 0 ? root->right_ : root->left_, point, index); } public: kdtree(const kdtree&) = delete; kdtree& operator=(const kdtree&) = delete; /** * Constructor taking a pair of iterators. Adds each * point in the range [begin, end) to the tree. * * @param begin start of range * @param end end of range */ template kdtree(iterator begin, iterator end) : nodes_(begin, end) { root_ = make_tree(0, nodes_.size(), 0); } /** * Constructor taking a function object that generates * points. The function object will be called n times * to populate the tree. * * @param f function that returns a point * @param n number of points to add */ template kdtree(func&& f, size_t n) { nodes_.reserve(n); for (size_t i = 0; i < n; ++i) nodes_.push_back(f()); root_ = make_tree(0, nodes_.size(), 0); } /** * Returns true if the tree is empty, false otherwise. */ bool empty() const { return nodes_.empty(); } /** * Returns the number of nodes visited by the last call * to nearest(). */ size_t visited() const { return visited_; } /** * Returns the distance between the input point and return value * from the last call to nearest(). */ double distance() const { return std::sqrt(best_dist_); } /** * Finds the nearest point in the tree to the given point. * It is not valid to call this function if the tree is empty. * * @param pt a point * @return the nearest point in the tree to the given point */ const point_type& nearest(const point_type& pt) { if (root_ == nullptr) throw std::logic_error("tree is empty"); best_ = nullptr; visited_ = 0; best_dist_ = 0; nearest(root_, pt, 0); return best_->point_; } }; void test_wikipedia() { typedef point point2d; typedef kdtree tree2d; point2d points[] = { { 2, 3 }, { 5, 4 }, { 9, 6 }, { 4, 7 }, { 8, 1 }, { 7, 2 } }; tree2d tree(std::begin(points), std::end(points)); point2d n = tree.nearest({ 9, 2 }); std::cout << "Wikipedia example data:\n"; std::cout << "nearest point: " << n << '\n'; std::cout << "distance: " << tree.distance() << '\n'; std::cout << "nodes visited: " << tree.visited() << '\n'; } typedef point point3d; typedef kdtree tree3d; struct random_point_generator { random_point_generator(double min, double max) : engine_(std::random_device()()), distribution_(min, max) {} point3d operator()() { double x = distribution_(engine_); double y = distribution_(engine_); double z = distribution_(engine_); return point3d({x, y, z}); } std::mt19937 engine_; std::uniform_real_distribution distribution_; }; void test_random(size_t count) { random_point_generator rpg(0, 1); tree3d tree(rpg, count); point3d pt(rpg()); point3d n = tree.nearest(pt); std::cout << "Random data (" << count << " points):\n"; std::cout << "point: " << pt << '\n'; std::cout << "nearest point: " << n << '\n'; std::cout << "distance: " << tree.distance() << '\n'; std::cout << "nodes visited: " << tree.visited() << '\n'; } int main() { try { test_wikipedia(); std::cout << '\n'; test_random(1000); std::cout << '\n'; test_random(1000000); } catch (const std::exception& e) { std::cerr << e.what() << '\n'; } return 0; }