(*------------------------------------------------------------------*) (* For linear linked lists, using a random pivot: * stable three-way "separation" (a variant of quickselect) * quickselect * stable quicksort Also a couple of routines for splitting lists according to a predicate. Linear list operations are destructive but may avoid doing many unnecessary allocations. Also they do not require a garbage collector. *) #include "share/atspre_staload.hats" staload UN = "prelude/SATS/unsafe.sats" #define NIL list_vt_nil () #define :: list_vt_cons (*------------------------------------------------------------------*) (* A simple linear congruential generator for pivot selection. *) (* The multiplier lcg_a comes from Steele, Guy; Vigna, Sebastiano (28 September 2021). "Computationally easy, spectrally good multipliers for congruential pseudorandom number generators". arXiv:2001.05304v3 [cs.DS] *) macdef lcg_a = $UN.cast{uint64} 0xf1357aea2e62a9c5LLU (* lcg_c must be odd. *) macdef lcg_c = $UN.cast{uint64} 0xbaceba11beefbeadLLU var seed : uint64 = $UN.cast 0 val p_seed = addr@ seed fn random_double () : double = let val (pf, fpf | p_seed) = $UN.ptr0_vtake{uint64} p_seed val old_seed = ptr_get (pf | p_seed) (* IEEE "binary64" or "double" has 52 bits of precision. We will take the high 48 bits of the seed and divide it by 2**48, to get a number 0.0 <= randnum < 1.0 *) val high_48_bits = $UN.cast{double} (old_seed >> 16) val divisor = $UN.cast{double} (1LLU << 48) val randnum = high_48_bits / divisor (* The following operation is modulo 2**64, by virtue of standard C behavior for uint64_t. *) val new_seed = (lcg_a * old_seed) + lcg_c val () = ptr_set (pf | p_seed, new_seed) prval () = fpf pf in randnum end (*------------------------------------------------------------------*) (* Destructive split into two lists: a list of leading elements that satisfy a predicate, and the tail of that split. (This is similar to "span!" in SRFI-1.) *) extern fun {a : vt@ype} list_vt_span {n : int} (pred : &((&a) - bool), lst : list_vt (a, n)) : [n1, n2 : nat | n1 + n2 == n] @(list_vt (a, n1), list_vt (a, n2)) (* Destructive, stable partition into elements less than the pivot, elements equal to the pivot, and elements greater than the pivot. *) extern fun {a : vt@ype} list_vt_three_way_partition {n : int} (compare : &((&a, &a) - int), pivot : &a, lst : list_vt (a, n)) : [n1, n2, n3 : nat | n1 + n2 + n3 == n] @(list_vt (a, n1), list_vt (a, n2), list_vt (a, n3)) (* Destructive, stable partition into elements less than the kth least element, elements equal to it, and elements greater than it. *) extern fun {a : vt@ype} list_vt_three_way_separation {n, k : int | 0 <= k; k < n} (compare : &((&a, &a) - int), k : int k, lst : list_vt (a, n)) : [n1, n2, n3 : nat | n1 + n2 + n3 == n; n1 <= k; k < n1 + n2] @(int n1, list_vt (a, n1), int n2, list_vt (a, n2), int n3, list_vt (a, n3)) (* Destructive quickselect for linear elements. *) extern fun {a : vt@ype} list_vt_select_linear {n, k : int | 0 <= k; k < n} (compare : &((&a, &a) - int), k : int k, lst : list_vt (a, n)) : a extern fun {a : vt@ype} list_vt_select_linear$clear (x : &a >> a?) : void (* Destructive quickselect for non-linear elements. *) extern fun {a : t@ype} list_vt_select {n, k : int | 0 <= k; k < n} (compare : &((&a, &a) - int), k : int k, lst : list_vt (a, n)) : a (* Stable quicksort. Also returns the length. *) extern fun {a : vt@ype} list_vt_stable_sort {n : int} (compare : &((&a, &a) - int), lst : list_vt (a, n)) : @(int n, list_vt (a, n)) (*------------------------------------------------------------------*) implement {a} list_vt_span {n} (pred, lst) = let fun loop {n : nat} .. (pred : &((&a) - bool), cursor : &list_vt (a, n) >> list_vt (a, m), tail : &List_vt a? >> list_vt (a, n - m)) : #[m : nat | m <= n] void = case+ cursor of | NIL => tail := NIL | @ elem :: rest => if pred (elem) then (* elem satisfies the predicate. Move the cursor to the next cons-pair in the list. *) let val () = loop {n - 1} (pred, rest, tail) prval () = fold@ cursor in end else (* elem does not satisfy the predicate. Split the list at the cursor. *) let prval () = fold@ cursor val () = tail := cursor val () = cursor := NIL in end prval () = lemma_list_vt_param lst var cursor = lst var tail : List_vt a? val () = loop {n} (pred, cursor, tail) in @(cursor, tail) end (*------------------------------------------------------------------*) implement {a} list_vt_three_way_partition {n} (compare, pivot, lst) = // // WARNING: This implementation is NOT tail-recursive. // let var current_sign : int = 0 val p_compare = addr@ compare val p_pivot = addr@ pivot val p_current_sign = addr@ current_sign var pred = (* A linear closure. *) lam (elem : &a) : bool = (* Return true iff the sign of the comparison of elem with the pivot matches the current_sign. *) let val @(pf_compare, fpf_compare | p_compare) = $UN.ptr0_vtake{(&a, &a) - int} p_compare val @(pf_pivot, fpf_pivot | p_pivot) = $UN.ptr0_vtake{a} p_pivot val @(pf_current_sign, fpf_current_sign | p_current_sign) = $UN.ptr0_vtake{int} p_current_sign macdef compare = !p_compare macdef pivot = !p_pivot macdef current_sign = !p_current_sign val sign = compare (elem, pivot) val truth = (sign < 0 && current_sign < 0) || (sign = 0 && current_sign = 0) || (sign > 0 && current_sign > 0) prval () = fpf_compare pf_compare prval () = fpf_pivot pf_pivot prval () = fpf_current_sign pf_current_sign in truth end fun recurs {n : nat} (compare : &((&a, &a) - int), pred : &((&a) - bool), pivot : &a, current_sign : &int, lst : list_vt (a, n)) : [n1, n2, n3 : nat | n1 + n2 + n3 == n] @(list_vt (a, n1), list_vt (a, n2), list_vt (a, n3)) = case+ lst of | ~ NIL => @(NIL, NIL, NIL) | @ elem :: tail => let macdef append = list_vt_append val cmp = compare (elem, pivot) val () = current_sign := cmp prval () = fold@ lst val @(matches, rest) = list_vt_span (pred, lst) val @(left, middle, right) = recurs (compare, pred, pivot, current_sign, rest) in if cmp < 0 then @(matches \append left, middle, right) else if cmp = 0 then @(left, matches \append middle, right) else @(left, middle, matches \append right) end prval () = lemma_list_vt_param lst val retvals = recurs (compare, pred, pivot, current_sign, lst) val () = cloptr_free ($UN.castvwtp0{cloptr0} pred) in retvals end (*------------------------------------------------------------------*) fn {a : vt@ype} three_way_partition_with_random_pivot {n : nat} (compare : &((&a, &a) - int), n : int n, lst : list_vt (a, n)) : [n1, n2, n3 : nat | n1 + n2 + n3 == n] @(int n1, list_vt (a, n1), int n2, list_vt (a, n2), int n3, list_vt (a, n3)) = let macdef append = list_vt_append var pivot : a val randnum = random_double () val i_pivot = $UN.cast{Size_t} (randnum * $UN.cast{double} n) prval () = lemma_g1uint_param i_pivot val () = assertloc (i_pivot < i2sz n) val i_pivot = sz2i i_pivot val @(left, right) = list_vt_split_at (lst, i_pivot) val+ ~ (pivot_val :: right) = right val () = pivot := pivot_val val @(left1, middle1, right1) = list_vt_three_way_partition (compare, pivot, left) val @(left2, middle2, right2) = list_vt_three_way_partition (compare, pivot, right) val left = left1 \append left2 val middle = middle1 \append (pivot :: middle2) val right = right1 \append right2 val n1 = length left val n2 = length middle val n3 = n - n1 - n2 in @(n1, left, n2, middle, n3, right) end (*------------------------------------------------------------------*) implement {a} list_vt_three_way_separation {n, k} (compare, k, lst) = (* This is a quickselect with random pivot, returning a three-way partition, in which the middle partition contains the (k+1)st least element. *) let macdef append = list_vt_append fun loop {n1, n2, n3, k : nat | 0 <= k; k < n; n1 + n2 + n3 == n} (compare : &((&a, &a) - int), k : int k, n1 : int n1, left : list_vt (a, n1), n2 : int n2, middle : list_vt (a, n2), n3 : int n3, right : list_vt (a, n3)) : [n1, n2, n3 : nat | n1 + n2 + n3 == n; n1 <= k; k < n1 + n2] @(int n1, list_vt (a, n1), int n2, list_vt (a, n2), int n3, list_vt (a, n3)) = if k < n1 then let val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n1, left) in loop (compare, k, m1, left1, m2, middle1, m3 + n2 + n3, right1 \append (middle \append right)) end else if n1 + n2 <= k then let val @(m1, left2, m2, middle2, m3, right2) = three_way_partition_with_random_pivot (compare, n3, right) in loop (compare, k, n1 + n2 + m1, left \append (middle \append left2), m2, middle2, m3, right2) end else @(n1, left, n2, middle, n3, right) prval () = lemma_list_vt_param lst val @(n1, left, n2, middle, n3, right) = three_way_partition_with_random_pivot (compare, length lst, lst) in loop (compare, k, n1, left, n2, middle, n3, right) end (*------------------------------------------------------------------*) implement {a} list_vt_select_linear {n, k} (compare, k, lst) = (* This is a quickselect with random pivot. It is like list_vt_three_way_separation, but throws away parts of the list that will not be needed later on. *) let implement list_vt_freelin$clear (x) = $effmask_all list_vt_select_linear$clear (x) macdef append = list_vt_append fun loop {n1, n2, n3, k : nat | 0 <= k; k < n1 + n2 + n3} (compare : &((&a, &a) - int), k : int k, n1 : int n1, left : list_vt (a, n1), n2 : int n2, middle : list_vt (a, n2), n3 : int n3, right : list_vt (a, n3)) : a = if k < n1 then let val () = list_vt_freelin middle val () = list_vt_freelin right val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n1, left) in loop (compare, k, m1, left1, m2, middle1, m3, right1) end else if n1 + n2 <= k then let val () = list_vt_freelin left val () = list_vt_freelin middle val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n3, right) in loop (compare, k - n1 - n2, m1, left1, m2, middle1, m3, right1) end else let val () = list_vt_freelin left val () = list_vt_freelin right val @(middle1, middle2) = list_vt_split_at (middle, k - n1) val () = list_vt_freelin middle1 val+ ~ (element :: middle2) = middle2 val () = list_vt_freelin middle2 in element end prval () = lemma_list_vt_param lst val @(n1, left, n2, middle, n3, right) = three_way_partition_with_random_pivot (compare, length lst, lst) in loop (compare, k, n1, left, n2, middle, n3, right) end implement {a} list_vt_select {n, k} (compare, k, lst) = let implement list_vt_select_linear$clear (x) = () in list_vt_select_linear {n, k} (compare, k, lst) end (*------------------------------------------------------------------*) implement {a} list_vt_stable_sort {n} (compare, lst) = (* This is a stable quicksort with random pivot. *) let macdef append = list_vt_append fun recurs {n : int} {n1, n2, n3 : nat | n1 + n2 + n3 == n} (compare : &((&a, &a) - int), n1 : int n1, left : list_vt (a, n1), n2 : int n2, middle : list_vt (a, n2), n3 : int n3, right : list_vt (a, n3)) : @(int n, list_vt (a, n)) = if 1 < n1 then let val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n1, left) val @(_, left) = recurs {n1} (compare, m1, left1, m2, middle1, m3, right1) in if 1 < n3 then let val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n3, right) val @(_, right) = recurs {n3} (compare, m1, left1, m2, middle1, m3, right1) in @(n1 + n2 + n3, left \append (middle \append right)) end else @(n1 + n2 + n3, left \append (middle \append right)) end else if 1 < n3 then let val @(m1, left1, m2, middle1, m3, right1) = three_way_partition_with_random_pivot (compare, n3, right) val @(_, right) = recurs {n3} (compare, m1, left1, m2, middle1, m3, right1) in @(n1 + n2 + n3, left \append (middle \append right)) end else @(n1 + n2 + n3, left \append (middle \append right)) prval () = lemma_list_vt_param lst val @(n1, left, n2, middle, n3, right) = three_way_partition_with_random_pivot (compare, length lst, lst) in recurs {n} (compare, n1, left, n2, middle, n3, right) end (*------------------------------------------------------------------*) fn print_kth (direction : int, k : int, lst : !List_vt int) : void = let var compare = lam (x : &int, y : &int) : int = if x < y then ~direction else if x = y then 0 else direction val lst = copy lst val n = length lst val k = g1ofg0 k val () = assertloc (1 <= k) val () = assertloc (k <= n) val element = list_vt_select (compare, k - 1, lst) val () = cloptr_free ($UN.castvwtp0{cloptr0} compare) in print! (element) end fn demonstrate_quickselect () : void = let var example_for_select = $list_vt (9, 8, 7, 6, 5, 0, 1, 2, 3, 4) val () = print! ("With < as order predicate: ") val () = print_kth (1, 1, example_for_select) val () = print! (" ") val () = print_kth (1, 2, example_for_select) val () = print! (" ") val () = print_kth (1, 3, example_for_select) val () = print! (" ") val () = print_kth (1, 4, example_for_select) val () = print! (" ") val () = print_kth (1, 5, example_for_select) val () = print! (" ") val () = print_kth (1, 6, example_for_select) val () = print! (" ") val () = print_kth (1, 7, example_for_select) val () = print! (" ") val () = print_kth (1, 8, example_for_select) val () = print! (" ") val () = print_kth (1, 9, example_for_select) val () = print! (" ") val () = print_kth (1, 10, example_for_select) val () = println! () val () = print! ("With > as order predicate: ") val () = print_kth (~1, 1, example_for_select) val () = print! (" ") val () = print_kth (~1, 2, example_for_select) val () = print! (" ") val () = print_kth (~1, 3, example_for_select) val () = print! (" ") val () = print_kth (~1, 4, example_for_select) val () = print! (" ") val () = print_kth (~1, 5, example_for_select) val () = print! (" ") val () = print_kth (~1, 6, example_for_select) val () = print! (" ") val () = print_kth (~1, 7, example_for_select) val () = print! (" ") val () = print_kth (~1, 8, example_for_select) val () = print! (" ") val () = print_kth (~1, 9, example_for_select) val () = print! (" ") val () = print_kth (~1, 10, example_for_select) val () = println! () val () = list_vt_free example_for_select in end fn demonstrate_quicksort () : void = let var example_for_sort = $list_vt ("elephant", "duck", "giraffe", "deer", "earwig", "dolphin", "wildebeest", "pronghorn", "woodlouse", "whip-poor-will") var compare = lam (x : &stringGt 0, y : &stringGt 0) : int = if x[0] < y[0] then ~1 else if x[0] = y[0] then 0 else 1 val () = println! ("stable sort by first character:") val @(_, sorted_lst) = list_vt_stable_sort (compare, copy example_for_sort) val () = println! ($UN.castvwtp1{List0 string} sorted_lst) in list_vt_free sorted_lst; list_vt_free example_for_sort; cloptr_free ($UN.castvwtp0{cloptr0} compare) end implement main0 (argc, argv) = let (* Currently there is no demonstration of list_vt_three_way_separation. *) val demo_name = begin if 2 <= argc then $UN.cast{string} argv[1] else begin println! ("Please choose \"quickselect\" or \"quicksort\"."); exit (1) end end : string in if demo_name = "quickselect" then demonstrate_quickselect () else if demo_name = "quicksort" then demonstrate_quicksort () else begin println! ("Please choose \"quickselect\" or \"quicksort\"."); exit (1) end end (*------------------------------------------------------------------*)