463 lines
13 KiB
Plaintext
463 lines
13 KiB
Plaintext
(* This is Algorithm 4.3.1M in Volume 2 of Knuth, ‘The Art of Computer
|
||
Programming’. *)
|
||
|
||
#include "share/atspre_staload.hats"
|
||
|
||
#define NIL list_nil ()
|
||
#define :: list_cons
|
||
|
||
(********************** FOR BINARY ARITHMETIC ***********************)
|
||
|
||
(* We need to choose a radix for the multiplication, small enough that
|
||
intermediate results can be represented, but big for efficiency. To
|
||
stay within the POSIX types, I choose 2**32 as my radix. Thus
|
||
‘digits’ are stored in uint32 and intermediate results are stored
|
||
in uint64.
|
||
|
||
A number is stored as an array of uint32, with the least
|
||
significant uint32 first. *)
|
||
|
||
extern fn
|
||
long_multiplication (* Multiply u and v, giving w. *)
|
||
{m, n : int}
|
||
(m : size_t m,
|
||
n : size_t n,
|
||
u : &array (uint32, m),
|
||
v : &array (uint32, n),
|
||
w : &array (uint32?, m + n) >> array (uint32, m + n))
|
||
:<!refwrt> void
|
||
|
||
%{^
|
||
#include <stdint.h>
|
||
%}
|
||
extern castfn i2u32 : int -<> uint32
|
||
extern castfn u32_2i : uint32 -<> int
|
||
extern castfn i2u64 : int -<> uint64
|
||
extern castfn u32u64 : uint32 -<> uint64
|
||
extern castfn u64u32 : uint64 -<> uint32
|
||
macdef zero32 = i2u32 0
|
||
macdef zero64 = i2u64 0
|
||
macdef one32 = i2u32 1
|
||
macdef ten32 = i2u32 10
|
||
macdef mask32 = $extval (uint32, "UINT32_C (0xFFFFFFFF)")
|
||
|
||
(* The following implementation is precisely the algorithm suggested
|
||
by Knuth, although specialized for b=2**32 and for unsigned
|
||
integers of precisely 32 bits. *)
|
||
implement
|
||
long_multiplication {m, n} (m, n, u, v, w) =
|
||
let
|
||
(* Establish that the arrays have non-negative lengths. *)
|
||
prval () = lemma_array_param u
|
||
prval () = lemma_array_param v
|
||
|
||
(* Knuth initializes only part of the w array. However, if we
|
||
initialize ALL of w now, then we will not have to deal with
|
||
complicated array views later. *)
|
||
val () = array_initize_elt<uint32> (w, m + n, zero32)
|
||
|
||
(* The following function includes proof of termination. *)
|
||
fun
|
||
jloop {j : nat | j <= n} .<n - j>.
|
||
(u : &array (uint32, m),
|
||
v : &array (uint32, n),
|
||
w : &array (uint32, m + n),
|
||
j : size_t j)
|
||
:<!refwrt> void =
|
||
if j = n then
|
||
()
|
||
else if v[j] = zero32 then (* This branch is optional. *)
|
||
begin
|
||
w[j + m] := zero32;
|
||
jloop (u, v, w, succ j)
|
||
end
|
||
else
|
||
let
|
||
fun
|
||
iloop {i : nat | i <= m} .<m - i>.
|
||
(u : &array (uint32, m),
|
||
v : &array (uint32, n),
|
||
w : &array (uint32, m + n),
|
||
i : size_t i,
|
||
k : uint64) (* carry *)
|
||
:<!refwrt> void =
|
||
if i = m then
|
||
w[j + m] := u64u32 k
|
||
else
|
||
let
|
||
val t = (u32u64 u[i] * u32u64 v[j])
|
||
+ u32u64 w[i + j] + k
|
||
in
|
||
(* The mask here is not actually needed, if uint32
|
||
really is treated by the C compiler as 32 bits. *)
|
||
w[i + j] := (u64u32 t) land mask32;
|
||
|
||
iloop (u, v, w, succ i, t >> 32)
|
||
end
|
||
in
|
||
iloop (u, v, w, i2sz 0, zero64);
|
||
jloop (u, v, w, succ j)
|
||
end
|
||
in
|
||
jloop (u, v, w, i2sz 0)
|
||
end
|
||
|
||
fn
|
||
big_integer_iseqz (* Is a big integer equal to zero? *)
|
||
{m : int}
|
||
(m : size_t m,
|
||
u : &array (uint32, m))
|
||
:<!ref> bool =
|
||
let
|
||
prval () = lemma_array_param u
|
||
fun
|
||
loop {n : nat | n <= m} .<n>.
|
||
(u : &array (uint32, m),
|
||
n : size_t n)
|
||
:<!ref> bool =
|
||
if n = i2sz 0 then
|
||
true
|
||
else if u[pred n] = zero32 then
|
||
loop (u, pred n)
|
||
else
|
||
false
|
||
in
|
||
loop (u, m)
|
||
end
|
||
|
||
(* To print the number in decimal, we need division by 10. So here is
|
||
‘short division’: Exercise 4.3.1.16 in Volume 2 of Knuth. *)
|
||
fn
|
||
short_division
|
||
{m : int}
|
||
(m : size_t m,
|
||
u : &array (uint32, m),
|
||
v : uint32,
|
||
q : &array (uint32?, m) >> array (uint32, m),
|
||
r : &uint32? >> uint32)
|
||
:<!refwrt> void =
|
||
let
|
||
prval () = lemma_array_param u
|
||
val () = array_initize_elt<uint32> (q, m, zero32)
|
||
val () = r := zero32
|
||
fun
|
||
loop {i1 : nat | i1 <= m} .<i1>.
|
||
(u : &array (uint32, m),
|
||
q : &array (uint32, m),
|
||
i1 : size_t i1,
|
||
r : &uint32)
|
||
:<!refwrt> void =
|
||
if i1 <> i2sz 0 then
|
||
let
|
||
val i = pred i1
|
||
val tmp = (u32u64 r << 32) lor (u32u64 u[i])
|
||
val tmp_q = tmp / u32u64 v and tmp_r = tmp mod (u32u64 v)
|
||
in
|
||
q[i] := u64u32 tmp_q;
|
||
r := u64u32 tmp_r;
|
||
loop (u, q, i, r)
|
||
end
|
||
in
|
||
loop (u, q, m, r)
|
||
end
|
||
|
||
fn
|
||
fprint_big_integer
|
||
{m : int}
|
||
(f : FILEref,
|
||
m : size_t m,
|
||
u : &array (uint32, m))
|
||
: void =
|
||
let
|
||
fun
|
||
loop1 (v : &array (uint32, m),
|
||
q : &array (uint32, m),
|
||
lst : List0 char,
|
||
i : uint)
|
||
: List0 char =
|
||
let
|
||
var r : uint32
|
||
val () = short_division (m, v, ten32, q, r)
|
||
val r = g1ofg0 (u32_2i r)
|
||
val () = assertloc ((0 <= r) * (r <= 9))
|
||
val digit = int2digit r
|
||
in
|
||
if big_integer_iseqz (m, q) then
|
||
digit :: lst
|
||
else if i = 2U then
|
||
(* Insert UTF-8 for narrow no-break space U+202F *)
|
||
loop1 (q, v, '\xE2' :: '\x80' :: '\xAF' :: digit :: lst, 0U)
|
||
else
|
||
loop1 (q, v, digit :: lst, succ i)
|
||
end
|
||
fun
|
||
loop2 {n : nat} .<n>.
|
||
(lst : list (char, n))
|
||
: void =
|
||
case+ lst of
|
||
| NIL => ()
|
||
| hd :: tl => (fprint! (f, hd); loop2 tl)
|
||
in
|
||
if big_integer_iseqz (m, u) then
|
||
fprint! (f, "0")
|
||
else
|
||
let
|
||
val @(pf, pfgc | p) = array_ptr_alloc<uint32> m
|
||
val @(qf, qfgc | q) = array_ptr_alloc<uint32> m
|
||
val () = array_copy<uint32> (!p, u, m)
|
||
val () = array_initize_elt<uint32> (!q, m, zero32)
|
||
val () = loop2 (loop1 (!p, !q, NIL, 0U))
|
||
val () = array_ptr_free (pf, pfgc | p)
|
||
val () = array_ptr_free (qf, qfgc | q)
|
||
in
|
||
end
|
||
end
|
||
|
||
fn
|
||
example_binary (f : FILEref) : void =
|
||
let
|
||
var u = @[uint32][3] (zero32, zero32, one32)
|
||
var v = @[uint32][3] (zero32, zero32, one32)
|
||
var w : @[uint32][6]
|
||
in
|
||
long_multiplication (i2sz 3, i2sz 3, u, v, w);
|
||
fprint! (f, "\nBinary long multiplication (b = 2³²)\n\n");
|
||
fprint! (f, "u = ");
|
||
fprint_big_integer (f, i2sz 3, u);
|
||
fprint! (f, "\nv = ");
|
||
fprint_big_integer (f, i2sz 3, v);
|
||
fprint! (f, "\nu × v = ");
|
||
fprint_big_integer (f, i2sz 6, w);
|
||
fprint! (f, "\n")
|
||
end
|
||
|
||
fn
|
||
test_binary (f : FILEref) : void =
|
||
let
|
||
var u = @[uint32][3] (mask32, mask32, mask32)
|
||
var v = @[uint32][3] (mask32, mask32, mask32)
|
||
var w : @[uint32][6]
|
||
in
|
||
long_multiplication (i2sz 3, i2sz 3, u, v, w);
|
||
fprint! (f, "\nThe example numbers specified in the task\n",
|
||
"are actually VERY bad for testing binary\n",
|
||
"multiplication, because they never need a carry.\n",
|
||
"So here is a multiplication full of carries,\n",
|
||
"with b = 2³²\n\n");
|
||
fprint! (f, "u = ");
|
||
fprint_big_integer (f, i2sz 3, u);
|
||
fprint! (f, "\nv = ");
|
||
fprint_big_integer (f, i2sz 3, v);
|
||
fprint! (f, "\nu × v = ");
|
||
fprint_big_integer (f, i2sz 6, w);
|
||
fprint! (f, "\n")
|
||
end
|
||
|
||
(************** FOR BINARY CODED DECIMAL ARITHMETIC *****************)
|
||
|
||
(* The following will operate on arrays of BCD digits, with the most
|
||
significant digit first. Only the least four bits of a byte will be
|
||
considered. This has at least two benefits: any ASCII digit is
|
||
treated as its BCD equivalent, and SPACE is treated as zero. *)
|
||
|
||
extern fn
|
||
bcd_multiplication (* Multiply u and v, giving w. *)
|
||
{m, n : int}
|
||
(m : size_t m,
|
||
n : size_t n,
|
||
u : &array (char, m),
|
||
v : &array (char, n),
|
||
w : &array (char?, m + n) >> array (char, m + n))
|
||
:<!refwrt> void
|
||
|
||
fn {}
|
||
char2bcd (c : char) :<> intBtwe (0, 9) =
|
||
let
|
||
val c = char2uchar1 (g1ofg0 c)
|
||
val i = g1uint_of_uchar1 c
|
||
val i = i mod 16U
|
||
val i = i mod 10U (* Guarantees the digit be BCD. *)
|
||
in
|
||
u2i i
|
||
end
|
||
|
||
extern castfn bcd2char (i : intBtwe (0, 9)) :<> char
|
||
|
||
(* The following implementation is precisely the algorithm suggested
|
||
by Knuth, specialized for b=10. *)
|
||
implement
|
||
bcd_multiplication {m, n} (m, n, u, v, w) =
|
||
let
|
||
(* Establish that the arrays have non-negative lengths. *)
|
||
prval () = lemma_array_param u
|
||
prval () = lemma_array_param v
|
||
|
||
(* Knuth initializes only part of the w array. However, if we
|
||
initialize ALL of w now, then we will not have to deal with
|
||
complicated array views later. *)
|
||
val () = array_initize_elt<char> (w, m + n, '\0')
|
||
|
||
(* The following function includes proof of termination. *)
|
||
fun
|
||
jloop {j : nat | j <= n} .<n - j>.
|
||
(u : &array (char, m),
|
||
v : &array (char, n),
|
||
w : &array (char, m + n),
|
||
j : size_t j)
|
||
:<!refwrt> void =
|
||
if j = n then
|
||
()
|
||
else if char2bcd v[pred n - j] = 0 then (* Optional branch. *)
|
||
begin
|
||
w[pred n - j] := '\0';
|
||
jloop (u, v, w, succ j)
|
||
end
|
||
else
|
||
let
|
||
fun
|
||
iloop {i : nat | i <= m} .<m - i>.
|
||
(u : &array (char, m),
|
||
v : &array (char, n),
|
||
w : &array (char, m + n),
|
||
i : size_t i,
|
||
k : intBtwe (0, 9)) (* carry *)
|
||
:<!refwrt> void =
|
||
if i = m then
|
||
w[pred n - j] := bcd2char k
|
||
else
|
||
let
|
||
val ui = char2bcd u[pred m - i]
|
||
and vj = char2bcd v[pred n - j]
|
||
and wij = char2bcd w[pred (m + n) - (i + j)]
|
||
|
||
val t = (ui * vj) + wij + k
|
||
|
||
(* This will prove that 0 <= t *)
|
||
prval [ui : int] EQINT () = eqint_make_gint ui
|
||
prval [vj : int] EQINT () = eqint_make_gint vj
|
||
prval [t : int] EQINT () = eqint_make_gint t
|
||
prval () = mul_gte_gte_gte {ui, vj} ()
|
||
prval () = prop_verify {0 <= t} ()
|
||
|
||
(* But I do not feel like proving that t / 10 <= 9. *)
|
||
val t_div_10 = t \ndiv 10 and t_mod_10 = t \nmod 10
|
||
val () = $effmask_exn assertloc (t_div_10 <= 9)
|
||
in
|
||
w[pred (m + n) - (i + j)] := bcd2char t_mod_10;
|
||
iloop (u, v, w, succ i, t_div_10)
|
||
end
|
||
in
|
||
iloop (u, v, w, i2sz 0, 0);
|
||
jloop (u, v, w, succ j)
|
||
end
|
||
in
|
||
jloop (u, v, w, i2sz 0)
|
||
end
|
||
|
||
fn
|
||
fprint_bcd {m : int}
|
||
(f : FILEref,
|
||
m : size_t m,
|
||
u : &array (char, m))
|
||
: void =
|
||
let
|
||
prval () = lemma_array_param u
|
||
|
||
fun
|
||
skip_zeros {i : nat | i <= m} .<m - i>.
|
||
(u : &array (char, m),
|
||
i : size_t i)
|
||
:<!ref> [i : nat | i <= m] size_t i =
|
||
if i = m then
|
||
i
|
||
else if char2bcd u[i] = 0 then
|
||
skip_zeros (u, succ i)
|
||
else
|
||
i
|
||
|
||
val [i : int] i = skip_zeros (u, i2sz 0)
|
||
|
||
fun
|
||
loop {j : int | i <= j; j <= m} .<m - j>.
|
||
(u : &array (char, m),
|
||
j : size_t j)
|
||
: void =
|
||
if j <> m then
|
||
begin
|
||
if j <> i && (m - j) mod (i2sz 3) = i2sz 0 then
|
||
(* Print UTF-8 for narrow no-break space U+202F *)
|
||
fprint! (f, "\xE2\x80\xAF");
|
||
fprint! (f, int2digit (char2bcd u[j]));
|
||
loop (u, succ j)
|
||
end
|
||
in
|
||
if i = m then
|
||
fprint! (f, "0")
|
||
else
|
||
loop (u, i)
|
||
end
|
||
|
||
fn
|
||
string2bcd {n : int}
|
||
(s : string n)
|
||
: [p : agz]
|
||
@(array_v (char, p, n), mfree_gc_v p | ptr p) =
|
||
let
|
||
val n = strlen s
|
||
val @(pf, pfgc | p) = array_ptr_alloc<char> n
|
||
implement
|
||
array_initize$init<char> (i, x) =
|
||
let
|
||
val i = g1ofg0 i
|
||
prval () = lemma_g1uint_param i
|
||
val () = assertloc (i < n)
|
||
in
|
||
x := s[i]
|
||
end
|
||
val () = array_initize<char> (!p, n)
|
||
in
|
||
@(pf, pfgc | p)
|
||
end
|
||
|
||
fn
|
||
example_bcd (f : FILEref) : void =
|
||
let
|
||
val s = g1ofg0 "18446744073709551616"
|
||
|
||
val m = strlen s
|
||
|
||
val @(pf_u, pfgc_u | p_u) = string2bcd s
|
||
val @(pf_v, pfgc_v | p_v) = string2bcd s
|
||
val @(pf_w, pfgc_w | p_w) = array_ptr_alloc<char> (m + m)
|
||
macdef u = !p_u
|
||
macdef v = !p_v
|
||
macdef w = !p_w
|
||
in
|
||
bcd_multiplication (m, m, u, v, w);
|
||
fprint! (f, "\nDecimal long multiplication (b = 10)\n\n");
|
||
fprint! (f, "u = ");
|
||
fprint_bcd (f, m, u);
|
||
fprint! (f, "\nv = ");
|
||
fprint_bcd (f, m, v);
|
||
fprint! (f, "\nu × v = ");
|
||
fprint_bcd (f, m + m, w);
|
||
fprint! (f, "\n");
|
||
array_ptr_free (pf_u, pfgc_u | p_u);
|
||
array_ptr_free (pf_v, pfgc_v | p_v);
|
||
array_ptr_free (pf_w, pfgc_w | p_w)
|
||
end
|
||
|
||
(********************************************************************)
|
||
|
||
implement
|
||
main () =
|
||
begin
|
||
example_binary (stdout_ref);
|
||
println! ();
|
||
example_bcd (stdout_ref);
|
||
println! ();
|
||
test_binary (stdout_ref);
|
||
println! ();
|
||
0
|
||
end
|