127 lines
3.3 KiB
Plaintext
127 lines
3.3 KiB
Plaintext
struct Kd_node {
|
|
d,
|
|
split,
|
|
left,
|
|
right,
|
|
}
|
|
|
|
struct Orthotope {
|
|
min,
|
|
max,
|
|
}
|
|
|
|
class Kd_tree(n, bounds) {
|
|
|
|
method init {
|
|
n = self.nk2(0, n);
|
|
}
|
|
|
|
method nk2(split, e) {
|
|
return(nil) if (e.len <= 0);
|
|
var exset = e.sort_by { _[split] }
|
|
var m = (exset.len // 2);
|
|
var d = exset[m];
|
|
while ((m+1 < exset.len) && (exset[m+1][split] == d[split])) {
|
|
++m;
|
|
}
|
|
|
|
var s2 = ((split + 1) % d.len); # cycle coordinates
|
|
Kd_node(d: d, split: split,
|
|
left: self.nk2(s2, exset.first(m)),
|
|
right: self.nk2(s2, exset.last(m-1)));
|
|
}
|
|
}
|
|
|
|
struct T3 {
|
|
nearest,
|
|
dist_sqd = Inf,
|
|
nodes_visited = 0,
|
|
}
|
|
|
|
func find_nearest(k, t, p) {
|
|
func nn(kd, target, hr, max_dist_sqd) {
|
|
kd || return T3(nearest: [0]*k);
|
|
|
|
var nodes_visited = 1;
|
|
var s = kd.split;
|
|
var pivot = kd.d;
|
|
var left_hr = Orthotope(hr.min, hr.max);
|
|
var right_hr = Orthotope(hr.min, hr.max);
|
|
left_hr.max[s] = pivot[s];
|
|
right_hr.min[s] = pivot[s];
|
|
|
|
var nearer_kd;
|
|
var further_kd;
|
|
var nearer_hr;
|
|
var further_hr;
|
|
if (target[s] <= pivot[s]) {
|
|
(nearer_kd, nearer_hr) = (kd.left, left_hr);
|
|
(further_kd, further_hr) = (kd.right, right_hr);
|
|
}
|
|
else {
|
|
(nearer_kd, nearer_hr) = (kd.right, right_hr);
|
|
(further_kd, further_hr) = (kd.left, left_hr);
|
|
}
|
|
|
|
var n1 = nn(nearer_kd, target, nearer_hr, max_dist_sqd);
|
|
var nearest = n1.nearest;
|
|
var dist_sqd = n1.dist_sqd;
|
|
nodes_visited += n1.nodes_visited;
|
|
|
|
if (dist_sqd < max_dist_sqd) {
|
|
max_dist_sqd = dist_sqd;
|
|
}
|
|
var d = (pivot[s] - target[s] -> sqr);
|
|
if (d > max_dist_sqd) {
|
|
return T3(nearest: nearest, dist_sqd: dist_sqd, nodes_visited: nodes_visited);
|
|
}
|
|
d = (pivot ~Z- target »sqr»() «+»);
|
|
if (d < dist_sqd) {
|
|
nearest = pivot;
|
|
dist_sqd = d;
|
|
max_dist_sqd = dist_sqd;
|
|
}
|
|
|
|
var n2 = nn(further_kd, target, further_hr, max_dist_sqd);
|
|
nodes_visited += n2.nodes_visited;
|
|
if (n2.dist_sqd < dist_sqd) {
|
|
nearest = n2.nearest;
|
|
dist_sqd = n2.dist_sqd;
|
|
}
|
|
|
|
T3(nearest: nearest, dist_sqd: dist_sqd, nodes_visited: nodes_visited);
|
|
}
|
|
|
|
return nn(t.n, p, t.bounds, Inf);
|
|
}
|
|
|
|
func show_nearest(k, heading, kd, p) {
|
|
print <<-"END"
|
|
#{heading}:
|
|
Point: [#{p.join(',')}]
|
|
END
|
|
var n = find_nearest(k, kd, p);
|
|
print <<-"END"
|
|
Nearest neighbor: [#{n.nearest.join(',')}]
|
|
Distance: #{sqrt(n.dist_sqd)}
|
|
Nodes visited: #{n.nodes_visited()}
|
|
|
|
END
|
|
}
|
|
|
|
func random_point(k) { k.of { 1.rand } }
|
|
func random_points(k, n) { n.of { random_point(k) } }
|
|
|
|
var kd1 = Kd_tree([[2, 3],[5, 4],[9, 6],[4, 7],[8, 1],[7, 2]],
|
|
Orthotope(min: [0, 0], max: [10, 10]));
|
|
show_nearest(2, "Wikipedia example data", kd1, [9, 2]);
|
|
|
|
var N = 1000
|
|
var t0 = Time.micro
|
|
var kd2 = Kd_tree(random_points(3, N), Orthotope(min: [0,0,0], max: [1,1,1]))
|
|
|
|
var t1 = Time.micro
|
|
show_nearest(2,
|
|
"k-d tree with #{N} random 3D points (generation time: #{t1 - t0}s)",
|
|
kd2, random_point(3))
|