import java.util.*; public class KdTree { private int dimensions_; private Node root_ = null; private Node best_ = null; private double bestDistance_ = 0; private int visited_ = 0; public KdTree(int dimensions, List nodes) { dimensions_ = dimensions; root_ = makeTree(nodes, 0, nodes.size(), 0); } public Node findNearest(Node target) { if (root_ == null) throw new IllegalStateException("Tree is empty!"); best_ = null; visited_ = 0; bestDistance_ = 0; nearest(root_, target, 0); return best_; } public int visited() { return visited_; } public double distance() { return Math.sqrt(bestDistance_); } private void nearest(Node root, Node target, int index) { if (root == null) return; ++visited_; double d = root.distance(target); if (best_ == null || d < bestDistance_) { bestDistance_ = d; best_ = root; } if (bestDistance_ == 0) return; double dx = root.get(index) - target.get(index); index = (index + 1) % dimensions_; nearest(dx > 0 ? root.left_ : root.right_, target, index); if (dx * dx >= bestDistance_) return; nearest(dx > 0 ? root.right_ : root.left_, target, index); } private Node makeTree(List nodes, int begin, int end, int index) { if (end <= begin) return null; int n = begin + (end - begin)/2; Node node = QuickSelect.select(nodes, begin, end - 1, n, new NodeComparator(index)); index = (index + 1) % dimensions_; node.left_ = makeTree(nodes, begin, n, index); node.right_ = makeTree(nodes, n + 1, end, index); return node; } private static class NodeComparator implements Comparator { private int index_; private NodeComparator(int index) { index_ = index; } public int compare(Node n1, Node n2) { return Double.compare(n1.get(index_), n2.get(index_)); } } public static class Node { private double[] coords_; private Node left_ = null; private Node right_ = null; public Node(double[] coords) { coords_ = coords; } public Node(double x, double y) { this(new double[]{x, y}); } public Node(double x, double y, double z) { this(new double[]{x, y, z}); } double get(int index) { return coords_[index]; } double distance(Node node) { double dist = 0; for (int i = 0; i < coords_.length; ++i) { double d = coords_[i] - node.coords_[i]; dist += d * d; } return dist; } public String toString() { StringBuilder s = new StringBuilder("("); for (int i = 0; i < coords_.length; ++i) { if (i > 0) s.append(", "); s.append(coords_[i]); } s.append(')'); return s.toString(); } } }