03 基于 AST 的四则运算解释器 | 《Let’s Build A Simple Interpreter》

本文最后更新于:2021年8月26日

IR 与 AST

在上一篇文章中,interpreter 的代码和 parser 的代码是混在一起的,且 interpreter 在 parser 识别出一个如加减乘除之类的语言结构之后会立刻对它进行求值。这种 interpreter 被称为语法导向解释器(syntax-directed interpreter)。它们通常在输入上做一个 pass 且只适合基础的语言应用。为了分析更复杂的编程语言 Pascal 的结构,需要建立一个中间表达(intermediate representation, IR)。parser 会 负责构建 IR 而 interpreter 会用来解释由 IR 所代表的输入。

树是一个表示 IR 非常合适的数据结构。本系列中用到的 IR 被称为抽象语法树(abstract-syntax tree, AST)。下面是表达式 2 * 7 + 3 的带有解释的树形表示:

关于运算的优先级,可以看下图中这个例子,左边是表达式 2 * 7 + 3 的 AST,在右边是表达式 2 * (7 + 3) 的 AST:

可以看出更高优先级的操作符在树中的位置更低。后序遍历(左子树->右子树->根节点)这个树可以对它所代表的表达式进行求值。

访问者模式

class NodeVisitor():
    def visit(self, node):
	method_name = 'visit_' + type(node).__name__              # 拼凑方法名称,如 visit_BinOp
	visitor = getattr(self, method_name, self.generic_visit)  # 获取对应的方法
	return visitor(node)                                      # 调用方法

    def generic_visit(self, node):
	raise Exception('No visit_{} method'.format(type(node).__name__))


class Interpreter(NodeVisitor):
    def __init__(self, parser):
	self.parser = parser

    def visit_BinOp(self, node):
	if node.op.type == PLUS:
	    return self.visit(node.left) + self.visit(node.right)
	elif node.op.type == MINUS:
	    return self.visit(node.left) - self.visit(node.right)
	elif node.op.type == MUL:
	    return self.visit(node.left) * self.visit(node.right)
	elif node.op.type == DIV:
	    return self.visit(node.left) / self.visit(node.right)

    def visit_Num(self, node):
	return node.value

关于以上代码有两点值得在这里提一下:第一,操作 AST 结点的访问者的代码和 AST 结点自身解耦了。可以看到没有一个 AST 结点类(BinOp 和 Num)提供了操作存储在这些结点中数据的代码。该逻辑被封装在了实现 NodeVisitor 的 Interpreter 类中。

Python 的标准模块 ast 也使用了相同的机制来遍历结点

一元操作符

一元操作符示例如下:

5 -  -  - 2  = 5 - (- (- (2))) = 5 - (- (2)) = 5 - 2 = 3
  |   \  \
   \   \  unary minus(negation)
    \   unary minus(negation)
     binary plus(addition)

支持一元操作符需要修改 factor 规则,修改前和修改后的规则如下:

修改前:
factor : INTEGER | LPAREN expr RPAREN

修改后:
factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN

修改后的 factor 规则使它引用了自己,这样就可以派生出类似 “- - - + - 3” 的表达式,一个包含很多一元操作符的合法表达式。

完整代码

""" SPI - Simple Pascal Interpreter """

###############################################################################
#                                                                             #
#  LEXER                                                                      #
#                                                                             #
###############################################################################

# Token types
#
# EOF (end-of-file) token is used to indicate that
# there is no more input left for lexical analysis
INTEGER, PLUS, MINUS, MUL, DIV, LPAREN, RPAREN, EOF = (
    'INTEGER', 'PLUS', 'MINUS', 'MUL', 'DIV', '(', ')', 'EOF'
)


class Token(object):
    def __init__(self, type, value):
        self.type = type
        self.value = value

    def __str__(self):
        """String representation of the class instance.
        Examples:
            Token(INTEGER, 3)
            Token(PLUS, '+')
            Token(MUL, '*')
        """
        return 'Token({type}, {value})'.format(
            type=self.type,
            value=repr(self.value)
        )

    def __repr__(self):
        return self.__str__()


class Lexer(object):
    def __init__(self, text):
        # client string input, e.g. "4 + 2 * 3 - 6 / 2"
        self.text = text
        # self.pos is an index into self.text
        self.pos = 0
        self.current_char = self.text[self.pos]

    def error(self):
        raise Exception('Invalid character')

    def advance(self):
        """Advance the `pos` pointer and set the `current_char` variable."""
        self.pos += 1
        if self.pos > len(self.text) - 1:
            self.current_char = None  # Indicates end of input
        else:
            self.current_char = self.text[self.pos]

    def skip_whitespace(self):
        while self.current_char is not None and self.current_char.isspace():
            self.advance()

    def integer(self):
        """Return a (multidigit) integer consumed from the input."""
        result = ''
        while self.current_char is not None and self.current_char.isdigit():
            result += self.current_char
            self.advance()
        return int(result)

    def get_next_token(self):
        """Lexical analyzer (also known as scanner or tokenizer)
        This method is responsible for breaking a sentence
        apart into tokens. One token at a time.
        """
        while self.current_char is not None:

            if self.current_char.isspace():
                self.skip_whitespace()
                continue

            if self.current_char.isdigit():
                return Token(INTEGER, self.integer())

            if self.current_char == '+':
                self.advance()
                return Token(PLUS, '+')

            if self.current_char == '-':
                self.advance()
                return Token(MINUS, '-')

            if self.current_char == '*':
                self.advance()
                return Token(MUL, '*')

            if self.current_char == '/':
                self.advance()
                return Token(DIV, '/')

            if self.current_char == '(':
                self.advance()
                return Token(LPAREN, '(')

            if self.current_char == ')':
                self.advance()
                return Token(RPAREN, ')')

            self.error()

        return Token(EOF, None)


###############################################################################
#                                                                             #
#  PARSER                                                                     #
#                                                                             #
###############################################################################

class AST(object):
    pass


class BinOp(AST):
    def __init__(self, left, op, right):
        self.left = left
        self.token = self.op = op
        self.right = right


class Num(AST):
    def __init__(self, token):
        self.token = token
        self.value = token.value


class UnaryOp(AST):
    def __init__(self, op, expr):
        self.token = self.op = op
        self.expr = expr


class Parser(object):
    def __init__(self, lexer):
        self.lexer = lexer
        # set current token to the first token taken from the input
        self.current_token = self.lexer.get_next_token()

    def error(self):
        raise Exception('Invalid syntax')

    def eat(self, token_type):
        # compare the current token type with the passed token
        # type and if they match then "eat" the current token
        # and assign the next token to the self.current_token,
        # otherwise raise an exception.
        if self.current_token.type == token_type:
            self.current_token = self.lexer.get_next_token()
        else:
            self.error()

    def factor(self):
        """factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN"""
        token = self.current_token
        if token.type == PLUS:
            self.eat(PLUS)
            node = UnaryOp(token, self.factor())
            return node
        elif token.type == MINUS:
            self.eat(MINUS)
            node = UnaryOp(token, self.factor())
            return node
        elif token.type == INTEGER:
            self.eat(INTEGER)
            return Num(token)
        elif token.type == LPAREN:
            self.eat(LPAREN)
            node = self.expr()
            self.eat(RPAREN)
            return node

    def term(self):
        """term : factor ((MUL | DIV) factor)*"""
        node = self.factor()

        while self.current_token.type in (MUL, DIV):
            token = self.current_token
            if token.type == MUL:
                self.eat(MUL)
            elif token.type == DIV:
                self.eat(DIV)

            node = BinOp(left=node, op=token, right=self.factor())

        return node

    def expr(self):
        """
        expr   : term ((PLUS | MINUS) term)*
        term   : factor ((MUL | DIV) factor)*
        factor : (PLUS | MINUS) factor | INTEGER | LPAREN expr RPAREN
        """
        node = self.term()

        while self.current_token.type in (PLUS, MINUS):
            token = self.current_token
            if token.type == PLUS:
                self.eat(PLUS)
            elif token.type == MINUS:
                self.eat(MINUS)

            node = BinOp(left=node, op=token, right=self.term())

        return node

    def parse(self):
        node = self.expr()
        if self.current_token.type != EOF:
            self.error()
        return node


###############################################################################
#                                                                             #
#  INTERPRETER                                                                #
#                                                                             #
###############################################################################

class NodeVisitor(object):
    def visit(self, node):
        method_name = 'visit_' + type(node).__name__
        visitor = getattr(self, method_name, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        raise Exception('No visit_{} method'.format(type(node).__name__))


class Interpreter(NodeVisitor):
    def __init__(self, parser):
        self.parser = parser

    def visit_BinOp(self, node):
        if node.op.type == PLUS:
            return self.visit(node.left) + self.visit(node.right)
        elif node.op.type == MINUS:
            return self.visit(node.left) - self.visit(node.right)
        elif node.op.type == MUL:
            return self.visit(node.left) * self.visit(node.right)
        elif node.op.type == DIV:
            return self.visit(node.left) // self.visit(node.right)

    def visit_Num(self, node):
        return node.value

    def visit_UnaryOp(self, node):
        op = node.op.type
        if op == PLUS:
            return +self.visit(node.expr)
        elif op == MINUS:
            return -self.visit(node.expr)

    def interpret(self):
        tree = self.parser.parse()
        if tree is None:
            return ''
        return self.visit(tree)


def main():
    while True:
        try:
            text = input('spi> ')
        except EOFError:
            break
        if not text:
            continue

        lexer = Lexer(text)
        parser = Parser(lexer)
        interpreter = Interpreter(parser)
        result = interpreter.interpret()
        print(result)


if __name__ == '__main__':
    main()