03 基于 AST 的四则运算解释器 | 《Let’s Build A Simple Interpreter》
本文最后更新于:2021年8月26日
IR 与 AST
在上一篇文章中,interpreter 的代码和 parser 的代码是混在一起的,且 interpreter 在 parser 识别出一个如加减乘除之类的语言结构之后会立刻对它进行求值。这种 interpreter 被称为语法导向解释器(syntax-directed interpreter)。它们通常在输入上做一个 pass 且只适合基础的语言应用。为了分析更复杂的编程语言 Pascal 的结构,需要建立一个中间表达(intermediate representation, IR)。parser 会 负责构建 IR 而 interpreter 会用来解释由 IR 所代表的输入。
树是一个表示 IR 非常合适的数据结构。本系列中用到的 IR 被称为抽象语法树(abstract-syntax tree, AST)。下面是表达式 2 * 7 + 3 的带有解释的树形表示:
关于运算的优先级,可以看下图中这个例子,左边是表达式 2 * 7 + 3 的 AST,在右边是表达式 2 * (7 + 3) 的 AST:
可以看出更高优先级的操作符在树中的位置更低。后序遍历(左子树->右子树->根节点)这个树可以对它所代表的表达式进行求值。
访问者模式
class NodeVisitor():
def visit(self, node):
method_name = 'visit_' + type(node).__name__ # 拼凑方法名称,如 visit_BinOp
visitor = getattr(self, method_name, self.generic_visit) # 获取对应的方法
return visitor(node) # 调用方法
def generic_visit(self, node):
raise Exception('No visit_{} method'.format(type(node).__name__))
class Interpreter(NodeVisitor):
def __init__(self, parser):
self.parser = parser
def visit_BinOp(self, node):
if node.op.type == PLUS:
return self.visit(node.left) + self.visit(node.right)
elif node.op.type == MINUS:
return self.visit(node.left) - self.visit(node.right)
elif node.op.type == MUL:
return self.visit(node.left) * self.visit(node.right)
elif node.op.type == DIV:
return self.visit(node.left) / self.visit(node.right)
def visit_Num(self, node):
return node.value
关于以上代码有两点值得在这里提一下:第一,操作 AST 结点的访问者的代码和 AST 结点自身解耦了。可以看到没有一个 AST 结点类(BinOp 和 Num)提供了操作存储在这些结点中数据的代码。该逻辑被封装在了实现 NodeVisitor 的 Interpreter 类中。
Python 的标准模块 ast 也使用了相同的机制来遍历结点
一元操作符
一元操作符示例如下:
5 - - - 2 = 5 - (- (- (2))) = 5 - (- (2)) = 5 - 2 = 3
| \ \
\ \ unary minus(negation)
\ unary minus(negation)
binary plus(addition)
支持一元操作符需要修改 factor 规则,修改前和修改后的规则如下:
修改前:
factor : INTEGER | LPAREN expr RPAREN
修改后:
factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN
修改后的 factor 规则使它引用了自己,这样就可以派生出类似 “- - - + - 3” 的表达式,一个包含很多一元操作符的合法表达式。
完整代码
""" SPI - Simple Pascal Interpreter """
###############################################################################
# #
# LEXER #
# #
###############################################################################
# Token types
#
# EOF (end-of-file) token is used to indicate that
# there is no more input left for lexical analysis
INTEGER, PLUS, MINUS, MUL, DIV, LPAREN, RPAREN, EOF = (
'INTEGER', 'PLUS', 'MINUS', 'MUL', 'DIV', '(', ')', 'EOF'
)
class Token(object):
def __init__(self, type, value):
self.type = type
self.value = value
def __str__(self):
"""String representation of the class instance.
Examples:
Token(INTEGER, 3)
Token(PLUS, '+')
Token(MUL, '*')
"""
return 'Token({type}, {value})'.format(
type=self.type,
value=repr(self.value)
)
def __repr__(self):
return self.__str__()
class Lexer(object):
def __init__(self, text):
# client string input, e.g. "4 + 2 * 3 - 6 / 2"
self.text = text
# self.pos is an index into self.text
self.pos = 0
self.current_char = self.text[self.pos]
def error(self):
raise Exception('Invalid character')
def advance(self):
"""Advance the `pos` pointer and set the `current_char` variable."""
self.pos += 1
if self.pos > len(self.text) - 1:
self.current_char = None # Indicates end of input
else:
self.current_char = self.text[self.pos]
def skip_whitespace(self):
while self.current_char is not None and self.current_char.isspace():
self.advance()
def integer(self):
"""Return a (multidigit) integer consumed from the input."""
result = ''
while self.current_char is not None and self.current_char.isdigit():
result += self.current_char
self.advance()
return int(result)
def get_next_token(self):
"""Lexical analyzer (also known as scanner or tokenizer)
This method is responsible for breaking a sentence
apart into tokens. One token at a time.
"""
while self.current_char is not None:
if self.current_char.isspace():
self.skip_whitespace()
continue
if self.current_char.isdigit():
return Token(INTEGER, self.integer())
if self.current_char == '+':
self.advance()
return Token(PLUS, '+')
if self.current_char == '-':
self.advance()
return Token(MINUS, '-')
if self.current_char == '*':
self.advance()
return Token(MUL, '*')
if self.current_char == '/':
self.advance()
return Token(DIV, '/')
if self.current_char == '(':
self.advance()
return Token(LPAREN, '(')
if self.current_char == ')':
self.advance()
return Token(RPAREN, ')')
self.error()
return Token(EOF, None)
###############################################################################
# #
# PARSER #
# #
###############################################################################
class AST(object):
pass
class BinOp(AST):
def __init__(self, left, op, right):
self.left = left
self.token = self.op = op
self.right = right
class Num(AST):
def __init__(self, token):
self.token = token
self.value = token.value
class UnaryOp(AST):
def __init__(self, op, expr):
self.token = self.op = op
self.expr = expr
class Parser(object):
def __init__(self, lexer):
self.lexer = lexer
# set current token to the first token taken from the input
self.current_token = self.lexer.get_next_token()
def error(self):
raise Exception('Invalid syntax')
def eat(self, token_type):
# compare the current token type with the passed token
# type and if they match then "eat" the current token
# and assign the next token to the self.current_token,
# otherwise raise an exception.
if self.current_token.type == token_type:
self.current_token = self.lexer.get_next_token()
else:
self.error()
def factor(self):
"""factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN"""
token = self.current_token
if token.type == PLUS:
self.eat(PLUS)
node = UnaryOp(token, self.factor())
return node
elif token.type == MINUS:
self.eat(MINUS)
node = UnaryOp(token, self.factor())
return node
elif token.type == INTEGER:
self.eat(INTEGER)
return Num(token)
elif token.type == LPAREN:
self.eat(LPAREN)
node = self.expr()
self.eat(RPAREN)
return node
def term(self):
"""term : factor ((MUL | DIV) factor)*"""
node = self.factor()
while self.current_token.type in (MUL, DIV):
token = self.current_token
if token.type == MUL:
self.eat(MUL)
elif token.type == DIV:
self.eat(DIV)
node = BinOp(left=node, op=token, right=self.factor())
return node
def expr(self):
"""
expr : term ((PLUS | MINUS) term)*
term : factor ((MUL | DIV) factor)*
factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN
"""
node = self.term()
while self.current_token.type in (PLUS, MINUS):
token = self.current_token
if token.type == PLUS:
self.eat(PLUS)
elif token.type == MINUS:
self.eat(MINUS)
node = BinOp(left=node, op=token, right=self.term())
return node
def parse(self):
node = self.expr()
if self.current_token.type != EOF:
self.error()
return node
###############################################################################
# #
# INTERPRETER #
# #
###############################################################################
class NodeVisitor(object):
def visit(self, node):
method_name = 'visit_' + type(node).__name__
visitor = getattr(self, method_name, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
raise Exception('No visit_{} method'.format(type(node).__name__))
class Interpreter(NodeVisitor):
def __init__(self, parser):
self.parser = parser
def visit_BinOp(self, node):
if node.op.type == PLUS:
return self.visit(node.left) + self.visit(node.right)
elif node.op.type == MINUS:
return self.visit(node.left) - self.visit(node.right)
elif node.op.type == MUL:
return self.visit(node.left) * self.visit(node.right)
elif node.op.type == DIV:
return self.visit(node.left) // self.visit(node.right)
def visit_Num(self, node):
return node.value
def visit_UnaryOp(self, node):
op = node.op.type
if op == PLUS:
return +self.visit(node.expr)
elif op == MINUS:
return -self.visit(node.expr)
def interpret(self):
tree = self.parser.parse()
if tree is None:
return ''
return self.visit(tree)
def main():
while True:
try:
text = input('spi> ')
except EOFError:
break
if not text:
continue
lexer = Lexer(text)
parser = Parser(lexer)
interpreter = Interpreter(parser)
result = interpreter.interpret()
print(result)
if __name__ == '__main__':
main()
评论系统采用 utterances ,加载有延迟,请稍等片刻。