250 lines
8.0 KiB
Nim
250 lines
8.0 KiB
Nim
import ast_lexer
|
|
|
|
type NodeKind* = enum
|
|
nIdentifier = "Identifier"
|
|
nString = "String"
|
|
nInteger = "Integer"
|
|
nSequence = "Sequence"
|
|
nIf = "If"
|
|
nPrtc = "Prtc"
|
|
nPrts = "Prts"
|
|
nPrti = "Prti"
|
|
nWhile = "While"
|
|
nAssign = "Assign"
|
|
nNegate = "Negate"
|
|
nNot = "Not"
|
|
nMultiply = "Multiply"
|
|
nDivide = "Divide"
|
|
nMod = "Mod"
|
|
nAdd = "Add"
|
|
nSubtract = "Subtract"
|
|
nLess = "Less"
|
|
nLessEqual = "LessEqual"
|
|
nGreater = "Greater"
|
|
nGreaterEqual = "GreaterEqual"
|
|
nEqual = "Equal"
|
|
nNotEqual = "NotEqual"
|
|
nAnd = "And"
|
|
nOr = "Or"
|
|
|
|
type Node* = ref object
|
|
left*: Node
|
|
right*: Node
|
|
case kind*: NodeKind
|
|
of nString: stringVal*: string
|
|
of nInteger: intVal*: int
|
|
of nIdentifier: name*: string
|
|
else: nil
|
|
|
|
type Operator = range[tokMult..tokOr]
|
|
|
|
const
|
|
|
|
Precedences: array[Operator, int] = [13, # tokMult
|
|
13, # tokDiv
|
|
13, # tokMod
|
|
12, # tokAdd
|
|
12, # tokSub
|
|
10, # tokLess
|
|
10, # tokLessEq
|
|
10, # tokGreater
|
|
10, # tokGreaterEq
|
|
9, # tokEq
|
|
9, # tokNeq
|
|
14, # tokNot
|
|
-1, # tokAssign
|
|
5, # tokAnd
|
|
4] # tokOr
|
|
UnaryPrecedence = 14
|
|
BinaryOperators = {tokMult, tokDiv, tokMod, tokAdd, tokSub, tokLess, tokLessEq,
|
|
tokGreater, tokGreaterEq, tokEq, tokNotEq, tokAnd, tokOr}
|
|
|
|
# Mapping of operators from TokenKind to NodeKind.
|
|
NodeKinds: array[Operator, NodeKind] = [nMultiply, nDivide, nMod, nAdd, nSubtract,
|
|
nLess, nLessEqual, nGreater, nGreaterEqual,
|
|
nEqual, nNotEqual, nNot, nAssign, nAnd, nOr]
|
|
|
|
type SyntaxError* = object of CatchableError
|
|
|
|
|
|
####################################################################################################
|
|
|
|
template expect(token: Token; expected: TokenKind; errmsg: string) =
|
|
## Check if a token is of the expected kind.
|
|
## Raise a SyntaxError if this is not the case.
|
|
if token.kind != expected:
|
|
raise newException(SyntaxError, "line " & $token.ln & ": " & errmsg)
|
|
token = lexer.next()
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
proc newNode*(kind: NodeKind; left: Node; right: Node = nil): Node =
|
|
## Create a new node with given left and right children.
|
|
result = Node(kind: kind, left: left, right: right)
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
# Forward reference.
|
|
proc parExpr(lexer: var Lexer; token: var Token): Node
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
proc expr(lexer: var Lexer; token: var Token; p: int): Node =
|
|
## Parse an expression.
|
|
|
|
case token.kind
|
|
|
|
of tokLPar:
|
|
result = parExpr(lexer, token)
|
|
|
|
of tokAdd, tokSub, tokNot:
|
|
# Unary operators.
|
|
let savedToken = token
|
|
token = lexer.next()
|
|
let e = expr(lexer, token, UnaryPrecedence)
|
|
if savedToken.kind == tokAdd:
|
|
result = e
|
|
else:
|
|
result = newNode(if savedToken.kind == tokSub: nNegate else: nNot, e)
|
|
|
|
of tokIdent:
|
|
result = Node(kind: nIdentifier, name: token.ident)
|
|
token = lexer.next()
|
|
|
|
of tokInt:
|
|
result = Node(kind:nInteger, intVal: token.intVal)
|
|
token = lexer.next()
|
|
|
|
of tokChar:
|
|
result = Node(kind:nInteger, intVal: ord(token.charVal))
|
|
token = lexer.next()
|
|
|
|
else:
|
|
raise newException(SyntaxError, "Unexpected symbol at line " & $token.ln)
|
|
|
|
# Process the binary operators in the expression.
|
|
while token.kind in BinaryOperators and Precedences[token.kind] >= p:
|
|
let savedToken = token
|
|
token = lexer.next()
|
|
let q = Precedences[savedToken.kind] + 1 # No operator is right associative.
|
|
result = newNode(NodeKinds[savedToken.kind], result, expr(lexer, token, q))
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
proc parExpr(lexer: var Lexer; token: var Token): Node =
|
|
## Parse a parenthetized expression.
|
|
token.expect(tokLPar, "'(' expected")
|
|
result = expr(lexer, token, 0)
|
|
token.expect(tokRPar, "')' expected")
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
proc stmt(lexer: var Lexer; token: var Token): Node =
|
|
## Parse a statement.
|
|
|
|
case token.kind:
|
|
|
|
of tokIf:
|
|
token = lexer.next()
|
|
let e = parExpr(lexer, token)
|
|
let thenNode = stmt(lexer, token)
|
|
var elseNode: Node = nil
|
|
if token.kind == tokElse:
|
|
token = lexer.next()
|
|
elseNode = stmt(lexer, token)
|
|
result = newNode(nIf, e, newNode(nIf, thenNode, elseNode))
|
|
|
|
of tokPutc:
|
|
token = lexer.next()
|
|
result = newNode(nPrtc, parExpr(lexer, token))
|
|
token.expect(tokSemi, "';' expected")
|
|
|
|
of tokPrint:
|
|
token = lexer.next()
|
|
token.expect(tokLPar, "'(' expected")
|
|
while true:
|
|
var e: Node
|
|
if token.kind == tokString:
|
|
e = newNode(nPrts, Node(kind: nString, stringVal: token.stringVal))
|
|
token = lexer.next()
|
|
else:
|
|
e = newNode(nPrti, expr(lexer, token, 0))
|
|
result = newNode(nSequence, result, e)
|
|
if token.kind == tokComma:
|
|
token = lexer.next()
|
|
else:
|
|
break
|
|
token.expect(tokRPar, "')' expected")
|
|
token.expect(tokSemi, "';' expected")
|
|
|
|
of tokSemi:
|
|
token = lexer.next()
|
|
|
|
of tokIdent:
|
|
let v = Node(kind: nIdentifier, name: token.ident)
|
|
token = lexer.next()
|
|
token.expect(tokAssign, "'=' expected")
|
|
result = newNode(nAssign, v, expr(lexer, token, 0))
|
|
token.expect(tokSemi, "';' expected")
|
|
|
|
of tokWhile:
|
|
token = lexer.next()
|
|
let e = parExpr(lexer, token)
|
|
result = newNode(nWhile, e, stmt(lexer, token))
|
|
|
|
of tokLBrace:
|
|
token = lexer.next()
|
|
while token.kind notin {tokRBrace, tokEnd}:
|
|
result = newNode(nSequence, result, stmt(lexer, token))
|
|
token.expect(tokRBrace, "'}' expected")
|
|
|
|
of tokEnd:
|
|
discard
|
|
|
|
else:
|
|
raise newException(SyntaxError, "Unexpected symbol at line " & $token.ln)
|
|
|
|
#---------------------------------------------------------------------------------------------------
|
|
|
|
proc parse*(code: string): Node =
|
|
## Parse the code provided.
|
|
|
|
var lexer = initLexer(code)
|
|
var token = lexer.next()
|
|
while true:
|
|
result = newNode(nSequence, result, stmt(lexer, token))
|
|
if token.kind == tokEnd:
|
|
break
|
|
|
|
#———————————————————————————————————————————————————————————————————————————————————————————————————
|
|
|
|
when isMainModule:
|
|
|
|
import os, strformat, strutils
|
|
|
|
proc printAst(node: Node) =
|
|
## Print tha AST in linear form.
|
|
|
|
if node.isNil:
|
|
echo ';'
|
|
|
|
else:
|
|
stdout.write &"{$node.kind:<14}"
|
|
case node.kind
|
|
of nIdentifier:
|
|
echo node.name
|
|
of nInteger:
|
|
echo node.intVal
|
|
of nString:
|
|
# Need to escape and to replace hexadecimal \x0A by \n.
|
|
echo escape(node.stringVal).replace("\\x0A", "\\n")
|
|
else:
|
|
echo ""
|
|
node.left.printAst()
|
|
node.right.printAst()
|
|
|
|
|
|
let code = if paramCount() < 1: stdin.readAll() else: paramStr(1).readFile()
|
|
let tree = parse(code)
|
|
tree.printAst()
|