(* The Rosetta Code code generator in ATS2. *) (* Usage: gen [INPUTFILE [OUTPUTFILE]] If INPUTFILE or OUTPUTFILE is "-" or missing, then standard input or standard output is used, respectively. *) (* Note: you might wish to add code to catch exceptions and print nice messages. *) (*------------------------------------------------------------------*) #define ATS_DYNLOADFLAG 0 #include "share/atspre_staload.hats" staload UN = "prelude/SATS/unsafe.sats" #define NIL list_vt_nil () #define :: list_vt_cons %{^ /* alloca(3) is needed for ATS exceptions. */ #include %} exception internal_error of () exception bad_ast_node_type of string exception premature_end_of_input of () exception bad_number_field of string exception missing_identifier_field of () exception bad_quoted_string of string (* Some implementations that are likely missing from the prelude. *) implement g0uint2int x = $UN.cast x implement g0uint2uint x = $UN.cast x implement g0uint2int x = $UN.cast x (*------------------------------------------------------------------*) extern fn {} skip_characters$skipworthy (c : char) :<> bool fn {} skip_characters {n : int} {i : nat | i <= n} (s : string n, i : size_t i) :<> [j : int | i <= j; j <= n] size_t j = let fun loop {k : int | i <= k; k <= n} .. (k : size_t k) :<> [j : int | k <= j; j <= n] size_t j = if string_is_atend (s, k) then k else if ~skip_characters$skipworthy (string_get_at (s, k)) then k else loop (succ k) in loop i end fn skip_whitespace {n : int} {i : nat | i <= n} (s : string n, i : size_t i) :<> [j : int | i <= j; j <= n] size_t j = let implement skip_characters$skipworthy<> c = isspace c in skip_characters<> (s, i) end fn skip_nonwhitespace {n : int} {i : nat | i <= n} (s : string n, i : size_t i) :<> [j : int | i <= j; j <= n] size_t j = let implement skip_characters$skipworthy<> c = ~isspace c in skip_characters<> (s, i) end fn skip_nonquote {n : int} {i : nat | i <= n} (s : string n, i : size_t i) :<> [j : int | i <= j; j <= n] size_t j = let implement skip_characters$skipworthy<> c = c <> '"' in skip_characters<> (s, i) end fn skip_to_end {n : int} {i : nat | i <= n} (s : string n, i : size_t i) :<> [j : int | i <= j; j <= n] size_t j = let implement skip_characters$skipworthy<> c = true in skip_characters<> (s, i) end (*------------------------------------------------------------------*) fn substring_equals {n : int} {i, j : nat | i <= j; j <= n} (s : string n, i : size_t i, j : size_t j, t : string) :<> bool = let val m = strlen t in if j - i <> m then false (* The substring is the wrong length. *) else let val p_s = ptrcast s and p_t = ptrcast t in 0 = $extfcall (int, "strncmp", ptr_add (p_s, i), p_t, m) end end (*------------------------------------------------------------------*) datatype node_type_t = | NullNode | Identifier | String | Integer | Sequence | If | Prtc | Prts | Prti | While | Assign | Negate | Not | Multiply | Divide | Mod | Add | Subtract | Less | LessEqual | Greater | GreaterEqual | Equal | NotEqual | And | Or #define ARBITRARY_NODE_ARG 1234 datatype ast_node_t = | ast_node_t_nil | ast_node_t_nonnil of node_contents_t where node_contents_t = @{ node_type = node_type_t, node_arg = ullint, node_left = ast_node_t, node_right = ast_node_t } fn get_node_type {n : int} {i : nat | i <= n} (s : string n, i : size_t i) : [j : int | i <= j; j <= n] @(node_type_t, size_t j) = let val i_start = skip_whitespace (s, i) val i_end = skip_nonwhitespace (s, i_start) macdef eq t = substring_equals (s, i_start, i_end, ,(t)) val node_type = if eq ";" then NullNode else if eq "Identifier" then Identifier else if eq "String" then String else if eq "Integer" then Integer else if eq "Sequence" then Sequence else if eq "If" then If else if eq "Prtc" then Prtc else if eq "Prts" then Prts else if eq "Prti" then Prti else if eq "While" then While else if eq "Assign" then Assign else if eq "Negate" then Negate else if eq "Not" then Not else if eq "Multiply" then Multiply else if eq "Divide" then Divide else if eq "Mod" then Mod else if eq "Add" then Add else if eq "Subtract" then Subtract else if eq "Less" then Less else if eq "LessEqual" then LessEqual else if eq "Greater" then Greater else if eq "GreaterEqual" then GreaterEqual else if eq "Equal" then Equal else if eq "NotEqual" then NotEqual else if eq "And" then And else if eq "Or" then Or else let val s_bad = strnptr2string (string_make_substring (s, i_start, i_end - i_start)) in $raise bad_ast_node_type s_bad end in @(node_type, i_end) end fn get_unsigned {n : int} {i : nat | i <= n} (s : string n, i : size_t i) : [j : int | i <= j; j <= n] @(ullint, size_t j) = let val i = skip_whitespace (s, i) val [j : int] j = skip_nonwhitespace (s, i) in if j = i then $raise bad_number_field "" else let fun loop {k : int | i <= k; k <= j} (k : size_t k, v : ullint) : ullint = if k = j then v else let val c = string_get_at (s, k) in if ~isdigit c then let val s_bad = strnptr2string (string_make_substring (s, i, j - i)) in $raise bad_number_field s_bad end else let val digit = char2int1 c - char2int1 '0' val () = assertloc (0 <= digit) in loop (succ k, (g1i2u 10 * v) + g1i2u digit) end end in @(loop (i, g0i2u 0), j) end end fn get_identifier {n : int} {i : nat | i <= n} (s : string n, i : size_t i) : [j : int | i <= j; j <= n] @(string, size_t j) = let val i = skip_whitespace (s, i) val j = skip_nonwhitespace (s, i) in if i = j then $raise missing_identifier_field () else let val ident = strnptr2string (string_make_substring (s, i, j - i)) in @(ident, j) end end fn get_quoted_string {n : int} {i : nat | i <= n} (s : string n, i : size_t i) : [j : int | i <= j; j <= n] @(string, size_t j) = let val i = skip_whitespace (s, i) in if string_is_atend (s, i) then $raise bad_quoted_string "" else if string_get_at (s, i) <> '"' then let val j = skip_to_end (s, i) val s_bad = strnptr2string (string_make_substring (s, i, j - i)) in $raise bad_quoted_string s_bad end else let val j = skip_nonquote (s, succ i) in if string_is_atend (s, j) then let val s_bad = strnptr2string (string_make_substring (s, i, j - i)) in $raise bad_quoted_string s_bad end else let val quoted_string = strnptr2string (string_make_substring (s, i, succ j - i)) in @(quoted_string, succ j) end end end fn collect_string {n : int} (str : string, strings : &list_vt (string, n) >> list_vt (string, m)) : #[m : int | m == n || m == n + 1] [str_num : nat | str_num <= m] size_t str_num = (* This implementation uses ‘list_vt’ instead of ‘list’, so appending elements to the end of the list will be both efficient and safe. It would also have been reasonable to build a ‘list’ backwards and then make a reversed copy. *) let fun find_or_extend {i : nat | i <= n} .. (strings1 : &list_vt (string, n - i) >> list_vt (string, m), i : size_t i) : #[m : int | m == n - i || m == n - i + 1] [j : nat | j <= n] size_t j = case+ strings1 of | ~ NIL => let (* The string is not there. Extend the list. *) prval () = prop_verify {i == n} () in strings1 := (str :: NIL); i end | @ (head :: tail) => if head = str then let (* The string is found. *) prval () = fold@ strings1 in i end else let (* Continue looking. *) val j = find_or_extend (tail, succ i) prval () = fold@ strings1 in j end prval () = lemma_list_vt_param strings val n = i2sz (length strings) and j = find_or_extend (strings, i2sz 0) in j end fn load_ast (inpf : FILEref, idents : &List_vt string >> _, strings : &List_vt string >> _) : ast_node_t = let fun recurs (idents : &List_vt string >> _, strings : &List_vt string >> _) : ast_node_t = if fileref_is_eof inpf then $raise premature_end_of_input () else let val s = strptr2string (fileref_get_line_string inpf) prval () = lemma_string_param s (* String length >= 0. *) val i = i2sz 0 val @(node_type, i) = get_node_type (s, i) in case+ node_type of | NullNode () => ast_node_t_nil () | Integer () => let val @(number, _) = get_unsigned (s, i) in ast_node_t_nonnil @{ node_type = node_type, node_arg = number, node_left = ast_node_t_nil, node_right = ast_node_t_nil } end | Identifier () => let val @(ident, _) = get_identifier (s, i) val arg = collect_string (ident, idents) in ast_node_t_nonnil @{ node_type = node_type, node_arg = g0u2u arg, node_left = ast_node_t_nil, node_right = ast_node_t_nil } end | String () => let val @(quoted_string, _) = get_quoted_string (s, i) val arg = collect_string (quoted_string, strings) in ast_node_t_nonnil @{ node_type = node_type, node_arg = g0u2u arg, node_left = ast_node_t_nil, node_right = ast_node_t_nil } end | _ => let val node_left = recurs (idents, strings) val node_right = recurs (idents, strings) in ast_node_t_nonnil @{ node_type = node_type, node_arg = g1i2u ARBITRARY_NODE_ARG, node_left = node_left, node_right = node_right } end end in recurs (idents, strings) end fn print_strings {n : int} (outf : FILEref, strings : !list_vt (string, n)) : void = let fun loop {m : nat} .. (strings1 : !list_vt (string, m)) : void = case+ strings1 of | NIL => () | head :: tail => begin fprintln! (outf, head); loop tail end prval () = lemma_list_vt_param strings in loop strings end (*------------------------------------------------------------------*) #define ARBITRARY_INSTRUCTION_ARG 1234 #define ARBITRARY_JUMP_ARG 123456789 typedef instruction_t = @{ address = ullint, opcode = string, arg = llint } typedef code_t = ref instruction_t vtypedef pjump_t (p : addr) = (instruction_t @ p, instruction_t @ p - void | ptr p) vtypedef pjump_t = [p : addr] pjump_t p fn add_instruction (opcode : string, arg : llint, size : uint, code : &List0_vt code_t >> List1_vt code_t, pc : &ullint >> _) : void = let val instr = @{ address = pc, opcode = opcode, arg = arg } in code := (ref instr :: code); pc := pc + g0u2u size end fn add_jump (opcode : string, code : &List0_vt code_t >> List1_vt code_t, pc : &ullint >> _) : pjump_t = let val instr = @{ address = pc, opcode = opcode, arg = g1i2i ARBITRARY_JUMP_ARG } val ref_instr = ref instr in code := (ref_instr :: code); pc := pc + g0u2u 5U; ref_vtakeout ref_instr end fn fill_jump (pjump : pjump_t, address : ullint) : void = let val @(pf, fpf | p) = pjump val instr0 = !p val jump_offset : llint = let val from = succ (instr0.address) and to = address in if from <= to then g0u2i (to - from) else ~g0u2i (from - to) end val instr1 = @{ address = instr0.address, opcode = instr0.opcode, arg = jump_offset } val () = !p := instr1 prval () = fpf pf in end fn add_filled_jump (opcode : string, address : ullint, code : &List0_vt code_t >> List1_vt code_t, pc : &ullint >> _) : void = let val pjump = add_jump (opcode, code, pc) in fill_jump (pjump, address) end fn generate_code (ast : ast_node_t) : List_vt code_t = let fnx traverse (ast : ast_node_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = (* Generate the code by consing a list. *) case+ ast of | ast_node_t_nil () => () | ast_node_t_nonnil contents => begin case+ contents.node_type of | NullNode () => $raise internal_error () | If () => if_then (contents, code, pc) | While () => while_do (contents, code, pc) | Sequence () => sequence (contents, code, pc) | Assign () => assign (contents, code, pc) | Identifier () => immediate ("fetch", contents, code, pc) | Integer () => immediate ("push", contents, code, pc) | String () => immediate ("push", contents, code, pc) | Prtc () => unary_op ("prtc", contents, code, pc) | Prti () => unary_op ("prti", contents, code, pc) | Prts () => unary_op ("prts", contents, code, pc) | Negate () => unary_op ("neg", contents, code, pc) | Not () => unary_op ("not", contents, code, pc) | Multiply () => binary_op ("mul", contents, code, pc) | Divide () => binary_op ("div", contents, code, pc) | Mod () => binary_op ("mod", contents, code, pc) | Add () => binary_op ("add", contents, code, pc) | Subtract () => binary_op ("sub", contents, code, pc) | Less () => binary_op ("lt", contents, code, pc) | LessEqual () => binary_op ("le", contents, code, pc) | Greater () => binary_op ("gt", contents, code, pc) | GreaterEqual () => binary_op ("ge", contents, code, pc) | Equal () => binary_op ("eq", contents, code, pc) | NotEqual () => binary_op ("ne", contents, code, pc) | And () => binary_op ("and", contents, code, pc) | Or () => binary_op ("or", contents, code, pc) end and if_then (contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = case- (contents.node_right) of | ast_node_t_nonnil contents1 => let val condition = (contents.node_left) and true_branch = (contents1.node_left) and false_branch = (contents1.node_right) (* Generate code to evaluate the condition. *) val () = traverse (condition, code, pc); (* Generate a conditional jump. Where it goes to will be filled in later. *) val pjump = add_jump ("jz", code, pc) (* Generate code for the true branch. *) val () = traverse (true_branch, code, pc); in case+ false_branch of | ast_node_t_nil () => begin (* There is no false branch. *) (* Fill in the conditional jump to come here. *) fill_jump (pjump, pc) end | ast_node_t_nonnil _ => let (* There is a false branch. *) (* Generate an unconditional jump. Where it goes to will be filled in later. *) val pjump1 = add_jump ("jmp", code, pc) (* Fill in the conditional jump to come here. *) val () = fill_jump (pjump, pc) (* Generate code for the false branch. *) val () = traverse (false_branch, code, pc); (* Fill in the unconditional jump to come here. *) val () = fill_jump (pjump1, pc) in end end and while_do (contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = (* I would prefer to implement ‘while’ by putting the conditional jump at the end, and jumping to it to get into the loop. However, we need to generate not the code of our choice, but the reference result. The reference result has the conditional jump at the top. *) let (* Where to jump from the bottom of the loop. *) val loop_top_address = pc (* Generate code to evaluate the condition. *) val () = traverse (contents.node_left, code, pc) (* Generate a conditional jump. It will be filled in later to go past the end of the loop. *) val pjump = add_jump ("jz", code, pc) (* Generate code for the loop body. *) val () = traverse (contents.node_right, code, pc) (* Generate a jump to the top of the loop. *) val () = add_filled_jump ("jmp", loop_top_address, code, pc) (* Fill in the conditional jump to come here. *) val () = fill_jump (pjump, pc) in end and sequence (contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = begin traverse (contents.node_left, code, pc); traverse (contents.node_right, code, pc) end and assign (contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = case- contents.node_left of | ast_node_t_nonnil ident_contents => let val variable_no = ident_contents.node_arg in traverse (contents.node_right, code, pc); add_instruction ("store", g0u2i variable_no, 5U, code, pc) end and immediate (opcode : string, contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = add_instruction (opcode, g0u2i (contents.node_arg), 5U, code, pc) and unary_op (opcode : string, contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = begin traverse (contents.node_left, code, pc); add_instruction (opcode, g0i2i ARBITRARY_INSTRUCTION_ARG, 1U, code, pc) end and binary_op (opcode : string, contents : node_contents_t, code : &List0_vt code_t >> _, pc : &ullint >> _) : void = begin traverse (contents.node_left, code, pc); traverse (contents.node_right, code, pc); add_instruction (opcode, g0i2i ARBITRARY_INSTRUCTION_ARG, 1U, code, pc) end var code : List_vt code_t = NIL var pc : ullint = g0i2u 0 in traverse (ast, code, pc); add_instruction ("halt", g0i2i ARBITRARY_INSTRUCTION_ARG, 1U, code, pc); (* The code is a cons-list, in decreasing-address order, so reverse it to put the instructions in increasing-address order. *) list_vt_reverse code end fn print_code (outf : FILEref, code : !List_vt code_t) : void = let fun loop {n : nat} .. (code : !list_vt (code_t, n)) : void = case+ code of | NIL => () | ref_instr :: tail => let val @{ address = address, opcode = opcode, arg = arg } = !ref_instr in fprint! (outf, address, " "); fprint! (outf, opcode); if opcode = "push" then fprint! (outf, " ", arg) else if opcode = "fetch" || opcode = "store" then fprint! (outf, " [", arg, "]") else if opcode = "jmp" || opcode = "jz" then begin fprint! (outf, " (", arg, ") "); if arg < g1i2i 0 then let val offset : ullint = g0i2u (~arg) val () = assertloc (offset <= succ address) in fprint! (outf, succ address - offset) end else let val offset : ullint = g0i2u arg in fprint! (outf, succ address + offset) end end; fprintln! (outf); loop tail end prval () = lemma_list_vt_param code in loop code end (*------------------------------------------------------------------*) fn main_program (inpf : FILEref, outf : FILEref) : int = let var idents : List_vt string = NIL var strings : List_vt string = NIL val ast = load_ast (inpf, idents, strings) val code = generate_code ast val () = fprintln! (outf, "Datasize: ", length idents, " Strings: ", length strings) val () = print_strings (outf, strings) val () = print_code (outf, code) val () = free idents and () = free strings and () = free code in 0 end implement main (argc, argv) = let val inpfname = if 2 <= argc then $UN.cast{string} argv[1] else "-" val outfname = if 3 <= argc then $UN.cast{string} argv[2] else "-" val inpf = if (inpfname : string) = "-" then stdin_ref else fileref_open_exn (inpfname, file_mode_r) val outf = if (outfname : string) = "-" then stdout_ref else fileref_open_exn (outfname, file_mode_w) in main_program (inpf, outf) end (*------------------------------------------------------------------*)