(* Remove duplicate elements. The best sorting algorithms, it is said, are O(n log n) and require an order predicate. But this is true only for a general sorting routine. A radix sort for fixed-size integers is O(n), and requires no order predicate. Here I use such a radix sort. *) #include "share/atspre_staload.hats" staload UN = "prelude/SATS/unsafe.sats" (* How the remove_dups template function will be called. *) extern fn {tk : tkind} remove_dups {n : int} (src : arrayref (g0uint tk, n), n : size_t n, dst : arrayref (g0uint tk, n), m : &size_t? >> size_t m) : #[m : nat | m <= n] void (*------------------------------------------------------------------*) (* A radix sort for unsigned integers, copied from my contribution to the radix sort task. *) extern fn {a : vt@ype} {tk : tkind} g0uint_radix_sort {n : int} (arr : &array (a, n) >> _, n : size_t n) : void extern fn {a : vt@ype} {tk : tkind} g0uint_radix_sort$key {n : int} {i : nat | i < n} (arr : &RD(array (a, n)), i : size_t i) :<> g0uint tk fn {} bin_sizes_to_indices (bin_indices : &array (size_t, 256) >> _) : void = let fun loop {i : int | i <= 256} {accum : int} .<256 - i>. (bin_indices : &array (size_t, 256) >> _, i : size_t i, accum : size_t accum) : void = if i <> i2sz 256 then let prval () = lemma_g1uint_param i val elem = bin_indices[i] in if elem = i2sz 0 then loop (bin_indices, succ i, accum) else begin bin_indices[i] := accum; loop (bin_indices, succ i, accum + g1ofg0 elem) end end in loop (bin_indices, i2sz 0, i2sz 0) end fn {a : vt@ype} {tk : tkind} count_entries {n : int} {shift : nat} (arr : &RD(array (a, n)), n : size_t n, bin_indices : &array (size_t?, 256) >> array (size_t, 256), all_expended : &bool? >> bool, shift : int shift) : void = let fun loop {i : int | i <= n} .. (arr : &RD(array (a, n)), bin_indices : &array (size_t, 256) >> _, all_expended : &bool >> bool, i : size_t i) : void = if i <> n then let prval () = lemma_g1uint_param i val key : g0uint tk = g0uint_radix_sort$key (arr, i) val key_shifted = key >> shift val digit = ($UN.cast{uint} key_shifted) land 255U val [digit : int] digit = g1ofg0 digit extern praxi set_range : () - [0 <= digit; digit <= 255] void prval () = set_range () val count = bin_indices[digit] val () = bin_indices[digit] := succ count in all_expended := all_expended * iseqz key_shifted; loop (arr, bin_indices, all_expended, succ i) end prval () = lemma_array_param arr in array_initize_elt (bin_indices, i2sz 256, i2sz 0); all_expended := true; loop (arr, bin_indices, all_expended, i2sz 0) end fn {a : vt@ype} {tk : tkind} sort_by_digit {n : int} {shift : nat} (arr1 : &RD(array (a, n)), arr2 : &array (a, n) >> _, n : size_t n, all_expended : &bool? >> bool, shift : int shift) : void = let var bin_indices : array (size_t, 256) in count_entries (arr1, n, bin_indices, all_expended, shift); if all_expended then () else let fun rearrange {i : int | i <= n} .. (arr1 : &RD(array (a, n)), arr2 : &array (a, n) >> _, bin_indices : &array (size_t, 256) >> _, i : size_t i) : void = if i <> n then let prval () = lemma_g1uint_param i val key = g0uint_radix_sort$key (arr1, i) val key_shifted = key >> shift val digit = ($UN.cast{uint} key_shifted) land 255U val [digit : int] digit = g1ofg0 digit extern praxi set_range : () - [0 <= digit; digit <= 255] void prval () = set_range () val [j : int] j = g1ofg0 bin_indices[digit] (* One might wish to get rid of this assertion somehow, to eliminate the branch, should it prove a problem. *) val () = $effmask_exn assertloc (j < n) val p_dst = ptr_add (addr@ arr2, j) and p_src = ptr_add (addr@ arr1, i) val _ = $extfcall (ptr, "memcpy", p_dst, p_src, sizeof) val () = bin_indices[digit] := succ (g0ofg1 j) in rearrange (arr1, arr2, bin_indices, succ i) end prval () = lemma_array_param arr1 in bin_sizes_to_indices<> bin_indices; rearrange (arr1, arr2, bin_indices, i2sz 0) end end fn {a : vt@ype} {tk : tkind} g0uint_sort {n : pos} (arr1 : &array (a, n) >> _, arr2 : &array (a, n) >> _, n : size_t n) : void = let fun loop {idigit_max, idigit : nat | idigit <= idigit_max} .. (arr1 : &array (a, n) >> _, arr2 : &array (a, n) >> _, from1to2 : bool, idigit_max : int idigit_max, idigit : int idigit) : void = if idigit = idigit_max then begin if ~from1to2 then let val _ = $extfcall (ptr, "memcpy", addr@ arr1, addr@ arr2, sizeof * n) in end end else if from1to2 then let var all_expended : bool in sort_by_digit (arr1, arr2, n, all_expended, 8 * idigit); if all_expended then () else loop (arr1, arr2, false, idigit_max, succ idigit) end else let var all_expended : bool in sort_by_digit (arr2, arr1, n, all_expended, 8 * idigit); if all_expended then let val _ = $extfcall (ptr, "memcpy", addr@ arr1, addr@ arr2, sizeof * n) in end else loop (arr1, arr2, true, idigit_max, succ idigit) end in loop (arr1, arr2, true, sz2i sizeof, 0) end #define SIZE_THRESHOLD 256 extern praxi unsafe_cast_array {a : vt@ype} {b : vt@ype} {n : int} (arr : &array (b, n) >> array (a, n)) : void implement {a} {tk} g0uint_radix_sort {n} (arr, n) = if n <> 0 then let prval () = lemma_array_param arr fn sort {n : pos} (arr1 : &array (a, n) >> _, arr2 : &array (a, n) >> _, n : size_t n) : void = g0uint_sort (arr1, arr2, n) in if n <= SIZE_THRESHOLD then let var arr2 : array (a, SIZE_THRESHOLD) prval @(pf_left, pf_right) = array_v_split {a?} {..} {SIZE_THRESHOLD} {n} (view@ arr2) prval () = view@ arr2 := pf_left prval () = unsafe_cast_array{a} arr2 val () = sort (arr, arr2, n) prval () = unsafe_cast_array{a?} arr2 prval () = view@ arr2 := array_v_unsplit (view@ arr2, pf_right) in end else let val @(pf_arr2, pfgc_arr2 | p_arr2) = array_ptr_alloc n macdef arr2 = !p_arr2 prval () = unsafe_cast_array{a} arr2 val () = sort (arr, arr2, n) prval () = unsafe_cast_array{a?} arr2 val () = array_ptr_free (pf_arr2, pfgc_arr2 | p_arr2) in end end (*------------------------------------------------------------------*) (* An implementation of the remove_dups template function, which also sorts the elements. *) implement {tk} remove_dups {n} (src, n, dst, m) = if n = i2sz 0 then m := i2sz 0 else let prval () = lemma_arrayref_param src (* Prove 0 <= n. *) (* Sort a copy of src. *) val arrptr = arrayref_copy (src, n) val @(pf_arr | p_arr) = arrayptr_takeout_viewptr arrptr val () = g0uint_radix_sort (!p_arr, n) prval () = arrayptr_addback (pf_arr | arrptr) (* Copy only the first element of each run of equals. *) val () = dst[0] := arrptr[0] fun loop {i : int | 1 <= i; i <= n} {j : int | 1 <= j; j <= i} .. (arrptr : !arrayptr (g0uint tk, n), i : size_t i, j : size_t j) : [m : int | 1 <= m; m <= n] size_t m = if i = n then j else if arrptr[pred i] = arrptr[i] then loop (arrptr, succ i, j) else begin dst[j] := arrptr[i]; loop (arrptr, succ i, succ j) end val () = m := loop (arrptr, i2sz 1, i2sz 1) val () = arrayptr_free arrptr in end (*------------------------------------------------------------------*) (* A demonstration. *) implement main0 () = let implement g0uint_radix_sort$key (arr, i) = arr[i] val src = arrayref_make_list (10, $list (1U, 3U, 2U, 5U, 1U, 1U, 4U, 4U, 2U, 3U)) val dst = arrayref_make_elt (i2sz 10, 123456789U) var m : size_t in remove_dups (src, i2sz 10, dst, m); let prval [m : int] EQINT () = eqint_make_guint m var i : natLte m in for (i := 0; i2sz i <> m; i := succ i) print! (" ", dst[i]); println! () end end (*------------------------------------------------------------------*)