265 lines
8.2 KiB
Plaintext
265 lines
8.2 KiB
Plaintext
use std::cmp::Ordering;
|
|
use std::cmp::Ordering::Less;
|
|
use std::ops::Sub;
|
|
use std::time::Instant;
|
|
|
|
use rand::prelude::*;
|
|
|
|
#[derive(Clone, PartialEq, Debug)]
|
|
struct Point {
|
|
pub coords: Vec<f32>,
|
|
}
|
|
|
|
impl<'a, 'b> Sub<&'b Point> for &'a Point {
|
|
type Output = Point;
|
|
|
|
fn sub(self, rhs: &Point) -> Point {
|
|
assert_eq!(self.coords.len(), rhs.coords.len());
|
|
Point {
|
|
coords: self
|
|
.coords
|
|
.iter()
|
|
.zip(rhs.coords.iter())
|
|
.map(|(&x, &y)| x - y)
|
|
.collect(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Point {
|
|
fn norm_sq(&self) -> f32 {
|
|
self.coords.iter().map(|n| n * n).sum()
|
|
}
|
|
}
|
|
|
|
struct KDTreeNode {
|
|
point: Point,
|
|
dim: usize,
|
|
// Construction could become faster if we use an arena allocator,
|
|
// but this is easier to use.
|
|
left: Option<Box<KDTreeNode>>,
|
|
right: Option<Box<KDTreeNode>>,
|
|
}
|
|
|
|
impl KDTreeNode {
|
|
/// Create a new KDTreeNode around the `dim`th dimension.
|
|
/// Alternatively, we could dynamically determine the dimension to
|
|
/// split on by using the longest dimension.
|
|
pub fn new(points: &mut [Point], dim: usize) -> KDTreeNode {
|
|
let points_len = points.len();
|
|
if points_len == 1 {
|
|
return KDTreeNode {
|
|
point: points[0].clone(),
|
|
dim,
|
|
left: None,
|
|
right: None,
|
|
};
|
|
}
|
|
|
|
// Split around the median
|
|
let pivot = quickselect_by(points, points_len / 2, &|a, b| {
|
|
a.coords[dim].partial_cmp(&b.coords[dim]).unwrap()
|
|
});
|
|
|
|
let left = Some(Box::new(KDTreeNode::new(
|
|
&mut points[0..points_len / 2],
|
|
(dim + 1) % pivot.coords.len(),
|
|
)));
|
|
let right = if points.len() >= 3 {
|
|
Some(Box::new(KDTreeNode::new(
|
|
&mut points[points_len / 2 + 1..points_len],
|
|
(dim + 1) % pivot.coords.len(),
|
|
)))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
KDTreeNode {
|
|
point: pivot,
|
|
dim,
|
|
left,
|
|
right,
|
|
}
|
|
}
|
|
|
|
pub fn find_nearest_neighbor<'a>(&'a self, point: &Point) -> (&'a Point, usize) {
|
|
self.find_nearest_neighbor_helper(point, &self.point, (point - &self.point).norm_sq(), 1)
|
|
}
|
|
|
|
fn find_nearest_neighbor_helper<'a>(
|
|
&'a self,
|
|
point: &Point,
|
|
best: &'a Point,
|
|
best_dist_sq: f32,
|
|
n_visited: usize,
|
|
) -> (&'a Point, usize) {
|
|
let mut my_best = best;
|
|
let mut my_best_dist_sq = best_dist_sq;
|
|
let mut my_n_visited = n_visited;
|
|
|
|
// We should always examine the near side
|
|
if self.point.coords[self.dim] < point.coords[self.dim] && self.right.is_some() {
|
|
let (a, b) = self.right.as_ref().unwrap().find_nearest_neighbor_helper(
|
|
point,
|
|
my_best,
|
|
my_best_dist_sq,
|
|
my_n_visited,
|
|
);
|
|
my_best = a;
|
|
my_n_visited = b;
|
|
} else if self.left.is_some() {
|
|
let (a, b) = self.left.as_ref().unwrap().find_nearest_neighbor_helper(
|
|
point,
|
|
my_best,
|
|
my_best_dist_sq,
|
|
my_n_visited,
|
|
);
|
|
my_best = a;
|
|
my_n_visited = b;
|
|
}
|
|
|
|
// distance along this node's axis
|
|
let axis_dist_sq = (self.point.coords[self.dim] - point.coords[self.dim]).powi(2);
|
|
if axis_dist_sq <= my_best_dist_sq {
|
|
// self can only be nearer than best if axis_dist_sq is less than
|
|
// best_dist_sq because axis_dist_sq is a lower bound for
|
|
// self_dist_sq
|
|
let self_dist_sq = (point - &self.point).norm_sq();
|
|
if self_dist_sq < my_best_dist_sq {
|
|
my_best = &self.point;
|
|
my_best_dist_sq = self_dist_sq;
|
|
}
|
|
|
|
// bookkeeping
|
|
my_n_visited += 1;
|
|
|
|
// same reasoning applies for the far side of the split
|
|
if self.point.coords[self.dim] < point.coords[self.dim] && self.left.is_some() {
|
|
let (a, b) = self.left.as_ref().unwrap().find_nearest_neighbor_helper(
|
|
point,
|
|
my_best,
|
|
my_best_dist_sq,
|
|
my_n_visited,
|
|
);
|
|
my_best = a;
|
|
my_n_visited = b;
|
|
} else if self.right.is_some() {
|
|
let (a, b) = self.right.as_ref().unwrap().find_nearest_neighbor_helper(
|
|
point,
|
|
my_best,
|
|
my_best_dist_sq,
|
|
my_n_visited,
|
|
);
|
|
my_best = a;
|
|
my_n_visited = b;
|
|
}
|
|
}
|
|
|
|
(my_best, my_n_visited)
|
|
}
|
|
}
|
|
|
|
pub fn main() {
|
|
let mut rng = thread_rng();
|
|
|
|
// wordpress
|
|
let mut wp_points: Vec<Point> = [
|
|
[2.0, 3.0],
|
|
[5.0, 4.0],
|
|
[9.0, 6.0],
|
|
[4.0, 7.0],
|
|
[8.0, 1.0],
|
|
[7.0, 2.0],
|
|
]
|
|
.iter()
|
|
.map(|x| Point { coords: x.to_vec() })
|
|
.collect();
|
|
let wp_tree = KDTreeNode::new(&mut wp_points, 0);
|
|
|
|
let wp_target = Point {
|
|
coords: vec![9.0, 2.0],
|
|
};
|
|
let (point, n_visited) = wp_tree.find_nearest_neighbor(&wp_target);
|
|
println!("Wikipedia example data:");
|
|
println!("Point: [9, 2]");
|
|
println!("Nearest neighbor: {:?}", point);
|
|
println!("Distance: {}", (point - &wp_target).norm_sq().sqrt());
|
|
println!("Nodes visited: {}", n_visited);
|
|
|
|
// randomly generated 3D
|
|
let n_random = 1000;
|
|
let mut make_random_point = || Point {
|
|
coords: (0..3).map(|_| (rng.gen::<f32>() - 0.5) * 1000.0).collect(),
|
|
};
|
|
let mut random_points: Vec<Point> = (0..n_random).map(|_| make_random_point()).collect();
|
|
|
|
let start_cons_time = Instant::now();
|
|
let random_tree = KDTreeNode::new(&mut random_points, 0);
|
|
let cons_time = start_cons_time.elapsed();
|
|
println!(
|
|
"1,000 3d points (Construction time: {}ms)",
|
|
cons_time.as_millis(),
|
|
);
|
|
|
|
let random_target = make_random_point();
|
|
|
|
let (point, n_visited) = random_tree.find_nearest_neighbor(&random_target);
|
|
println!("Point: {:?}", random_target);
|
|
println!("Nearest neighbor: {:?}", point);
|
|
println!("Distance: {}", (point - &random_target).norm_sq().sqrt());
|
|
println!("Nodes visited: {}", n_visited);
|
|
|
|
// benchmark search time
|
|
let n_searches = 1000;
|
|
let random_targets: Vec<Point> = (0..n_searches).map(|_| make_random_point()).collect();
|
|
|
|
let start_search_time = Instant::now();
|
|
let mut total_n_visited = 0;
|
|
for target in &random_targets {
|
|
let (_, n_visited) = random_tree.find_nearest_neighbor(target);
|
|
total_n_visited += n_visited;
|
|
}
|
|
let search_time = start_search_time.elapsed();
|
|
println!(
|
|
"Visited an average of {} nodes on {} searches in {} ms",
|
|
total_n_visited as f32 / n_searches as f32,
|
|
n_searches,
|
|
search_time.as_millis(),
|
|
);
|
|
}
|
|
|
|
fn quickselect_by<T>(arr: &mut [T], position: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> T
|
|
where
|
|
T: Clone,
|
|
{
|
|
// We use `thread_rng` here because it was already initialized in `main`.
|
|
let mut pivot_index = thread_rng().gen_range(0, arr.len());
|
|
// Need to wrap in another closure or we get ownership complaints.
|
|
// Tried using an unboxed closure to get around this but couldn't get it to work.
|
|
pivot_index = partition_by(arr, pivot_index, &|a: &T, b: &T| cmp(a, b));
|
|
let array_len = arr.len();
|
|
match position.cmp(&pivot_index) {
|
|
Ordering::Equal => arr[position].clone(),
|
|
Ordering::Less => quickselect_by(&mut arr[0..pivot_index], position, cmp),
|
|
Ordering::Greater => quickselect_by(
|
|
&mut arr[pivot_index + 1..array_len],
|
|
position - pivot_index - 1,
|
|
cmp,
|
|
),
|
|
}
|
|
}
|
|
|
|
fn partition_by<T>(arr: &mut [T], pivot_index: usize, cmp: &dyn Fn(&T, &T) -> Ordering) -> usize {
|
|
let array_len = arr.len();
|
|
arr.swap(pivot_index, array_len - 1);
|
|
let mut store_index = 0;
|
|
for i in 0..array_len - 1 {
|
|
if cmp(&arr[i], &arr[array_len - 1]) == Less {
|
|
arr.swap(i, store_index);
|
|
store_index += 1;
|
|
}
|
|
}
|
|
arr.swap(array_len - 1, store_index);
|
|
store_index
|
|
}
|