(* Floyd-Warshall algorithm. See https://en.wikipedia.org/w/index.php?title=Floyd%E2%80%93Warshall_algorithm&oldid=1082310013 ------------------------- WHY A SECOND ATS VERSION? ------------------------- From the first ATS version, I derived a version in OCaml, which modularized the code. From the OCaml, I produced a Standard ML implementation that also made the types abstract. Now I am returning to the ATS, to backport (among other things) the abstraction of types. In fact I increase the abstraction, in a way that protects the programmer against accidentally using the "uninitialized" entries of the "distance" array. Thus one can follow the chain of improvement, and also compare how type abstraction is done Standard ML and in ATS. In ATS, type abstraction can be done using "assume" statements or type casts. *) #include "share/atspre_staload.hats" #define NIL list_nil () #define :: list_cons typedef Pos = [i : pos] int i (*------------------------------------------------------------------*) (* You can change floatpt from "float" to "double" or another type, if you wish. *) typedef floatpt = float extern castfn int2floatpt : int -<> floatpt overload i2fp with int2floatpt (*------------------------------------------------------------------*) (* Square arrays with 1-based indexing. *) local typedef _square_array (t : t@ype+, n : int) = (* '{ ... } with a "'" means the type is pointer to a record allocated by the garbage collector. *) '{ side_length = int n, elements = arrayref (t, n * n) } in abstype square_array (t : t@ype+, n : int) assume square_array (t, n) = _square_array (t, n) extern praxi lemma_square_array_indices {n : pos} {i, j : pos | i <= n; j <= n} () : [0 <= (i - 1) + ((j - 1) * n); (i - 1) + ((j - 1) * n) < n * n] void fn {t : t@ype} square_array_make {n : nat} (n : int n, fill : t) : square_array (t, n) = let prval () = mul_gte_gte_gte {n, n} () in '{ side_length = n, elements = arrayref_make_elt (i2sz (n * n), fill) } end fn {t : t@ype} square_array_get_at {n : pos} {i, j : pos | i <= n; j <= n} (arr : square_array (t, n), i : int i, j : int j) : t = let prval () = lemma_square_array_indices {n} {i, j} () in arrayref_get_at (arr.elements, (i - 1) + ((j - 1) * arr.side_length)) end fn {t : t@ype} square_array_set_at {n : pos} {i, j : pos | i <= n; j <= n} (arr : square_array (t, n), i : int i, j : int j, x : t) : void = let prval () = lemma_square_array_indices {n} {i, j} () in arrayref_set_at (arr.elements, (i - 1) + ((j - 1) * arr.side_length), x) end overload [] with square_array_get_at overload [] with square_array_set_at end (* local *) (*------------------------------------------------------------------*) (* A vertex made more abstract than simply identifying it with an integer. *) (* The following "abst@ype" tells the compiler that "vertex" is the same size as "int" (as opposed to the size of a pointer, which "abstype" assumes). It does *not* identify "vertex" with "int". *) abst@ype vertex (i : int) = int typedef vertex = [i : nat] vertex i (* These casts let us convert between int and the abstract type. *) extern castfn int2vertex : {i : nat} int i -<> vertex i extern castfn vertex2int : {i : nat} vertex i -<> int i macdef nil_vertex = int2vertex 0 fn vertex_is_nil {u : nat} (u : vertex u) :<> bool (u == 0) = vertex2int u = vertex2int nil_vertex fn vertex_isnot_nil {u : nat} (u : vertex u) :<> bool (u != 0) = ~vertex_is_nil u fn vertex_eq {u, v : nat} (u : vertex u, v : vertex v) :<> bool (u == v) = vertex2int u = vertex2int v fn vertex_neq {u, v : nat} (u : vertex u, v : vertex v) :<> bool (u <> v) = ~vertex_eq (u, v) fn vertex_max {u, v : nat} (u : vertex u, v : vertex v) :<> vertex (max (u, v)) = int2vertex (max (vertex2int u, vertex2int v)) fn tostring_vertex (u : vertex) :<> string = tostring_int (vertex2int u) fn tostring_directed_vertex_list (lst : List vertex) : string = let fun loop {n : nat} .. (lst : list (vertex, n), s : string) : string = case+ lst of | NIL => s | u :: tail => let val s_u = tostring_vertex u in if s = "" then loop (tail, s_u) else let val s1 = strptr2string (string_append (s, " -> ", s_u)) in loop (tail, s1) end end prval () = lemma_list_param lst in loop (lst, "") end overload iseqz with vertex_is_nil overload isneqz with vertex_isnot_nil overload = with vertex_eq overload <> with vertex_neq overload max with vertex_max (*------------------------------------------------------------------*) (* Graph edges, with weights. *) local typedef _edge (u : int, v : int) = (* The type is pointer to a tuple allocated by the garbage collector. *) [1 <= u; 1 <= v] '(vertex u, floatpt, vertex v) in abstype edge (u : int, v : int) typedef edge = [u, v : pos] edge (u, v) assume edge (u, v) = _edge (u, v) fn edge_make {u, v : pos} (u : vertex u, weight : floatpt, v : vertex v) :<> edge (u, v) = '(u, weight, v) fn edge_first {u, v : pos} (edge : edge (u, v)) :<> vertex u = edge.0 fn edge_weight (edge : edge) :<> floatpt = edge.1 fn edge_second {u, v : pos} (edge : edge (u, v)) :<> vertex v = edge.2 fn max_vertex_in_edge_list (lst : List edge) :<> vertex = let fun loop {n : nat} .. (lst : list (edge, n), x : vertex) :<> vertex = case+ lst of | NIL => x | edge :: tail => loop (tail, max (max (edge_first edge, edge_second edge), x)) prval () = lemma_list_param lst in loop (lst, nil_vertex) end end (* local *) (*------------------------------------------------------------------*) (* Floyd-Warshall. *) local typedef _floyd_warshall_result (n : int) = '{ n = int n, dist = square_array (floatpt, n), next = square_array (vertex, n) } fn {} _dist_get_at {n : pos} {i, j : pos | i <= n; j <= n} (dist : square_array (floatpt, n), i : int i, j : int j) : floatpt = square_array_get_at (dist, i, j) fn _dist_set_at {n : pos} {i, j : pos | i <= n; j <= n} (dist : square_array (floatpt, n), i : int i, j : int j, x : floatpt) : void = square_array_set_at (dist, i, j, x) fn {} _next_get_at {n : pos} {i, j : pos | i <= n; j <= n} (next : square_array (vertex, n), i : int i, j : int j) : vertex = square_array_get_at (next, i, j) fn _next_set_at {n : pos} {i, j : pos | i <= n; j <= n} (next : square_array (vertex, n), i : int i, j : int j, x : vertex) : void = square_array_set_at (next, i, j, x) in abstype floyd_warshall_result (n : int) typedef floyd_warshall_result = [n : nat] floyd_warshall_result n assume floyd_warshall_result n = _floyd_warshall_result n exception FloydWarshallError of (string) fn vertex_count {n : pos} (fw : floyd_warshall_result n) :<> int n = fw.n fn get_distance {n : pos} {i, j : pos | i <= n; j <= n} (fw : floyd_warshall_result n, i : vertex i, j : vertex j) : Option floatpt = (* Notice there is *no way* to return one of the "uninitialized" values in the "dist" array (which were actually set to a meaningless value, or could have been set to positive infinity). Instead you get "None()". This kind of behavior is better than returning "positive infinity", because it does not depend on any particular sort of floating point. Indeed, in Ada you could use fixed point. *) let val i = vertex2int i val j = vertex2int j val u = _next_get_at (fw.next, i, j) in if iseqz u then None () (* There is no finite path. *) else Some (_dist_get_at (fw.dist, i, j)) end fn get_next_vertex {n : pos} {i, j : pos | i <= n; j <= n} (fw : floyd_warshall_result n, i : vertex i, j : vertex j) : vertex = _next_get_at (fw.next, vertex2int i, vertex2int j) fn floyd_warshall (edges : List edge) :<1> [n : pos] floyd_warshall_result n = let val n = vertex2int (max_vertex_in_edge_list edges) in if n = 0 then $raise FloydWarshallError ("no vertices") else let macdef arbitrary_floatpt = i2fp (12345) val dist = square_array_make (n, arbitrary_floatpt) val next = square_array_make (n, nil_vertex) in (* Initialize. *) let var i : Pos in for (i := 1; i <= n; i := succ i) let var j : Pos in for (j := 1; j <= n; j := succ j) next[i, j] := nil_vertex end end; let var p : List edge in for (p := edges; list_is_cons p; p := list_tail p) let val edge = list_head p val u = edge_first edge val () = assertloc (isneqz u) val () = assertloc (vertex2int u <= n) val v = edge_second edge val () = assertloc (isneqz v) val () = assertloc (vertex2int v <= n) in dist[vertex2int u, vertex2int v] := edge_weight edge; next[vertex2int u, vertex2int v] := v end end; let var i : Pos in for (i := 1; i <= n; i := succ i) begin (* Distance from a vertex to itself is zero. *) dist[i, i] := int2floatpt (0); next[i, i] := int2vertex i end end; (* Perform the algorithm. *) let var k : Pos in for (k := 1; k <= n; k := succ k) let var i : Pos in for (i := 1; i <= n; i := succ i) let var j : Pos in for (j := 1; j <= n; j := succ j) if isneqz next[i, k] && isneqz next[k, j] then let val dist_ikj = dist[i, k] + dist[k, j] in if iseqz next[i, j] || dist_ikj < dist[i, j] then begin dist[i, j] := dist_ikj; next[i, j] := next[i, k] end end end end end; (* Return the result. *) '{ n = n, dist = dist, next = next } end end fn get_path {n : int} {u, v : pos} (fw : floyd_warshall_result n, u : vertex u, v : vertex v) : List vertex = if (fw.n) < vertex2int u then $raise FloydWarshallError ("vertex not found") else if (fw.n) < vertex2int v then $raise FloydWarshallError ("vertex not found") else if iseqz (get_next_vertex (fw, u, v)) then NIL else let fun loop (w : vertex, lst : List0 vertex) : List vertex = if w = v then list_vt2t (list_reverse lst) else let val () = $effmask_exn assertloc (isneqz w) val () = $effmask_exn assertloc (vertex2int w <= (fw.n)) val w = get_next_vertex (fw, w, v) in loop (w, w :: lst) end in $effmask_ntm loop (u, u :: NIL) end end (* local *) (*------------------------------------------------------------------*) implement main0 () = let val example_graph = $list (edge_make (int2vertex 1, i2fp (~2), int2vertex 3), edge_make (int2vertex 3, i2fp (2), int2vertex 4), edge_make (int2vertex 4, i2fp (~1), int2vertex 2), edge_make (int2vertex 2, i2fp (4), int2vertex 1), edge_make (int2vertex 2, i2fp (3), int2vertex 3)) val fw = floyd_warshall example_graph in println! (" pair distance path"); println! ("------------------------------------------"); let var i : Pos in for (i := 1; i <= (fw.n); i := succ i) let var j : Pos in for (j := 1; j <= (fw.n); j := succ j) let val u = int2vertex i val v = int2vertex j in if u <> v then let val s_edge = tostring_directed_vertex_list ($list (u, v)) val distance_opt = get_distance (fw, u, v) in print! (" ", s_edge, " "); begin case+ distance_opt of | None () => print! " no path" | Some distance => let val path = get_path (fw, u, v) val s_path = tostring_directed_vertex_list path in if int2floatpt (0) <= distance then print! " "; print! distance; print! " "; print! s_path end end; println! () end end end end end (*------------------------------------------------------------------*)