RosettaCodeData/Task/Quickselect-algorithm/ATS/quickselect-algorithm.ats

638 lines
19 KiB
Plaintext

(*------------------------------------------------------------------*)
(*
For linear linked lists, using a random pivot:
* stable three-way "separation" (a variant of quickselect)
* quickselect
* stable quicksort
Also a couple of routines for splitting lists according to a
predicate.
Linear list operations are destructive but may avoid doing many
unnecessary allocations. Also they do not require a garbage
collector.
*)
#include "share/atspre_staload.hats"
staload UN = "prelude/SATS/unsafe.sats"
#define NIL list_vt_nil ()
#define :: list_vt_cons
(*------------------------------------------------------------------*)
(* A simple linear congruential generator for pivot selection. *)
(* The multiplier lcg_a comes from Steele, Guy; Vigna, Sebastiano (28
September 2021). "Computationally easy, spectrally good multipliers
for congruential pseudorandom number generators".
arXiv:2001.05304v3 [cs.DS] *)
macdef lcg_a = $UN.cast{uint64} 0xf1357aea2e62a9c5LLU
(* lcg_c must be odd. *)
macdef lcg_c = $UN.cast{uint64} 0xbaceba11beefbeadLLU
var seed : uint64 = $UN.cast 0
val p_seed = addr@ seed
fn
random_double () :<!wrt> double =
let
val (pf, fpf | p_seed) = $UN.ptr0_vtake{uint64} p_seed
val old_seed = ptr_get<uint64> (pf | p_seed)
(* IEEE "binary64" or "double" has 52 bits of precision. We will
take the high 48 bits of the seed and divide it by 2**48, to
get a number 0.0 <= randnum < 1.0 *)
val high_48_bits = $UN.cast{double} (old_seed >> 16)
val divisor = $UN.cast{double} (1LLU << 48)
val randnum = high_48_bits / divisor
(* The following operation is modulo 2**64, by virtue of standard
C behavior for uint64_t. *)
val new_seed = (lcg_a * old_seed) + lcg_c
val () = ptr_set<uint64> (pf | p_seed, new_seed)
prval () = fpf pf
in
randnum
end
(*------------------------------------------------------------------*)
(* Destructive split into two lists: a list of leading elements that
satisfy a predicate, and the tail of that split. (This is similar
to "span!" in SRFI-1.) *)
extern fun {a : vt@ype}
list_vt_span {n : int}
(pred : &((&a) -<cloptr1> bool),
lst : list_vt (a, n))
: [n1, n2 : nat | n1 + n2 == n]
@(list_vt (a, n1),
list_vt (a, n2))
(* Destructive, stable partition into elements less than the pivot,
elements equal to the pivot, and elements greater than the
pivot. *)
extern fun {a : vt@ype}
list_vt_three_way_partition
{n : int}
(compare : &((&a, &a) -<cloptr1> int),
pivot : &a,
lst : list_vt (a, n))
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
@(list_vt (a, n1),
list_vt (a, n2),
list_vt (a, n3))
(* Destructive, stable partition into elements less than the kth least
element, elements equal to it, and elements greater than it. *)
extern fun {a : vt@ype}
list_vt_three_way_separation
{n, k : int | 0 <= k; k < n}
(compare : &((&a, &a) -<cloptr1> int),
k : int k,
lst : list_vt (a, n))
: [n1, n2, n3 : nat | n1 + n2 + n3 == n;
n1 <= k; k < n1 + n2]
@(int n1, list_vt (a, n1),
int n2, list_vt (a, n2),
int n3, list_vt (a, n3))
(* Destructive quickselect for linear elements. *)
extern fun {a : vt@ype}
list_vt_select_linear
{n, k : int | 0 <= k; k < n}
(compare : &((&a, &a) -<cloptr1> int),
k : int k,
lst : list_vt (a, n)) : a
extern fun {a : vt@ype}
list_vt_select_linear$clear (x : &a >> a?) : void
(* Destructive quickselect for non-linear elements. *)
extern fun {a : t@ype}
list_vt_select
{n, k : int | 0 <= k; k < n}
(compare : &((&a, &a) -<cloptr1> int),
k : int k,
lst : list_vt (a, n)) : a
(* Stable quicksort. Also returns the length. *)
extern fun {a : vt@ype}
list_vt_stable_sort
{n : int}
(compare : &((&a, &a) -<cloptr1> int),
lst : list_vt (a, n))
: @(int n, list_vt (a, n))
(*------------------------------------------------------------------*)
implement {a}
list_vt_span {n} (pred, lst) =
let
fun
loop {n : nat} .<n>.
(pred : &((&a) -<cloptr1> bool),
cursor : &list_vt (a, n) >> list_vt (a, m),
tail : &List_vt a? >> list_vt (a, n - m))
: #[m : nat | m <= n] void =
case+ cursor of
| NIL => tail := NIL
| @ elem :: rest =>
if pred (elem) then
(* elem satisfies the predicate. Move the cursor to the next
cons-pair in the list. *)
let
val () = loop {n - 1} (pred, rest, tail)
prval () = fold@ cursor
in
end
else
(* elem does not satisfy the predicate. Split the list at
the cursor. *)
let
prval () = fold@ cursor
val () = tail := cursor
val () = cursor := NIL
in
end
prval () = lemma_list_vt_param lst
var cursor = lst
var tail : List_vt a?
val () = loop {n} (pred, cursor, tail)
in
@(cursor, tail)
end
(*------------------------------------------------------------------*)
implement {a}
list_vt_three_way_partition {n} (compare, pivot, lst) =
//
// WARNING: This implementation is NOT tail-recursive.
//
let
var current_sign : int = 0
val p_compare = addr@ compare
val p_pivot = addr@ pivot
val p_current_sign = addr@ current_sign
var pred = (* A linear closure. *)
lam (elem : &a) : bool =<cloptr1>
(* Return true iff the sign of the comparison of elem with the
pivot matches the current_sign. *)
let
val @(pf_compare, fpf_compare | p_compare) =
$UN.ptr0_vtake{(&a, &a) -<cloptr1> int} p_compare
val @(pf_pivot, fpf_pivot | p_pivot) =
$UN.ptr0_vtake{a} p_pivot
val @(pf_current_sign, fpf_current_sign | p_current_sign) =
$UN.ptr0_vtake{int} p_current_sign
macdef compare = !p_compare
macdef pivot = !p_pivot
macdef current_sign = !p_current_sign
val sign = compare (elem, pivot)
val truth =
(sign < 0 && current_sign < 0) ||
(sign = 0 && current_sign = 0) ||
(sign > 0 && current_sign > 0)
prval () = fpf_compare pf_compare
prval () = fpf_pivot pf_pivot
prval () = fpf_current_sign pf_current_sign
in
truth
end
fun
recurs {n : nat}
(compare : &((&a, &a) -<cloptr1> int),
pred : &((&a) -<cloptr1> bool),
pivot : &a,
current_sign : &int,
lst : list_vt (a, n))
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
@(list_vt (a, n1),
list_vt (a, n2),
list_vt (a, n3)) =
case+ lst of
| ~ NIL => @(NIL, NIL, NIL)
| @ elem :: tail =>
let
macdef append = list_vt_append<a>
val cmp = compare (elem, pivot)
val () = current_sign := cmp
prval () = fold@ lst
val @(matches, rest) = list_vt_span<a> (pred, lst)
val @(left, middle, right) =
recurs (compare, pred, pivot, current_sign, rest)
in
if cmp < 0 then
@(matches \append left, middle, right)
else if cmp = 0 then
@(left, matches \append middle, right)
else
@(left, middle, matches \append right)
end
prval () = lemma_list_vt_param lst
val retvals = recurs (compare, pred, pivot, current_sign, lst)
val () = cloptr_free ($UN.castvwtp0{cloptr0} pred)
in
retvals
end
(*------------------------------------------------------------------*)
fn {a : vt@ype}
three_way_partition_with_random_pivot
{n : nat}
(compare : &((&a, &a) -<cloptr1> int),
n : int n,
lst : list_vt (a, n))
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
@(int n1, list_vt (a, n1),
int n2, list_vt (a, n2),
int n3, list_vt (a, n3)) =
let
macdef append = list_vt_append<a>
var pivot : a
val randnum = random_double ()
val i_pivot = $UN.cast{Size_t} (randnum * $UN.cast{double} n)
prval () = lemma_g1uint_param i_pivot
val () = assertloc (i_pivot < i2sz n)
val i_pivot = sz2i i_pivot
val @(left, right) = list_vt_split_at<a> (lst, i_pivot)
val+ ~ (pivot_val :: right) = right
val () = pivot := pivot_val
val @(left1, middle1, right1) =
list_vt_three_way_partition<a> (compare, pivot, left)
val @(left2, middle2, right2) =
list_vt_three_way_partition<a> (compare, pivot, right)
val left = left1 \append left2
val middle = middle1 \append (pivot :: middle2)
val right = right1 \append right2
val n1 = length<a> left
val n2 = length<a> middle
val n3 = n - n1 - n2
in
@(n1, left, n2, middle, n3, right)
end
(*------------------------------------------------------------------*)
implement {a}
list_vt_three_way_separation {n, k} (compare, k, lst) =
(* This is a quickselect with random pivot, returning a three-way
partition, in which the middle partition contains the (k+1)st
least element. *)
let
macdef append = list_vt_append<a>
fun
loop {n1, n2, n3, k : nat | 0 <= k; k < n;
n1 + n2 + n3 == n}
(compare : &((&a, &a) -<cloptr1> int),
k : int k,
n1 : int n1,
left : list_vt (a, n1),
n2 : int n2,
middle : list_vt (a, n2),
n3 : int n3,
right : list_vt (a, n3))
: [n1, n2, n3 : nat | n1 + n2 + n3 == n;
n1 <= k; k < n1 + n2]
@(int n1, list_vt (a, n1),
int n2, list_vt (a, n2),
int n3, list_vt (a, n3)) =
if k < n1 then
let
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n1, left)
in
loop (compare, k, m1, left1, m2, middle1,
m3 + n2 + n3,
right1 \append (middle \append right))
end
else if n1 + n2 <= k then
let
val @(m1, left2, m2, middle2, m3, right2) =
three_way_partition_with_random_pivot<a>
(compare, n3, right)
in
loop (compare, k, n1 + n2 + m1,
left \append (middle \append left2),
m2, middle2, m3, right2)
end
else
@(n1, left, n2, middle, n3, right)
prval () = lemma_list_vt_param lst
val @(n1, left, n2, middle, n3, right) =
three_way_partition_with_random_pivot<a>
(compare, length<a> lst, lst)
in
loop (compare, k, n1, left, n2, middle, n3, right)
end
(*------------------------------------------------------------------*)
implement {a}
list_vt_select_linear {n, k} (compare, k, lst) =
(* This is a quickselect with random pivot. It is like
list_vt_three_way_separation, but throws away parts of the list that
will not be needed later on. *)
let
implement
list_vt_freelin$clear<a> (x) =
$effmask_all list_vt_select_linear$clear<a> (x)
macdef append = list_vt_append<a>
fun
loop {n1, n2, n3, k : nat | 0 <= k; k < n1 + n2 + n3}
(compare : &((&a, &a) -<cloptr1> int),
k : int k,
n1 : int n1,
left : list_vt (a, n1),
n2 : int n2,
middle : list_vt (a, n2),
n3 : int n3,
right : list_vt (a, n3)) : a =
if k < n1 then
let
val () = list_vt_freelin<a> middle
val () = list_vt_freelin<a> right
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n1, left)
in
loop (compare, k, m1, left1, m2, middle1, m3, right1)
end
else if n1 + n2 <= k then
let
val () = list_vt_freelin<a> left
val () = list_vt_freelin<a> middle
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n3, right)
in
loop (compare, k - n1 - n2,
m1, left1, m2, middle1, m3, right1)
end
else
let
val () = list_vt_freelin<a> left
val () = list_vt_freelin<a> right
val @(middle1, middle2) =
list_vt_split_at<a> (middle, k - n1)
val () = list_vt_freelin<a> middle1
val+ ~ (element :: middle2) = middle2
val () = list_vt_freelin<a> middle2
in
element
end
prval () = lemma_list_vt_param lst
val @(n1, left, n2, middle, n3, right) =
three_way_partition_with_random_pivot<a>
(compare, length<a> lst, lst)
in
loop (compare, k, n1, left, n2, middle, n3, right)
end
implement {a}
list_vt_select {n, k} (compare, k, lst) =
let
implement
list_vt_select_linear$clear<a> (x) = ()
in
list_vt_select_linear<a> {n, k} (compare, k, lst)
end
(*------------------------------------------------------------------*)
implement {a}
list_vt_stable_sort {n} (compare, lst) =
(* This is a stable quicksort with random pivot. *)
let
macdef append = list_vt_append<a>
fun
recurs {n : int}
{n1, n2, n3 : nat | n1 + n2 + n3 == n}
(compare : &((&a, &a) -<cloptr1> int),
n1 : int n1,
left : list_vt (a, n1),
n2 : int n2,
middle : list_vt (a, n2),
n3 : int n3,
right : list_vt (a, n3))
: @(int n, list_vt (a, n)) =
if 1 < n1 then
let
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n1, left)
val @(_, left) =
recurs {n1} (compare, m1, left1, m2, middle1, m3, right1)
in
if 1 < n3 then
let
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n3, right)
val @(_, right) =
recurs {n3} (compare, m1, left1, m2, middle1,
m3, right1)
in
@(n1 + n2 + n3, left \append (middle \append right))
end
else
@(n1 + n2 + n3, left \append (middle \append right))
end
else if 1 < n3 then
let
val @(m1, left1, m2, middle1, m3, right1) =
three_way_partition_with_random_pivot<a>
(compare, n3, right)
val @(_, right) =
recurs {n3} (compare, m1, left1, m2, middle1, m3, right1)
in
@(n1 + n2 + n3, left \append (middle \append right))
end
else
@(n1 + n2 + n3, left \append (middle \append right))
prval () = lemma_list_vt_param lst
val @(n1, left, n2, middle, n3, right) =
three_way_partition_with_random_pivot<a>
(compare, length<a> lst, lst)
in
recurs {n} (compare, n1, left, n2, middle, n3, right)
end
(*------------------------------------------------------------------*)
fn
print_kth (direction : int,
k : int,
lst : !List_vt int) : void =
let
var compare =
lam (x : &int, y : &int) : int =<cloptr1>
if x < y then
~direction
else if x = y then
0
else
direction
val lst = copy<int> lst
val n = length<int> lst
val k = g1ofg0 k
val () = assertloc (1 <= k)
val () = assertloc (k <= n)
val element = list_vt_select<int> (compare, k - 1, lst)
val () = cloptr_free ($UN.castvwtp0{cloptr0} compare)
in
print! (element)
end
fn
demonstrate_quickselect () : void =
let
var example_for_select = $list_vt (9, 8, 7, 6, 5, 0, 1, 2, 3, 4)
val () = print! ("With < as order predicate: ")
val () = print_kth (1, 1, example_for_select)
val () = print! (" ")
val () = print_kth (1, 2, example_for_select)
val () = print! (" ")
val () = print_kth (1, 3, example_for_select)
val () = print! (" ")
val () = print_kth (1, 4, example_for_select)
val () = print! (" ")
val () = print_kth (1, 5, example_for_select)
val () = print! (" ")
val () = print_kth (1, 6, example_for_select)
val () = print! (" ")
val () = print_kth (1, 7, example_for_select)
val () = print! (" ")
val () = print_kth (1, 8, example_for_select)
val () = print! (" ")
val () = print_kth (1, 9, example_for_select)
val () = print! (" ")
val () = print_kth (1, 10, example_for_select)
val () = println! ()
val () = print! ("With > as order predicate: ")
val () = print_kth (~1, 1, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 2, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 3, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 4, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 5, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 6, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 7, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 8, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 9, example_for_select)
val () = print! (" ")
val () = print_kth (~1, 10, example_for_select)
val () = println! ()
val () = list_vt_free<int> example_for_select
in
end
fn
demonstrate_quicksort () : void =
let
var example_for_sort =
$list_vt ("elephant", "duck", "giraffe", "deer",
"earwig", "dolphin", "wildebeest", "pronghorn",
"woodlouse", "whip-poor-will")
var compare =
lam (x : &stringGt 0,
y : &stringGt 0) : int =<cloptr1>
if x[0] < y[0] then
~1
else if x[0] = y[0] then
0
else
1
val () = println! ("stable sort by first character:")
val @(_, sorted_lst) =
list_vt_stable_sort<stringGt 0>
(compare, copy<stringGt 0> example_for_sort)
val () = println! ($UN.castvwtp1{List0 string} sorted_lst)
in
list_vt_free<string> sorted_lst;
list_vt_free<string> example_for_sort;
cloptr_free ($UN.castvwtp0{cloptr0} compare)
end
implement
main0 (argc, argv) =
let
(* Currently there is no demonstration of
list_vt_three_way_separation. *)
val demo_name =
begin
if 2 <= argc then
$UN.cast{string} argv[1]
else
begin
println!
("Please choose \"quickselect\" or \"quicksort\".");
exit (1)
end
end : string
in
if demo_name = "quickselect" then
demonstrate_quickselect ()
else if demo_name = "quicksort" then
demonstrate_quicksort ()
else
begin
println! ("Please choose \"quickselect\" or \"quicksort\".");
exit (1)
end
end
(*------------------------------------------------------------------*)