544 lines
15 KiB
Plaintext
544 lines
15 KiB
Plaintext
(*
|
|
Floyd-Warshall algorithm.
|
|
|
|
See https://en.wikipedia.org/w/index.php?title=Floyd%E2%80%93Warshall_algorithm&oldid=1082310013
|
|
|
|
|
|
-------------------------
|
|
WHY A SECOND ATS VERSION?
|
|
-------------------------
|
|
|
|
From the first ATS version, I derived a version in OCaml, which
|
|
modularized the code. From the OCaml, I produced a Standard ML
|
|
implementation that also made the types abstract.
|
|
|
|
Now I am returning to the ATS, to backport (among other things) the
|
|
abstraction of types. In fact I increase the abstraction, in a way
|
|
that protects the programmer against accidentally using the
|
|
"uninitialized" entries of the "distance" array.
|
|
|
|
Thus one can follow the chain of improvement, and also compare how
|
|
type abstraction is done Standard ML and in ATS. In ATS, type
|
|
abstraction can be done using "assume" statements or type casts.
|
|
|
|
*)
|
|
|
|
#include "share/atspre_staload.hats"
|
|
|
|
#define NIL list_nil ()
|
|
#define :: list_cons
|
|
|
|
typedef Pos = [i : pos] int i
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* You can change floatpt from "float" to "double" or another type,
|
|
if you wish. *)
|
|
|
|
typedef floatpt = float
|
|
|
|
extern castfn int2floatpt : int -<> floatpt
|
|
overload i2fp with int2floatpt
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* Square arrays with 1-based indexing. *)
|
|
|
|
local
|
|
|
|
typedef _square_array (t : t@ype+, n : int) =
|
|
(* '{ ... } with a "'" means the type is pointer to a record
|
|
allocated by the garbage collector. *)
|
|
'{
|
|
side_length = int n,
|
|
elements = arrayref (t, n * n)
|
|
}
|
|
|
|
in
|
|
|
|
abstype square_array (t : t@ype+, n : int)
|
|
|
|
assume square_array (t, n) = _square_array (t, n)
|
|
|
|
extern praxi
|
|
lemma_square_array_indices {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
() :<prf>
|
|
[0 <= (i - 1) + ((j - 1) * n);
|
|
(i - 1) + ((j - 1) * n) < n * n]
|
|
void
|
|
|
|
fn {t : t@ype}
|
|
square_array_make {n : nat}
|
|
(n : int n,
|
|
fill : t) :<!wrt> square_array (t, n) =
|
|
let
|
|
prval () = mul_gte_gte_gte {n, n} ()
|
|
in
|
|
'{
|
|
side_length = n,
|
|
elements = arrayref_make_elt (i2sz (n * n), fill)
|
|
}
|
|
end
|
|
|
|
fn {t : t@ype}
|
|
square_array_get_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(arr : square_array (t, n),
|
|
i : int i,
|
|
j : int j) :<!ref> t =
|
|
let
|
|
prval () = lemma_square_array_indices {n} {i, j} ()
|
|
in
|
|
arrayref_get_at (arr.elements,
|
|
(i - 1) + ((j - 1) * arr.side_length))
|
|
end
|
|
|
|
fn {t : t@ype}
|
|
square_array_set_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(arr : square_array (t, n),
|
|
i : int i,
|
|
j : int j,
|
|
x : t) :<!refwrt> void =
|
|
let
|
|
prval () = lemma_square_array_indices {n} {i, j} ()
|
|
in
|
|
arrayref_set_at (arr.elements,
|
|
(i - 1) + ((j - 1) * arr.side_length),
|
|
x)
|
|
end
|
|
|
|
overload [] with square_array_get_at
|
|
overload [] with square_array_set_at
|
|
|
|
end (* local *)
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* A vertex made more abstract than simply identifying it with an
|
|
integer. *)
|
|
|
|
(* The following "abst@ype" tells the compiler that "vertex" is the
|
|
same size as "int" (as opposed to the size of a pointer, which
|
|
"abstype" assumes). It does *not* identify "vertex" with "int". *)
|
|
abst@ype vertex (i : int) = int
|
|
|
|
typedef vertex = [i : nat] vertex i
|
|
|
|
(* These casts let us convert between int and the abstract type. *)
|
|
extern castfn int2vertex : {i : nat} int i -<> vertex i
|
|
extern castfn vertex2int : {i : nat} vertex i -<> int i
|
|
|
|
macdef nil_vertex = int2vertex 0
|
|
|
|
fn
|
|
vertex_is_nil {u : nat}
|
|
(u : vertex u) :<> bool (u == 0) =
|
|
vertex2int u = vertex2int nil_vertex
|
|
|
|
fn
|
|
vertex_isnot_nil {u : nat}
|
|
(u : vertex u) :<> bool (u != 0) =
|
|
~vertex_is_nil u
|
|
|
|
fn
|
|
vertex_eq {u, v : nat}
|
|
(u : vertex u,
|
|
v : vertex v) :<> bool (u == v) =
|
|
vertex2int u = vertex2int v
|
|
|
|
fn
|
|
vertex_neq {u, v : nat}
|
|
(u : vertex u,
|
|
v : vertex v) :<> bool (u <> v) =
|
|
~vertex_eq (u, v)
|
|
|
|
fn
|
|
vertex_max {u, v : nat}
|
|
(u : vertex u,
|
|
v : vertex v) :<> vertex (max (u, v)) =
|
|
int2vertex (max (vertex2int u, vertex2int v))
|
|
|
|
fn
|
|
tostring_vertex (u : vertex) :<> string =
|
|
tostring_int (vertex2int u)
|
|
|
|
fn
|
|
tostring_directed_vertex_list (lst : List vertex) :<!wrt> string =
|
|
let
|
|
fun
|
|
loop {n : nat} .<n>.
|
|
(lst : list (vertex, n),
|
|
s : string) :<!wrt> string =
|
|
case+ lst of
|
|
| NIL => s
|
|
| u :: tail =>
|
|
let
|
|
val s_u = tostring_vertex u
|
|
in
|
|
if s = "" then
|
|
loop (tail, s_u)
|
|
else
|
|
let
|
|
val s1 = strptr2string (string_append (s, " -> ", s_u))
|
|
in
|
|
loop (tail, s1)
|
|
end
|
|
|
|
end
|
|
|
|
prval () = lemma_list_param lst
|
|
in
|
|
loop (lst, "")
|
|
end
|
|
|
|
overload iseqz with vertex_is_nil
|
|
overload isneqz with vertex_isnot_nil
|
|
overload = with vertex_eq
|
|
overload <> with vertex_neq
|
|
overload max with vertex_max
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* Graph edges, with weights. *)
|
|
|
|
local
|
|
|
|
typedef _edge (u : int, v : int) =
|
|
(* The type is pointer to a tuple allocated by the garbage
|
|
collector. *)
|
|
[1 <= u; 1 <= v] '(vertex u, floatpt, vertex v)
|
|
|
|
in
|
|
|
|
abstype edge (u : int, v : int)
|
|
typedef edge = [u, v : pos] edge (u, v)
|
|
|
|
assume edge (u, v) = _edge (u, v)
|
|
|
|
fn
|
|
edge_make {u, v : pos}
|
|
(u : vertex u,
|
|
weight : floatpt,
|
|
v : vertex v) :<> edge (u, v) =
|
|
'(u, weight, v)
|
|
|
|
fn
|
|
edge_first {u, v : pos}
|
|
(edge : edge (u, v)) :<> vertex u =
|
|
edge.0
|
|
|
|
fn
|
|
edge_weight (edge : edge) :<> floatpt =
|
|
edge.1
|
|
|
|
fn
|
|
edge_second {u, v : pos}
|
|
(edge : edge (u, v)) :<> vertex v =
|
|
edge.2
|
|
|
|
fn
|
|
max_vertex_in_edge_list (lst : List edge) :<> vertex =
|
|
let
|
|
fun
|
|
loop {n : nat} .<n>.
|
|
(lst : list (edge, n),
|
|
x : vertex) :<> vertex =
|
|
case+ lst of
|
|
| NIL => x
|
|
| edge :: tail =>
|
|
loop (tail,
|
|
max (max (edge_first edge, edge_second edge), x))
|
|
|
|
prval () = lemma_list_param lst
|
|
in
|
|
loop (lst, nil_vertex)
|
|
end
|
|
|
|
end (* local *)
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* Floyd-Warshall. *)
|
|
|
|
local
|
|
|
|
typedef _floyd_warshall_result (n : int) =
|
|
'{
|
|
n = int n,
|
|
dist = square_array (floatpt, n),
|
|
next = square_array (vertex, n)
|
|
}
|
|
|
|
fn {}
|
|
_dist_get_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(dist : square_array (floatpt, n),
|
|
i : int i,
|
|
j : int j) :<!ref> floatpt =
|
|
square_array_get_at (dist, i, j)
|
|
|
|
fn
|
|
_dist_set_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(dist : square_array (floatpt, n),
|
|
i : int i,
|
|
j : int j,
|
|
x : floatpt) :<!refwrt> void =
|
|
square_array_set_at (dist, i, j, x)
|
|
|
|
fn {}
|
|
_next_get_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(next : square_array (vertex, n),
|
|
i : int i,
|
|
j : int j) :<!ref> vertex =
|
|
square_array_get_at (next, i, j)
|
|
|
|
fn
|
|
_next_set_at {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(next : square_array (vertex, n),
|
|
i : int i,
|
|
j : int j,
|
|
x : vertex) :<!refwrt> void =
|
|
square_array_set_at (next, i, j, x)
|
|
|
|
in
|
|
|
|
abstype floyd_warshall_result (n : int)
|
|
typedef floyd_warshall_result = [n : nat] floyd_warshall_result n
|
|
|
|
assume floyd_warshall_result n = _floyd_warshall_result n
|
|
|
|
exception FloydWarshallError of (string)
|
|
|
|
fn
|
|
vertex_count {n : pos}
|
|
(fw : floyd_warshall_result n) :<> int n =
|
|
fw.n
|
|
|
|
fn
|
|
get_distance {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(fw : floyd_warshall_result n,
|
|
i : vertex i,
|
|
j : vertex j) :<!ref> Option floatpt =
|
|
|
|
(* Notice there is *no way* to return one of the "uninitialized"
|
|
values in the "dist" array (which were actually set to a
|
|
meaningless value, or could have been set to positive
|
|
infinity). Instead you get "None()".
|
|
|
|
This kind of behavior is better than returning "positive
|
|
infinity", because it does not depend on any particular sort of
|
|
floating point. Indeed, in Ada you could use fixed point. *)
|
|
|
|
let
|
|
val i = vertex2int i
|
|
val j = vertex2int j
|
|
val u = _next_get_at (fw.next, i, j)
|
|
in
|
|
if iseqz u then
|
|
None () (* There is no finite path. *)
|
|
else
|
|
Some (_dist_get_at (fw.dist, i, j))
|
|
end
|
|
|
|
fn
|
|
get_next_vertex {n : pos}
|
|
{i, j : pos | i <= n; j <= n}
|
|
(fw : floyd_warshall_result n,
|
|
i : vertex i,
|
|
j : vertex j) :<!ref> vertex =
|
|
_next_get_at (fw.next, vertex2int i, vertex2int j)
|
|
|
|
fn
|
|
floyd_warshall (edges : List edge)
|
|
:<1> [n : pos] floyd_warshall_result n =
|
|
let
|
|
val n = vertex2int (max_vertex_in_edge_list edges)
|
|
in
|
|
if n = 0 then
|
|
$raise FloydWarshallError ("no vertices")
|
|
else
|
|
let
|
|
macdef arbitrary_floatpt = i2fp (12345)
|
|
val dist = square_array_make<floatpt> (n, arbitrary_floatpt)
|
|
val next = square_array_make<vertex> (n, nil_vertex)
|
|
in
|
|
|
|
(* Initialize. *)
|
|
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
let
|
|
var j : Pos
|
|
in
|
|
for (j := 1; j <= n; j := succ j)
|
|
next[i, j] := nil_vertex
|
|
end
|
|
end;
|
|
let
|
|
var p : List edge
|
|
in
|
|
for (p := edges; list_is_cons p; p := list_tail p)
|
|
let
|
|
val edge = list_head p
|
|
val u = edge_first edge
|
|
val () = assertloc (isneqz u)
|
|
val () = assertloc (vertex2int u <= n)
|
|
val v = edge_second edge
|
|
val () = assertloc (isneqz v)
|
|
val () = assertloc (vertex2int v <= n)
|
|
in
|
|
dist[vertex2int u, vertex2int v] := edge_weight edge;
|
|
next[vertex2int u, vertex2int v] := v
|
|
end
|
|
end;
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
begin
|
|
(* Distance from a vertex to itself is zero. *)
|
|
dist[i, i] := int2floatpt (0);
|
|
next[i, i] := int2vertex i
|
|
end
|
|
end;
|
|
|
|
(* Perform the algorithm. *)
|
|
|
|
let
|
|
var k : Pos
|
|
in
|
|
for (k := 1; k <= n; k := succ k)
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= n; i := succ i)
|
|
let
|
|
var j : Pos
|
|
in
|
|
for (j := 1; j <= n; j := succ j)
|
|
if isneqz next[i, k] && isneqz next[k, j] then
|
|
let
|
|
val dist_ikj = dist[i, k] + dist[k, j]
|
|
in
|
|
if iseqz next[i, j]
|
|
|| dist_ikj < dist[i, j] then
|
|
begin
|
|
dist[i, j] := dist_ikj;
|
|
next[i, j] := next[i, k]
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end;
|
|
|
|
(* Return the result. *)
|
|
|
|
'{ n = n, dist = dist, next = next }
|
|
|
|
end
|
|
end
|
|
|
|
fn
|
|
get_path {n : int}
|
|
{u, v : pos}
|
|
(fw : floyd_warshall_result n,
|
|
u : vertex u,
|
|
v : vertex v) :<!refwrt,!exn> List vertex =
|
|
if (fw.n) < vertex2int u then
|
|
$raise FloydWarshallError ("vertex not found")
|
|
else if (fw.n) < vertex2int v then
|
|
$raise FloydWarshallError ("vertex not found")
|
|
else
|
|
if iseqz (get_next_vertex (fw, u, v)) then
|
|
NIL
|
|
else
|
|
let
|
|
fun
|
|
loop (w : vertex,
|
|
lst : List0 vertex) :<!ntm,!refwrt> List vertex =
|
|
if w = v then
|
|
list_vt2t (list_reverse lst)
|
|
else
|
|
let
|
|
val () =
|
|
$effmask_exn assertloc (isneqz w)
|
|
val () =
|
|
$effmask_exn assertloc (vertex2int w <= (fw.n))
|
|
val w = get_next_vertex (fw, w, v)
|
|
in
|
|
loop (w, w :: lst)
|
|
end
|
|
in
|
|
$effmask_ntm loop (u, u :: NIL)
|
|
end
|
|
|
|
end (* local *)
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement
|
|
main0 () =
|
|
let
|
|
val example_graph =
|
|
$list (edge_make (int2vertex 1, i2fp (~2), int2vertex 3),
|
|
edge_make (int2vertex 3, i2fp (2), int2vertex 4),
|
|
edge_make (int2vertex 4, i2fp (~1), int2vertex 2),
|
|
edge_make (int2vertex 2, i2fp (4), int2vertex 1),
|
|
edge_make (int2vertex 2, i2fp (3), int2vertex 3))
|
|
|
|
val fw = floyd_warshall example_graph
|
|
in
|
|
println! (" pair distance path");
|
|
println! ("------------------------------------------");
|
|
let
|
|
var i : Pos
|
|
in
|
|
for (i := 1; i <= (fw.n); i := succ i)
|
|
let
|
|
var j : Pos
|
|
in
|
|
for (j := 1; j <= (fw.n); j := succ j)
|
|
let
|
|
val u = int2vertex i
|
|
val v = int2vertex j
|
|
in
|
|
if u <> v then
|
|
let
|
|
val s_edge =
|
|
tostring_directed_vertex_list ($list (u, v))
|
|
val distance_opt = get_distance (fw, u, v)
|
|
in
|
|
print! (" ", s_edge, " ");
|
|
begin
|
|
case+ distance_opt of
|
|
| None () => print! " no path"
|
|
| Some distance =>
|
|
let
|
|
val path = get_path (fw, u, v)
|
|
val s_path =
|
|
tostring_directed_vertex_list path
|
|
in
|
|
if int2floatpt (0) <= distance then
|
|
print! " ";
|
|
print! distance;
|
|
print! " ";
|
|
print! s_path
|
|
end
|
|
end;
|
|
println! ()
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|