RosettaCodeData/Task/Numerical-integration/ATS/numerical-integration.ats

189 lines
6.1 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "share/atspre_staload.hats"
%{^
#include <math.h>
%}
typedef FILEstar = $extype"FILE *"
extern castfn FILEref2star : FILEref -<> FILEstar
(* This type declarations is for composite quadrature functions for
all the different g0float typekinds. The function must either prove
termination or mask the requirement. (All of ours will prove
termination.) The function to be integrated will not be passed as
an argument, but inlined via the template mechanism. (This design
is more general. It can easily be used to write a quadrature
function that takes the argument, but also can be used for faster
code that requires no function call.) *)
typedef composite_quadrature (tk : tkind) =
(g0float tk, g0float tk, intGte 2) -<> g0float tk
extern fn {tk : tkind}
composite_quadrature$func : g0float tk -<> g0float tk
extern fn {tk : tkind} left_rule : composite_quadrature tk
extern fn {tk : tkind} right_rule : composite_quadrature tk
extern fn {tk : tkind} midpoint_rule : composite_quadrature tk
extern fn {tk : tkind} trapezium_rule : composite_quadrature tk
extern fn {tk : tkind} simpson_rule : composite_quadrature tk
extern fn {tk : tkind}
_one_point_rule$init_x :
g0float tk -<> g0float tk
fn {tk : tkind}
_one_point_rule : composite_quadrature tk =
lam (a, b, n) =>
let
prval [n : int] EQINT () = eqint_make_gint n
macdef f = composite_quadrature$func
val h = (b - a) / g0i2f n
val x0 = _one_point_rule$init_x<tk> h
fun
loop {i : nat | i <= n} .<n - i>.
(i : int i,
sum : g0float tk) :<> g0float tk =
if i = n then
sum
else
loop (succ i, sum + f(x0 + (g0i2f i * h)))
in
loop (0, g0i2f 0) * h
end
(* The left rule, for any floating point type. *)
implement {tk}
left_rule (a, b, n) =
let
implement _one_point_rule$init_x<tk> _ = a
in
_one_point_rule<tk> (a, b, n)
end
(* The right rule, for any floating point type. *)
implement {tk}
right_rule (a, b, n) =
let
implement _one_point_rule$init_x<tk> h = a + h
in
_one_point_rule<tk> (a, b, n)
end
(* The midpoint rule, for any floating point type. *)
implement {tk}
midpoint_rule (a, b, n) =
let
implement _one_point_rule$init_x<tk> h = a + (h / g0i2f 2)
in
_one_point_rule<tk> (a, b, n)
end
implement {tk}
trapezium_rule : composite_quadrature tk =
lam (a, b, n) =>
let
prval [n : int] EQINT () = eqint_make_gint n
macdef f = composite_quadrature$func
val h = (b - a) / g0i2f n
fun
loop {i : pos | i <= n} .<n - i>.
(i : int i,
sum : g0float tk) :<> g0float tk =
if i = n then
sum
else
loop (succ i, sum + f(a + (g0i2f i * h)))
val sum = loop (1, g0i2f 0)
in
((f(a) + sum + sum + f(b)) * h) / g0i2f 2
end
(* Simpsons 1/3 rule, for any floating point type. *)
implement {tk}
simpson_rule : composite_quadrature tk =
lam (a, b, n) =>
let
(* I have noticed that the Simpson rule is a weighted average of
the trapezium and midpoint rules, which themselves evaluate
the function at different points. Therefore, the following
should be efficient and produce good results. *)
val estimate1 = trapezium_rule<tk> (a, b, n)
val estimate2 = midpoint_rule<tk> (a, b, n)
in
(estimate1 + estimate2 + estimate2) / (g0i2f 3)
end
extern fn {tk : tkind}
fprint_result$rule : composite_quadrature tk
extern fn {tk : tkind}
fprint_result (outf : FILEref,
message : string,
a : g0float tk,
b : g0float tk,
n : intGte 2,
nominal : g0float tk) : void
implement
fprint_result<dblknd> (outf, message, a, b, n, nominal) =
let
val integral = fprint_result$rule<dblknd> (a, b, n)
in
fprint! (outf, " ", message, " ");
ignoret ($extfcall (int, "fprintf", FILEref2star outf,
"%18.15le", integral));
fprint! (outf, " (nominal + ");
ignoret ($extfcall (int, "fprintf", FILEref2star outf,
"% .6le", integral - nominal));
fprint! (outf, ")\n")
end
fn {tk : tkind}
fprint_rule_results (outf : FILEref,
a : g0float tk,
b : g0float tk,
n : intGte 2,
nominal : g0float tk) : void =
let
implement fprint_result$rule<tk> (a, b, n) = left_rule<tk> (a, b, n)
val () = fprint_result (outf, "left rule ", a, b, n, nominal)
implement fprint_result$rule<tk> (a, b, n) = right_rule<tk> (a, b, n)
val () = fprint_result (outf, "right rule ", a, b, n, nominal)
implement fprint_result$rule<tk> (a, b, n) = midpoint_rule<tk> (a, b, n)
val () = fprint_result (outf, "midpoint rule ", a, b, n, nominal)
implement fprint_result$rule<tk> (a, b, n) = trapezium_rule<tk> (a, b, n)
val () = fprint_result (outf, "trapezium rule ", a, b, n, nominal)
implement fprint_result$rule<tk> (a, b, n) = simpson_rule<tk> (a, b, n)
val () = fprint_result (outf, "Simpson rule ", a, b, n, nominal)
in
end
implement
main () =
let
val outf = stdout_ref
val () = fprint! (outf, "\nx³ in [0,1] with n = 100\n")
implement composite_quadrature$func<dblknd> x = x * x * x
val () = fprint_rule_results<dblknd> (outf, 0.0, 1.0, 100, 0.25)
val () = fprint! (outf, "\n1/x in [1,100] with n = 1000\n")
implement composite_quadrature$func<dblknd> x = g0i2f 1 / x
val () = fprint_rule_results<dblknd> (outf, 1.0, 100.0, 1000,
$extfcall (double, "log", 100.0))
val () = fprint! (outf, "\nx in [0,5000] with n = 5000000\n")
implement composite_quadrature$func<dblknd> x = x
val () = fprint_rule_results<dblknd> (outf, 0.0, 5000.0, 5000000,
12500000.0)
val () = fprint! (outf, "\nx in [0,6000] with n = 6000000\n")
implement composite_quadrature$func<dblknd> x = x
val () = fprint_rule_results<dblknd> (outf, 0.0, 6000.0, 6000000,
18000000.0)
val () = fprint! (outf, "\n")
in
0
end