294 lines
7.7 KiB
Plaintext
294 lines
7.7 KiB
Plaintext
(*
|
|
Floyd-Warshall algorithm.
|
|
|
|
See https://en.wikipedia.org/w/index.php?title=Floyd%E2%80%93Warshall_algorithm&oldid=1082310013
|
|
*)
|
|
|
|
#include "share/atspre_staload.hats"
|
|
|
|
#define NIL list_nil ()
|
|
#define :: list_cons
|
|
|
|
typedef Pos = [i : pos] int i
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* Square arrays with 1-based indexing. *)
|
|
|
|
extern praxi
|
|
lemma_square_array_indices {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
() :<prf>
|
|
[0 <= (i - 1) + ((j - 1) * n);
|
|
(i - 1) + ((j - 1) * n) < n * n]
|
|
void
|
|
|
|
typedef square_array (t : t@ype+, n : int) =
|
|
'{
|
|
side_length = int n,
|
|
elements = arrayref (t, n * n)
|
|
}
|
|
|
|
fn {t : t@ype}
|
|
make_square_array {n : nat}
|
|
(n : int n,
|
|
fill : t) : square_array (t, n) =
|
|
let
|
|
prval () = mul_gte_gte_gte {n, n} ()
|
|
in
|
|
'{
|
|
side_length = n,
|
|
elements = arrayref_make_elt (i2sz (n * n), fill)
|
|
}
|
|
end
|
|
|
|
fn {t : t@ype}
|
|
square_array_get_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(arr : square_array (t, n),
|
|
i : int i,
|
|
j : int j) : t =
|
|
let
|
|
prval () = lemma_square_array_indices {n} {i, j} ()
|
|
in
|
|
arrayref_get_at (arr.elements,
|
|
(i - 1) + ((j - 1) * arr.side_length))
|
|
end
|
|
|
|
fn {t : t@ype}
|
|
square_array_set_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(arr : square_array (t, n),
|
|
i : int i,
|
|
j : int j,
|
|
x : t) : void =
|
|
let
|
|
prval () = lemma_square_array_indices {n} {i, j} ()
|
|
in
|
|
arrayref_set_at (arr.elements,
|
|
(i - 1) + ((j - 1) * arr.side_length),
|
|
x)
|
|
end
|
|
|
|
overload [] with square_array_get_at
|
|
overload [] with square_array_set_at
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
typedef floatpt = float
|
|
extern castfn i2floatpt : int -<> floatpt
|
|
macdef arbitrary_floatpt = i2floatpt (12345)
|
|
|
|
typedef distance_array (n : int) = square_array (floatpt, n)
|
|
|
|
typedef vertex = [i : nat] int i
|
|
#define NIL_VERTEX 0
|
|
typedef next_vertex_array (n : int) = square_array (vertex, n)
|
|
|
|
typedef edge =
|
|
'{ (* The ' means this is allocated by the garbage collector.*)
|
|
u = vertex,
|
|
weight = floatpt,
|
|
v = vertex
|
|
}
|
|
typedef edge_list (n : int) = list (edge, n)
|
|
typedef edge_list = [n : int] edge_list (n)
|
|
|
|
prfn (* edge_list have non-negative size. *)
|
|
lemma_edge_list_param {n : int} (edges : edge_list n)
|
|
:<prf> [0 <= n] void =
|
|
lemma_list_param edges
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
fn
|
|
find_max_vertex (edges : edge_list) : vertex =
|
|
let
|
|
fun
|
|
loop {n : nat} .<n>.
|
|
(p : edge_list n,
|
|
u : vertex) : vertex =
|
|
case+ p of
|
|
| NIL => u
|
|
| head :: tail =>
|
|
loop (tail, max (max (u, (head.u)), (head.v)))
|
|
|
|
prval () = lemma_edge_list_param edges
|
|
in
|
|
assertloc (isneqz edges);
|
|
loop (edges, 0)
|
|
end
|
|
|
|
fn
|
|
floyd_warshall {n : int}
|
|
(edges : edge_list,
|
|
n : int n,
|
|
distance : distance_array n,
|
|
next_vertex : next_vertex_array n) : void =
|
|
let
|
|
val () = assertloc (1 <= n)
|
|
in
|
|
|
|
(* This implementation does NOT initialize (to any meaningful
|
|
value) elements of "distance" that would be set "infinite" in
|
|
the Wikipedia pseudocode. Instead you should use the
|
|
"next_vertex" array to determine whether there exists a finite
|
|
path from one vertex to another.
|
|
|
|
Thus we avoid any dependence on IEEE floating point or on the
|
|
settings of the FPU. *)
|
|
|
|
(* Initialize. *)
|
|
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
let
|
|
var j : Pos
|
|
in
|
|
for (j := 1; j <= n; j := succ j)
|
|
next_vertex[i, j] := NIL_VERTEX
|
|
end
|
|
end;
|
|
let
|
|
var p : edge_list
|
|
in
|
|
for (p := edges; list_is_cons p; p := list_tail p)
|
|
let
|
|
val head = list_head p
|
|
val u = head.u
|
|
val () = assertloc (u <> NIL_VERTEX)
|
|
val () = assertloc (u <= n)
|
|
val v = head.v
|
|
val () = assertloc (v <> NIL_VERTEX)
|
|
val () = assertloc (v <= n)
|
|
in
|
|
distance[u, v] := head.weight;
|
|
next_vertex[u, v] := v
|
|
end
|
|
end;
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
begin
|
|
(* Distance from a vertex to itself is zero. *)
|
|
distance[i, i] := i2floatpt (0);
|
|
next_vertex[i, i] := i
|
|
end
|
|
end;
|
|
|
|
(* Perform the algorithm. *)
|
|
|
|
let
|
|
var k : Pos
|
|
in
|
|
for (k := 1; k <= n; k := succ k)
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
let
|
|
var j : Pos
|
|
in
|
|
for (j := 1; j <= n; j := succ j)
|
|
if next_vertex[i, k] <> NIL_VERTEX
|
|
&& next_vertex[k, j] <> NIL_VERTEX then
|
|
let
|
|
val dist_ikj = distance[i, k] + distance[k, j]
|
|
in
|
|
if next_vertex[i, j] = NIL_VERTEX
|
|
|| dist_ikj < distance[i, j] then
|
|
begin
|
|
distance[i, j] := dist_ikj;
|
|
next_vertex[i, j] := next_vertex[i, k]
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
end
|
|
|
|
fn
|
|
print_path {n : int}
|
|
(n : int n,
|
|
next_vertex : next_vertex_array n,
|
|
u : Pos,
|
|
v : Pos) : void =
|
|
if 0 < n then
|
|
let
|
|
val () = assertloc (u <= n)
|
|
val () = assertloc (v <= n)
|
|
in
|
|
if next_vertex[u, v] <> NIL_VERTEX then
|
|
let
|
|
var i : Int
|
|
in
|
|
i := u;
|
|
print! (i);
|
|
while (i <> v)
|
|
let
|
|
val () = assertloc (1 <= i)
|
|
val () = assertloc (i <= n)
|
|
in
|
|
print! (" -> ");
|
|
i := next_vertex[i, v];
|
|
print! (i)
|
|
end
|
|
end
|
|
end
|
|
|
|
implement
|
|
main0 () =
|
|
let
|
|
|
|
(* One might notice that (because consing prepends rather than
|
|
appends) the order of edges here is *opposite* to that of some
|
|
other languages' implementations. But the order of the edges is
|
|
immaterial. *)
|
|
val example_graph = NIL
|
|
val example_graph =
|
|
'{u = 1, weight = i2floatpt (~2), v = 3} :: example_graph
|
|
val example_graph =
|
|
'{u = 3, weight = i2floatpt (2), v = 4} :: example_graph
|
|
val example_graph =
|
|
'{u = 4, weight = i2floatpt (~1), v = 2} :: example_graph
|
|
val example_graph =
|
|
'{u = 2, weight = i2floatpt (4), v = 1} :: example_graph
|
|
val example_graph =
|
|
'{u = 2, weight = i2floatpt (3), v = 3} :: example_graph
|
|
|
|
val n = find_max_vertex (example_graph)
|
|
val distance = make_square_array<floatpt> (n, arbitrary_floatpt)
|
|
val next_vertex = make_square_array<vertex> (n, NIL_VERTEX)
|
|
|
|
in
|
|
|
|
floyd_warshall (example_graph, n, distance, next_vertex);
|
|
|
|
println! (" pair distance path");
|
|
println! ("------------------------------------------");
|
|
let
|
|
var u : Pos
|
|
in
|
|
for (u := 1; u <= n; u := succ u)
|
|
let
|
|
var v : Pos
|
|
in
|
|
for (v := 1; v <= n; v := succ v)
|
|
if u <> v then
|
|
begin
|
|
print! (" ", u, " -> ", v, " ");
|
|
if i2floatpt (0) <= distance[u, v] then
|
|
print! (" ");
|
|
print! (distance[u, v], " ");
|
|
print_path (n, next_vertex, u, v);
|
|
println! ()
|
|
end
|
|
end
|
|
end
|
|
|
|
end
|