%%%------------------------------------------------------------------- :- module quickselect_task. :- interface. :- import_module io. :- pred main(io, io). :- mode main(di, uo) is det. :- implementation. :- import_module array. :- import_module exception. :- import_module int. :- import_module list. :- import_module random. :- import_module random.sfc64. :- import_module string. %%%------------------------------------------------------------------- %%% %%% Partitioning a subarray into two halves: one with elements less %%% than or equal to a pivot, the other with elements greater than or %%% equal to a pivot. %%% %%% The implementation is tail-recursive. %%% :- pred partition(pred(T, T), T, int, int, array(T), array(T), int). :- mode partition(pred(in, in) is semidet, in, in, in, array_di, array_uo, out). partition(Less_than, Pivot, I_first, I_last, Arr0, Arr, I_pivot) :- I = I_first - 1, J = I_last + 1, partition_loop(Less_than, Pivot, I, J, Arr0, Arr, I_pivot). :- pred partition_loop(pred(T, T), T, int, int, array(T), array(T), int). :- mode partition_loop(pred(in, in) is semidet, in, in, in, array_di, array_uo, out). partition_loop(Less_than, Pivot, I, J, Arr0, Arr, Pivot_index) :- if (I = J) then (Arr = Arr0, Pivot_index = I) else (I1 = I + 1, I2 = search_right(Less_than, Pivot, I1, J, Arr0), (if (I2 = J) then (Arr = Arr0, Pivot_index = J) else (J1 = J - 1, J2 = search_left(Less_than, Pivot, I2, J1, Arr0), swap(I2, J2, Arr0, Arr1), partition_loop(Less_than, Pivot, I2, J2, Arr1, Arr, Pivot_index)))). :- func search_right(pred(T, T), T, int, int, array(T)) = int. :- mode search_right(pred(in, in) is semidet, in, in, in, in) = out is det. search_right(Less_than, Pivot, I, J, Arr0) = K :- if (I = J) then (I = K) else if Less_than(Pivot, Arr0^elem(I)) then (I = K) else (search_right(Less_than, Pivot, I + 1, J, Arr0) = K). :- func search_left(pred(T, T), T, int, int, array(T)) = int. :- mode search_left(pred(in, in) is semidet, in, in, in, in) = out is det. search_left(Less_than, Pivot, I, J, Arr0) = K :- if (I = J) then (J = K) else if Less_than(Arr0^elem(J), Pivot) then (J = K) else (search_left(Less_than, Pivot, I, J - 1, Arr0) = K). %%%------------------------------------------------------------------- %%% %%% Quickselect with a random pivot. %%% %%% The implementation is tail-recursive. One has to pass the routine %%% a random number generator of type M, attached to the IO state. %%% %%% I use a random pivot to get O(n) worst case *expected* running %%% time. Code using a random pivot is easy to write and read, and for %%% most purposes comes close enough to a criterion set by Scheme's %%% SRFI-132: "Runs in O(n) time." (See %%% https://srfi.schemers.org/srfi-132/srfi-132.html) %%% %%% Of course we are not bound here by SRFI-132, but still I respect %%% it as a guide. %%% %%% A "median of medians" pivot gives O(n) running time, but is more %%% complicated. (That is, of course, assuming you are not writing %%% your own random number generator and making it a complicated one.) %%% %% quickselect/8 selects the (K+1)th largest element of Arr. :- pred quickselect(pred(T, T)::pred(in, in) is semidet, int::in, array(T)::array_di, array(T)::array_uo, T::out, M::in, io::di, io::uo) is det <= urandom(M, io). quickselect(Less_than, K, Arr0, Arr, Elem, M, !IO) :- bounds(Arr0, I_first, I_last), quickselect(Less_than, I_first, I_last, K, Arr0, Arr, Elem, M, !IO). %% quickselect/10 selects the (K+1)th largest element of %% Arr[I_first..I_last]. :- pred quickselect(pred(T, T)::pred(in, in) is semidet, int::in, int::in, int::in, array(T)::array_di, array(T)::array_uo, T::out, M::in, io::di, io::uo) is det <= urandom(M, io). quickselect(Less_than, I_first, I_last, K, Arr0, Arr, Elem, M, !IO) :- if (0 =< K, K =< I_last - I_first) then (K_adjusted_for_range = K + I_first, quickselect_loop(Less_than, I_first, I_last, K_adjusted_for_range, Arr0, Arr, Elem, M, !IO)) else throw("out of range"). :- pred quickselect_loop(pred(T, T)::pred(in, in) is semidet, int::in, int::in, int::in, array(T)::array_di, array(T)::array_uo, T::out, M::in, io::di, io::uo) is det <= urandom(M, io). quickselect_loop(Less_than, I_first, I_last, K, Arr0, Arr, Elem, M, !IO) :- if (I_first = I_last) then (Arr = Arr0, Elem = Arr0^elem(I_first)) else (uniform_int_in_range(M, I_first, I_last - I_first + 1, I_pivot, !IO), Pivot = Arr0^elem(I_pivot), %% Move the last element to where the pivot had been. Perhaps %% the pivot was already the last element, of course. In any %% case, we shall partition only from I_first to I_last - 1. Elem_last = Arr0^elem(I_last), Arr1 = (Arr0^elem(I_pivot) := Elem_last), %% Partition the array in the range I_first..I_last - 1, %% leaving out the last element (which now can be considered %% garbage). partition(Less_than, Pivot, I_first, I_last - 1, Arr1, Arr2, I_final), %% Now everything that is less than the pivot is to the left %% of I_final. %% Put the pivot at I_final, moving the element that had been %% there to the end. If I_final = I_last, then this element is %% actually garbage and will be overwritten with the pivot, %% which turns out to be the greatest element. Otherwise, the %% moved element is not less than the pivot and so the %% partitioning is preserved. Elem_to_move = Arr2^elem(I_final), Arr3 = (Arr2^elem(I_last) := Elem_to_move), Arr4 = (Arr3^elem(I_final) := Pivot), %% Compare I_final and K, to see what to do next. (if (I_final < K) then quickselect_loop(Less_than, I_final + 1, I_last, K, Arr4, Arr, Elem, M, !IO) else if (K < I_final) then quickselect_loop(Less_than, I_first, I_final - 1, K, Arr4, Arr, Elem, M, !IO) else (Arr = Arr4, Elem = Arr4^elem(I_final)))). %%%------------------------------------------------------------------- :- func example_numbers = list(int). example_numbers = [9, 8, 7, 6, 5, 0, 1, 2, 3, 4]. main(!IO) :- (sfc64.init(P, S)), make_io_urandom(P, S, M, !IO), Print_kth_greatest = (pred(K::in, di, uo) is det --> print_kth_greatest(K, example_numbers, M)), Print_kth_least = (pred(K::in, di, uo) is det --> print_kth_least(K, example_numbers, M)), print("With < as order predicate: ", !IO), foldl(Print_kth_least, 1 `..` 10, !IO), print_line("", !IO), print("With > as order predicate: ", !IO), foldl(Print_kth_greatest, 1 `..` 10, !IO), print_line("", !IO). :- pred print_kth_least(int::in, list(int)::in, M::in, io::di, io::uo) is det <= urandom(M, io). print_kth_least(K, Numbers_list, M, !IO) :- (array.from_list(Numbers_list, Arr0)), quickselect(<, K - 1, Arr0, _, Elem, M, !IO), print(" ", !IO), print(Elem, !IO). :- pred print_kth_greatest(int::in, list(int)::in, M::in, io::di, io::uo) is det <= urandom(M, io). print_kth_greatest(K, Numbers_list, M, !IO) :- (array.from_list(Numbers_list, Arr0)), %% Notice that the "Less_than" predicate is actually "greater %% than". :) One can think of this as meaning that a greater number %% has an *ordinal* that is "less than"; that is, it "comes before" %% in the order. quickselect(>, K - 1, Arr0, _, Elem, M, !IO), print(" ", !IO), print(Elem, !IO). %%%------------------------------------------------------------------- %%% local variables: %%% mode: mercury %%% prolog-indent-width: 2 %%% end: