638 lines
19 KiB
Plaintext
638 lines
19 KiB
Plaintext
(*------------------------------------------------------------------*)
|
|
(*
|
|
|
|
For linear linked lists, using a random pivot:
|
|
|
|
* stable three-way "separation" (a variant of quickselect)
|
|
* quickselect
|
|
* stable quicksort
|
|
|
|
Also a couple of routines for splitting lists according to a
|
|
predicate.
|
|
|
|
Linear list operations are destructive but may avoid doing many
|
|
unnecessary allocations. Also they do not require a garbage
|
|
collector.
|
|
|
|
*)
|
|
|
|
#include "share/atspre_staload.hats"
|
|
|
|
staload UN = "prelude/SATS/unsafe.sats"
|
|
|
|
#define NIL list_vt_nil ()
|
|
#define :: list_vt_cons
|
|
|
|
(*------------------------------------------------------------------*)
|
|
(* A simple linear congruential generator for pivot selection. *)
|
|
|
|
(* The multiplier lcg_a comes from Steele, Guy; Vigna, Sebastiano (28
|
|
September 2021). "Computationally easy, spectrally good multipliers
|
|
for congruential pseudorandom number generators".
|
|
arXiv:2001.05304v3 [cs.DS] *)
|
|
macdef lcg_a = $UN.cast{uint64} 0xf1357aea2e62a9c5LLU
|
|
|
|
(* lcg_c must be odd. *)
|
|
macdef lcg_c = $UN.cast{uint64} 0xbaceba11beefbeadLLU
|
|
|
|
var seed : uint64 = $UN.cast 0
|
|
val p_seed = addr@ seed
|
|
|
|
fn
|
|
random_double () :<!wrt> double =
|
|
let
|
|
val (pf, fpf | p_seed) = $UN.ptr0_vtake{uint64} p_seed
|
|
val old_seed = ptr_get<uint64> (pf | p_seed)
|
|
|
|
(* IEEE "binary64" or "double" has 52 bits of precision. We will
|
|
take the high 48 bits of the seed and divide it by 2**48, to
|
|
get a number 0.0 <= randnum < 1.0 *)
|
|
val high_48_bits = $UN.cast{double} (old_seed >> 16)
|
|
val divisor = $UN.cast{double} (1LLU << 48)
|
|
val randnum = high_48_bits / divisor
|
|
|
|
(* The following operation is modulo 2**64, by virtue of standard
|
|
C behavior for uint64_t. *)
|
|
val new_seed = (lcg_a * old_seed) + lcg_c
|
|
|
|
val () = ptr_set<uint64> (pf | p_seed, new_seed)
|
|
prval () = fpf pf
|
|
in
|
|
randnum
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
(* Destructive split into two lists: a list of leading elements that
|
|
satisfy a predicate, and the tail of that split. (This is similar
|
|
to "span!" in SRFI-1.) *)
|
|
extern fun {a : vt@ype}
|
|
list_vt_span {n : int}
|
|
(pred : &((&a) -<cloptr1> bool),
|
|
lst : list_vt (a, n))
|
|
: [n1, n2 : nat | n1 + n2 == n]
|
|
@(list_vt (a, n1),
|
|
list_vt (a, n2))
|
|
|
|
(* Destructive, stable partition into elements less than the pivot,
|
|
elements equal to the pivot, and elements greater than the
|
|
pivot. *)
|
|
extern fun {a : vt@ype}
|
|
list_vt_three_way_partition
|
|
{n : int}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
pivot : &a,
|
|
lst : list_vt (a, n))
|
|
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
|
|
@(list_vt (a, n1),
|
|
list_vt (a, n2),
|
|
list_vt (a, n3))
|
|
|
|
(* Destructive, stable partition into elements less than the kth least
|
|
element, elements equal to it, and elements greater than it. *)
|
|
extern fun {a : vt@ype}
|
|
list_vt_three_way_separation
|
|
{n, k : int | 0 <= k; k < n}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
k : int k,
|
|
lst : list_vt (a, n))
|
|
: [n1, n2, n3 : nat | n1 + n2 + n3 == n;
|
|
n1 <= k; k < n1 + n2]
|
|
@(int n1, list_vt (a, n1),
|
|
int n2, list_vt (a, n2),
|
|
int n3, list_vt (a, n3))
|
|
|
|
(* Destructive quickselect for linear elements. *)
|
|
extern fun {a : vt@ype}
|
|
list_vt_select_linear
|
|
{n, k : int | 0 <= k; k < n}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
k : int k,
|
|
lst : list_vt (a, n)) : a
|
|
extern fun {a : vt@ype}
|
|
list_vt_select_linear$clear (x : &a >> a?) : void
|
|
|
|
(* Destructive quickselect for non-linear elements. *)
|
|
extern fun {a : t@ype}
|
|
list_vt_select
|
|
{n, k : int | 0 <= k; k < n}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
k : int k,
|
|
lst : list_vt (a, n)) : a
|
|
|
|
(* Stable quicksort. Also returns the length. *)
|
|
extern fun {a : vt@ype}
|
|
list_vt_stable_sort
|
|
{n : int}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
lst : list_vt (a, n))
|
|
: @(int n, list_vt (a, n))
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement {a}
|
|
list_vt_span {n} (pred, lst) =
|
|
let
|
|
fun
|
|
loop {n : nat} .<n>.
|
|
(pred : &((&a) -<cloptr1> bool),
|
|
cursor : &list_vt (a, n) >> list_vt (a, m),
|
|
tail : &List_vt a? >> list_vt (a, n - m))
|
|
: #[m : nat | m <= n] void =
|
|
case+ cursor of
|
|
| NIL => tail := NIL
|
|
| @ elem :: rest =>
|
|
if pred (elem) then
|
|
(* elem satisfies the predicate. Move the cursor to the next
|
|
cons-pair in the list. *)
|
|
let
|
|
val () = loop {n - 1} (pred, rest, tail)
|
|
prval () = fold@ cursor
|
|
in
|
|
end
|
|
else
|
|
(* elem does not satisfy the predicate. Split the list at
|
|
the cursor. *)
|
|
let
|
|
prval () = fold@ cursor
|
|
val () = tail := cursor
|
|
val () = cursor := NIL
|
|
in
|
|
end
|
|
|
|
prval () = lemma_list_vt_param lst
|
|
|
|
var cursor = lst
|
|
var tail : List_vt a?
|
|
val () = loop {n} (pred, cursor, tail)
|
|
in
|
|
@(cursor, tail)
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement {a}
|
|
list_vt_three_way_partition {n} (compare, pivot, lst) =
|
|
//
|
|
// WARNING: This implementation is NOT tail-recursive.
|
|
//
|
|
let
|
|
var current_sign : int = 0
|
|
|
|
val p_compare = addr@ compare
|
|
val p_pivot = addr@ pivot
|
|
val p_current_sign = addr@ current_sign
|
|
|
|
var pred = (* A linear closure. *)
|
|
lam (elem : &a) : bool =<cloptr1>
|
|
(* Return true iff the sign of the comparison of elem with the
|
|
pivot matches the current_sign. *)
|
|
let
|
|
val @(pf_compare, fpf_compare | p_compare) =
|
|
$UN.ptr0_vtake{(&a, &a) -<cloptr1> int} p_compare
|
|
val @(pf_pivot, fpf_pivot | p_pivot) =
|
|
$UN.ptr0_vtake{a} p_pivot
|
|
val @(pf_current_sign, fpf_current_sign | p_current_sign) =
|
|
$UN.ptr0_vtake{int} p_current_sign
|
|
|
|
macdef compare = !p_compare
|
|
macdef pivot = !p_pivot
|
|
macdef current_sign = !p_current_sign
|
|
|
|
val sign = compare (elem, pivot)
|
|
val truth =
|
|
(sign < 0 && current_sign < 0) ||
|
|
(sign = 0 && current_sign = 0) ||
|
|
(sign > 0 && current_sign > 0)
|
|
|
|
prval () = fpf_compare pf_compare
|
|
prval () = fpf_pivot pf_pivot
|
|
prval () = fpf_current_sign pf_current_sign
|
|
in
|
|
truth
|
|
end
|
|
|
|
fun
|
|
recurs {n : nat}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
pred : &((&a) -<cloptr1> bool),
|
|
pivot : &a,
|
|
current_sign : &int,
|
|
lst : list_vt (a, n))
|
|
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
|
|
@(list_vt (a, n1),
|
|
list_vt (a, n2),
|
|
list_vt (a, n3)) =
|
|
case+ lst of
|
|
| ~ NIL => @(NIL, NIL, NIL)
|
|
| @ elem :: tail =>
|
|
let
|
|
macdef append = list_vt_append<a>
|
|
val cmp = compare (elem, pivot)
|
|
val () = current_sign := cmp
|
|
prval () = fold@ lst
|
|
val @(matches, rest) = list_vt_span<a> (pred, lst)
|
|
val @(left, middle, right) =
|
|
recurs (compare, pred, pivot, current_sign, rest)
|
|
in
|
|
if cmp < 0 then
|
|
@(matches \append left, middle, right)
|
|
else if cmp = 0 then
|
|
@(left, matches \append middle, right)
|
|
else
|
|
@(left, middle, matches \append right)
|
|
end
|
|
|
|
prval () = lemma_list_vt_param lst
|
|
val retvals = recurs (compare, pred, pivot, current_sign, lst)
|
|
|
|
val () = cloptr_free ($UN.castvwtp0{cloptr0} pred)
|
|
in
|
|
retvals
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
fn {a : vt@ype}
|
|
three_way_partition_with_random_pivot
|
|
{n : nat}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
n : int n,
|
|
lst : list_vt (a, n))
|
|
: [n1, n2, n3 : nat | n1 + n2 + n3 == n]
|
|
@(int n1, list_vt (a, n1),
|
|
int n2, list_vt (a, n2),
|
|
int n3, list_vt (a, n3)) =
|
|
let
|
|
macdef append = list_vt_append<a>
|
|
|
|
var pivot : a
|
|
|
|
val randnum = random_double ()
|
|
val i_pivot = $UN.cast{Size_t} (randnum * $UN.cast{double} n)
|
|
prval () = lemma_g1uint_param i_pivot
|
|
val () = assertloc (i_pivot < i2sz n)
|
|
val i_pivot = sz2i i_pivot
|
|
|
|
val @(left, right) = list_vt_split_at<a> (lst, i_pivot)
|
|
val+ ~ (pivot_val :: right) = right
|
|
val () = pivot := pivot_val
|
|
|
|
val @(left1, middle1, right1) =
|
|
list_vt_three_way_partition<a> (compare, pivot, left)
|
|
val @(left2, middle2, right2) =
|
|
list_vt_three_way_partition<a> (compare, pivot, right)
|
|
|
|
val left = left1 \append left2
|
|
val middle = middle1 \append (pivot :: middle2)
|
|
val right = right1 \append right2
|
|
|
|
val n1 = length<a> left
|
|
val n2 = length<a> middle
|
|
val n3 = n - n1 - n2
|
|
in
|
|
@(n1, left, n2, middle, n3, right)
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement {a}
|
|
list_vt_three_way_separation {n, k} (compare, k, lst) =
|
|
(* This is a quickselect with random pivot, returning a three-way
|
|
partition, in which the middle partition contains the (k+1)st
|
|
least element. *)
|
|
let
|
|
macdef append = list_vt_append<a>
|
|
|
|
fun
|
|
loop {n1, n2, n3, k : nat | 0 <= k; k < n;
|
|
n1 + n2 + n3 == n}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
k : int k,
|
|
n1 : int n1,
|
|
left : list_vt (a, n1),
|
|
n2 : int n2,
|
|
middle : list_vt (a, n2),
|
|
n3 : int n3,
|
|
right : list_vt (a, n3))
|
|
: [n1, n2, n3 : nat | n1 + n2 + n3 == n;
|
|
n1 <= k; k < n1 + n2]
|
|
@(int n1, list_vt (a, n1),
|
|
int n2, list_vt (a, n2),
|
|
int n3, list_vt (a, n3)) =
|
|
if k < n1 then
|
|
let
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n1, left)
|
|
in
|
|
loop (compare, k, m1, left1, m2, middle1,
|
|
m3 + n2 + n3,
|
|
right1 \append (middle \append right))
|
|
end
|
|
else if n1 + n2 <= k then
|
|
let
|
|
val @(m1, left2, m2, middle2, m3, right2) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n3, right)
|
|
in
|
|
loop (compare, k, n1 + n2 + m1,
|
|
left \append (middle \append left2),
|
|
m2, middle2, m3, right2)
|
|
end
|
|
else
|
|
@(n1, left, n2, middle, n3, right)
|
|
|
|
prval () = lemma_list_vt_param lst
|
|
|
|
val @(n1, left, n2, middle, n3, right) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, length<a> lst, lst)
|
|
in
|
|
loop (compare, k, n1, left, n2, middle, n3, right)
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement {a}
|
|
list_vt_select_linear {n, k} (compare, k, lst) =
|
|
(* This is a quickselect with random pivot. It is like
|
|
list_vt_three_way_separation, but throws away parts of the list that
|
|
will not be needed later on. *)
|
|
let
|
|
implement
|
|
list_vt_freelin$clear<a> (x) =
|
|
$effmask_all list_vt_select_linear$clear<a> (x)
|
|
|
|
macdef append = list_vt_append<a>
|
|
|
|
fun
|
|
loop {n1, n2, n3, k : nat | 0 <= k; k < n1 + n2 + n3}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
k : int k,
|
|
n1 : int n1,
|
|
left : list_vt (a, n1),
|
|
n2 : int n2,
|
|
middle : list_vt (a, n2),
|
|
n3 : int n3,
|
|
right : list_vt (a, n3)) : a =
|
|
if k < n1 then
|
|
let
|
|
val () = list_vt_freelin<a> middle
|
|
val () = list_vt_freelin<a> right
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n1, left)
|
|
in
|
|
loop (compare, k, m1, left1, m2, middle1, m3, right1)
|
|
end
|
|
else if n1 + n2 <= k then
|
|
let
|
|
val () = list_vt_freelin<a> left
|
|
val () = list_vt_freelin<a> middle
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n3, right)
|
|
in
|
|
loop (compare, k - n1 - n2,
|
|
m1, left1, m2, middle1, m3, right1)
|
|
end
|
|
else
|
|
let
|
|
val () = list_vt_freelin<a> left
|
|
val () = list_vt_freelin<a> right
|
|
val @(middle1, middle2) =
|
|
list_vt_split_at<a> (middle, k - n1)
|
|
val () = list_vt_freelin<a> middle1
|
|
val+ ~ (element :: middle2) = middle2
|
|
val () = list_vt_freelin<a> middle2
|
|
in
|
|
element
|
|
end
|
|
|
|
prval () = lemma_list_vt_param lst
|
|
|
|
val @(n1, left, n2, middle, n3, right) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, length<a> lst, lst)
|
|
in
|
|
loop (compare, k, n1, left, n2, middle, n3, right)
|
|
end
|
|
|
|
implement {a}
|
|
list_vt_select {n, k} (compare, k, lst) =
|
|
let
|
|
implement
|
|
list_vt_select_linear$clear<a> (x) = ()
|
|
in
|
|
list_vt_select_linear<a> {n, k} (compare, k, lst)
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
implement {a}
|
|
list_vt_stable_sort {n} (compare, lst) =
|
|
(* This is a stable quicksort with random pivot. *)
|
|
let
|
|
macdef append = list_vt_append<a>
|
|
|
|
fun
|
|
recurs {n : int}
|
|
{n1, n2, n3 : nat | n1 + n2 + n3 == n}
|
|
(compare : &((&a, &a) -<cloptr1> int),
|
|
n1 : int n1,
|
|
left : list_vt (a, n1),
|
|
n2 : int n2,
|
|
middle : list_vt (a, n2),
|
|
n3 : int n3,
|
|
right : list_vt (a, n3))
|
|
: @(int n, list_vt (a, n)) =
|
|
if 1 < n1 then
|
|
let
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n1, left)
|
|
val @(_, left) =
|
|
recurs {n1} (compare, m1, left1, m2, middle1, m3, right1)
|
|
in
|
|
if 1 < n3 then
|
|
let
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n3, right)
|
|
val @(_, right) =
|
|
recurs {n3} (compare, m1, left1, m2, middle1,
|
|
m3, right1)
|
|
in
|
|
@(n1 + n2 + n3, left \append (middle \append right))
|
|
end
|
|
else
|
|
@(n1 + n2 + n3, left \append (middle \append right))
|
|
end
|
|
else if 1 < n3 then
|
|
let
|
|
val @(m1, left1, m2, middle1, m3, right1) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, n3, right)
|
|
val @(_, right) =
|
|
recurs {n3} (compare, m1, left1, m2, middle1, m3, right1)
|
|
in
|
|
@(n1 + n2 + n3, left \append (middle \append right))
|
|
end
|
|
else
|
|
@(n1 + n2 + n3, left \append (middle \append right))
|
|
|
|
prval () = lemma_list_vt_param lst
|
|
|
|
val @(n1, left, n2, middle, n3, right) =
|
|
three_way_partition_with_random_pivot<a>
|
|
(compare, length<a> lst, lst)
|
|
in
|
|
recurs {n} (compare, n1, left, n2, middle, n3, right)
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|
|
|
|
fn
|
|
print_kth (direction : int,
|
|
k : int,
|
|
lst : !List_vt int) : void =
|
|
let
|
|
var compare =
|
|
lam (x : &int, y : &int) : int =<cloptr1>
|
|
if x < y then
|
|
~direction
|
|
else if x = y then
|
|
0
|
|
else
|
|
direction
|
|
|
|
val lst = copy<int> lst
|
|
val n = length<int> lst
|
|
val k = g1ofg0 k
|
|
val () = assertloc (1 <= k)
|
|
val () = assertloc (k <= n)
|
|
val element = list_vt_select<int> (compare, k - 1, lst)
|
|
|
|
val () = cloptr_free ($UN.castvwtp0{cloptr0} compare)
|
|
in
|
|
print! (element)
|
|
end
|
|
|
|
fn
|
|
demonstrate_quickselect () : void =
|
|
let
|
|
var example_for_select = $list_vt (9, 8, 7, 6, 5, 0, 1, 2, 3, 4)
|
|
|
|
val () = print! ("With < as order predicate: ")
|
|
val () = print_kth (1, 1, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 2, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 3, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 4, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 5, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 6, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 7, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 8, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 9, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (1, 10, example_for_select)
|
|
val () = println! ()
|
|
|
|
val () = print! ("With > as order predicate: ")
|
|
val () = print_kth (~1, 1, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 2, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 3, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 4, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 5, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 6, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 7, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 8, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 9, example_for_select)
|
|
val () = print! (" ")
|
|
val () = print_kth (~1, 10, example_for_select)
|
|
val () = println! ()
|
|
|
|
val () = list_vt_free<int> example_for_select
|
|
in
|
|
end
|
|
|
|
fn
|
|
demonstrate_quicksort () : void =
|
|
let
|
|
var example_for_sort =
|
|
$list_vt ("elephant", "duck", "giraffe", "deer",
|
|
"earwig", "dolphin", "wildebeest", "pronghorn",
|
|
"woodlouse", "whip-poor-will")
|
|
|
|
var compare =
|
|
lam (x : &stringGt 0,
|
|
y : &stringGt 0) : int =<cloptr1>
|
|
if x[0] < y[0] then
|
|
~1
|
|
else if x[0] = y[0] then
|
|
0
|
|
else
|
|
1
|
|
|
|
val () = println! ("stable sort by first character:")
|
|
val @(_, sorted_lst) =
|
|
list_vt_stable_sort<stringGt 0>
|
|
(compare, copy<stringGt 0> example_for_sort)
|
|
val () = println! ($UN.castvwtp1{List0 string} sorted_lst)
|
|
in
|
|
list_vt_free<string> sorted_lst;
|
|
list_vt_free<string> example_for_sort;
|
|
cloptr_free ($UN.castvwtp0{cloptr0} compare)
|
|
end
|
|
|
|
implement
|
|
main0 (argc, argv) =
|
|
let
|
|
|
|
(* Currently there is no demonstration of
|
|
list_vt_three_way_separation. *)
|
|
|
|
val demo_name =
|
|
begin
|
|
if 2 <= argc then
|
|
$UN.cast{string} argv[1]
|
|
else
|
|
begin
|
|
println!
|
|
("Please choose \"quickselect\" or \"quicksort\".");
|
|
exit (1)
|
|
end
|
|
end : string
|
|
|
|
in
|
|
|
|
if demo_name = "quickselect" then
|
|
demonstrate_quickselect ()
|
|
else if demo_name = "quicksort" then
|
|
demonstrate_quicksort ()
|
|
else
|
|
begin
|
|
println! ("Please choose \"quickselect\" or \"quicksort\".");
|
|
exit (1)
|
|
end
|
|
|
|
end
|
|
|
|
(*------------------------------------------------------------------*)
|