RosettaCodeData/Task/Floyd-Warshall-algorithm/OCaml/floyd-warshall-algorithm.ml

211 lines
4.5 KiB
OCaml

(*
Floyd-Warshall algorithm.
See https://en.wikipedia.org/w/index.php?title=Floyd%E2%80%93Warshall_algorithm&oldid=1082310013
*)
module Square_array =
(* Square arrays with 1-based indexing. *)
struct
type 'a t =
{
n : int;
r : 'a Array.t
}
let make n fill =
let r = Array.make (n * n) fill in
{ n = n; r = r }
let get arr (i, j) =
Array.get arr.r ((i - 1) + (arr.n * (j - 1)))
let set arr (i, j) x =
Array.set arr.r ((i - 1) + (arr.n * (j - 1))) x
end
module Vertex =
(* A vertex is a positive integer, or 0 for the nil object. *)
struct
type t = int
let nil = 0
let print_vertex u =
print_int u
let rec print_directed_list lst =
match lst with
| [] -> ()
| [u] -> print_vertex u
| u :: tail ->
begin
print_vertex u;
print_string " -> ";
print_directed_list tail
end
end
module Edge =
(* A graph edge. *)
struct
type t =
{
u : Vertex.t;
weight : Float.t;
v : Vertex.t
}
let make u weight v =
{ u = u; weight = weight; v = v }
end
module Paths =
(* The "next vertex" array and its operations. *)
struct
type t = Vertex.t Square_array.t
let make n =
Square_array.make n Vertex.nil
let get = Square_array.get
let set = Square_array.set
let path paths u v =
(* Path reconstruction. In the finest tradition of the standard
List module, this implementation is *not* tail recursive. *)
if Square_array.get paths (u, v) = Vertex.nil then
[]
else
let rec build_path paths u v =
if u = v then
[v]
else
let i = Square_array.get paths (u, v) in
u :: build_path paths i v
in
build_path paths u v
let print_path paths u v =
Vertex.print_directed_list (path paths u v)
end
module Distances =
(* The "distance" array and its operations. *)
struct
type t = Float.t Square_array.t
let make n =
Square_array.make n Float.infinity
let get = Square_array.get
let set = Square_array.set
end
let find_max_vertex edges =
(* This implementation is *not* tail recursive. *)
let rec find_max =
function
| [] -> Vertex.nil
| edge :: tail -> max (max Edge.(edge.u) Edge.(edge.v))
(find_max tail)
in
find_max edges
let floyd_warshall edges =
(* This implementation assumes IEEE floating point. The OCaml Float
module explicitly specifies 64-bit IEEE floating point. *)
let _ = assert (edges <> []) in
let n = find_max_vertex edges in
let dist = Distances.make n in
let next = Paths.make n in
let rec read_edges =
function
| [] -> ()
| edge :: tail ->
let u = Edge.(edge.u) in
let v = Edge.(edge.v) in
let weight = Edge.(edge.weight) in
begin
Distances.set dist (u, v) weight;
Paths.set next (u, v) v;
read_edges tail
end
in
begin
(* Initialization. *)
read_edges edges;
for i = 1 to n do
(* Distance from a vertex to itself = 0.0 *)
Distances.set dist (i, i) 0.0;
Paths.set next (i, i) i
done;
(* Perform the algorithm. *)
for k = 1 to n do
for i = 1 to n do
for j = 1 to n do
let dist_ij = Distances.get dist (i, j) in
let dist_ik = Distances.get dist (i, k) in
let dist_kj = Distances.get dist (k, j) in
let dist_ikj = dist_ik +. dist_kj in
if dist_ikj < dist_ij then
begin
Distances.set dist (i, j) dist_ikj;
Paths.set next (i, j) (Paths.get next (i, k))
end
done
done
done;
(* Return the results, as a 3-tuple. *)
(n, dist, next)
end
let example_graph =
[Edge.make 1 (-2.0) 3;
Edge.make 3 (+2.0) 4;
Edge.make 4 (-1.0) 2;
Edge.make 2 (+4.0) 1;
Edge.make 2 (+3.0) 3]
;;
let (n, dist, next) =
floyd_warshall example_graph
;;
print_string " pair distance path";
print_newline ();
print_string "---------------------------------------";
print_newline ();
for u = 1 to n do
for v = 1 to n do
if u <> v then
begin
print_string " ";
Vertex.print_directed_list [u; v];
print_string " ";
Printf.printf "%4.1f" (Distances.get dist (u, v));
print_string " ";
Paths.print_path next u v;
print_newline ()
end
done
done
;;