《JIT Technical Overview》翻译 | PyTorch

本文是 PyTorch 官方文档 JIT Technical Overview 的中文翻译。对于看不太明白的部分,暂不翻译,直接放上原文。

概述

JIT 可以独立于 Python 解释器运行和优化 PyTorch 程序。本文根据组件被组织成各个小节:

  • 核心程序表示:JIT 执行的是 TorchScript,在语法上属于 Python 的子集。本节描述了TorchScript 程序如何在 JIT 中表示,并用作 JIT 组件之间的交换格式
  • 生成程序:可以通过跟踪 Python 代码或直接编写 TorchScript 来创建 TorchScript 程序。本节介绍如何从这些前端创建模型
  • 执行程序:创建程序后,将优化并运行 TorchScript 模型。由于这是一个即时编译器,程序在执行时会进行优化,因此本节将介绍如何优化程序以及如何运行程序
  • 保存程序:TorchScript 通常是由 Python 创建的,然后在 C++ 中使用。本节介绍保存和加载过程的工作方式
  • Python 绑定:TorchScript 代码通常是从 Python 创建和使用的,因此本节描述了 Python 组件如何与 TorchScript 中的代码交互

核心程序表示

Modules

api/module.h

在顶层设计上,所有 TorchScript 程序都表示为一个 Module,包括:

  • named Parameters:训练中用到的 Tensor,比如 weight 或者 bias
  • named Buffers:作为 Module 训练状态的一部分,但不出现在 module.parameters() 中且不参与梯度下降的 Tensor
  • named sub-Modules:用于代码组织
  • named Attributes:没有出现在上述三种类别中的所有其他属性,通常用于配置,并且不会在 state_dict 中保存
  • named Methods:Module 中可以运行的函数,比如 forward

所有 TorchScrip 代码都是某个 Module 的成员。这包括纯函数,例如通过使用@torch.jit.script 注释 Python 函数创建的函数,这个函数在内部表示为一个 Module,该且只有一个方法 forward,即函数本身的实现。

Method

api/module.h

Method 是 TorchScript 中的一段代码,其接受多个参数并且产生输出值。Method 有几个子组件:

  • FunctionSchema:描述了输入参数和返回值的类型和名称
  • member_inputs 列表:描述了 Method 访问了哪些 Parameter(对于纯函数该列表为空)
  • Graph 对象:描述 Method 内部的实际代码
  • GraphExecutor:执行上述 Graph 对象

Method 使用的 Parameter 在运行之前作为附加输入添加到此图中。这允许 GraphExecutor 出于优化和执行的目的将 method inputs 和 method parameters 同等对待,从而简化了执行程序的过程。

Methods also contain helper functions for inserting calls to the Method from other Method objects.

Method 还包括一些帮助函数,用于插入其他 Method 对象对该 Method 对象的调用。

FunctionSchema

aten/src/ATen/core/function_schema.h

每个 Method 都有一个 FunctionSchema,用于描述参数和返回值的类型。Operator 也有 FunctionSchema。FunctionSchema 类似 C++ 中的声明,描述了如何调用该函数但是没有实现。

Graph

ir.h

Graph 由 Node、Block 和 Value 组成。Node 是指令(例如「做矩阵乘法」)。Block 则是顺序执行的多个 Node。Node 的输入输出都是 Value 的列表。

下面展示一个例子:

1
2
3
4
5
6
@torch.jit.script
def f(a, b):
c = a + b
d = c * c
e = torch.tanh(d * c)
return d + (e + e)

对应的 Graph 是这样的:

1
2
3
4
5
6
7
8
9
10
graph(%a.1 : Tensor,
%b.1 : Tensor):
%4 : int = prim::Constant[value=1]()
%c.1 : Tensor = aten::add(%a.1, %b.1, %4)
%d.1 : Tensor = aten::mul(%c.1, %c.1)
%11 : Tensor = aten::mul(%d.1, %c.1)
%e.1 : Tensor = aten::tanh(%11)
%17 : Tensor = aten::add(%e.1, %e.1, %4)
%19 : Tensor = aten::add(%d.1, %17, %4)
return (%19)

这是 IR 的规范文本表示,可以在其中找到部分前述元素的对应:

  • graph 是 Graph
  • %c.1 是 Value
  • %c.1 : Tensor 是带有类型标注的 Value
  • %c.1 : Tensor = aten::add(%a.1, %b.1, %4)) 则是算子 aten::add 对应的 Node,其接受 %a.1%b.1%4 三个 Value 作为输入,并返回 %c.1 这个 Value 作为输出

最后,可以为 Node 赋予额外的信息,这些信息称为属性。在上面的例子中,属性在 prim::Constant 节点中使用,该节点在被调用时返回属性值。

JIT 中的 Graph 是单静态赋值(SSA)形式的,这意味着每个 Value 正好有一个可以直接从其中查找的 Node:Node* n = v.node()

所有权模型:Block、Node 和 Value 由它们出现在其中的 Graph 所有,并且可能只出现在单个Graph 中。Node 的创建和删除是通过 Graph 对象完成的(例如 Graph::create)完成的。Bloack 和 Value 的创建和删除是通过 Node 对象完成的(例如 Node::addOutput、Node::addBlock)完成的。某些一致性属性还会被强制执行,例如,Node::destroy 会移除一个 Node,但只有在不再使用该 Node 生成的 Value 时才能调用此函数,这可以使用 Value::replaceAllUseWith 等其他函数来完成。

因为 Graph 拥有它的所有 Block、Node 和 Value,所以这些对象总是由原始指针传递。一般来说,开发人员不应该编写无限期地持有 Block、Node 或 Value 对象的代码,也不应该同时持有它们所在的 Graph 的 shared_ptr。

Node

ir.h

Node 表示单个内置算子,如矩阵乘法或卷积。通过例如 NodeKind Node::kind() 的方式来表示。不同的运算符(例如卷积和矩阵乘法)由不同的种类表示,而不是像在 LLVM 中那样通过将 Node 子类化来表示。NodeKind 是一个 Symbol 对象,它只是某个命名空间中的一个驻留字符串。Symbol 可以从字符串创建,例如通过 Symbol::fromQualString("aten::add"),因此 Node 的取值不是一个预定义的集合。选择这种设计是为了允许新算子和用户自定义算子的开放注册。

Node 生成输出 Value,并将输入 Value 作为参数。Node 节点可以产生多个输出。例如 prim::TupleUnpack 将一个元组拆分成多个元素,因此它的输出数量等于该元组的成员数量。尽管节点可能有多个输出,但每个节点的输出数量是静态已知的。可能产生动态量结果的操作,例如将 Tensor 分割成大小为 2 的块,将被表示为产生列表对象的算子。

Because Nodes are not subclassed per-operator, it is very easy to construct invalid Nodes, e.g. by forgetting an input or an output, or by passing Values of the wrong Type. To help avoid this, Graph provides the method Graph::insert for constructing Nodes that guarantees Nodes have the correct setup. This method uses the database of registered Operators and their FunctionSchema to construct Nodes using that schema.

PyTorch IR 支持函数重载,这意味着单个 NodeKind 可能对应多个运算符。例如,aten::add 类型具有以下重载(标量表示为 float 或 int):

  • aten::add(Tensor self, Tensor other) -> Tensor
  • aten::add(Tensor self, Scalar other) -> Tensor
  • aten::add(int self, int other) -> int
  • aten::add(float self, float other) -> float

对于表示内置算子的 Node,Node::schema 方法也可以查找该算子注册的 FunctionSchema。每个重载对应一个不同的 FunctionSchema 对象。可以使用 schema() 方法查询节点的 schema(将检查节点的参数类型,并尝试为其匹配)。

每个 Node 还具有一组属性,这些属性被命名为整型、字符串、浮点、Tensor、子图或这些类型的列表。这些被特殊的原生算子用来编码 Node 中的附加数据。例如 prim::Constant 定义了一个编译时常量值。对于 Tensor 常量,它将有一个名为 attr::value 的属性,其中包含常量的值。

属性很少使用。卷积或矩阵乘法等运算符没有属性,并且通过输入列表获取参数。这包括通常被认为是常数的东西,比如卷积的 stride。在 PyTorch 中,这些信息中的任何一个都可能是程序的动态属性,因此节点总是以允许动态确定这些值的方式编码。许多输入几乎总是常量,通过 c10::optional<IValue> Node::get(Symbol name) 获取输入的值可以检查其是否为常量:如果是常量,则返回 IValue(输入的具体值),否则返回 nullopt。

Block

ir.h

Node 被组织成在 Block 内顺序执行的列表。Graph 本身有一个顶层 Block 概念 Graph::block(),控制流节点(prim::If 和 prim::Loop)也有子 Block。虽然可以设计一个 Node 无序的 Graph 表示,但是当所有 Node 都有特定的规范顺序时,调试和理解 Block 要容易得多。在进行 pass 优化时,如果更改顺序但是保留语义(很像乱序处理器)可以提升性能,则允许解释器乱序执行 Node。有序的 Node 很容易打印,也可以轻松地逐步执行 Graph。

Values are Block-scoped. A Value is in scope for the remainder of the Block it is defined in, including in the sub-blocks of any Node defined after it. Values go out of scope at the end of the block in which they are defined.

When Nodes are inserted into a Graph, they are inserted at a special “insertion point” that is part of the state of the Graph. On construction, this will go to the end of the Graph.

每个 Block 有两个虚拟 Node,它们不包含在 Block 的节点列表中。 prim::Param 节点代表 Block 的输入,没有 prev() 或 next() 节点。 prim::Return 节点表示 Block 的输出。Block 中的节点列表被实现为一个循环链表,其中 prim::Return 节点用作开始/结束标记。在任意位置插入和删除是高效的。开发人员也可能会遇到基于此事实的 IR 对象内部的实现(例如,添加 Node 到 Block 相当于将其放在 prim::Return 节点之前)。

当 Block::nodes() 列表的迭代器指向的 Node 被移动或者删除时,该迭代器无效。

If

控制流使用子 Block 而不是控制流图表示来表示。 prim::If 有一个 Block 用于 true 分支,一个 Block 用于 else。一个 prim:Loop 有一个循环体的 Block(没有条件 Block,而是循环体的末尾计算是否重新进入循环体)。这种表示确保我们有结构化的控制流。这种限制使许多优化变得更容易,并且对于绝大多数网络都是如此。一个 Node 可以查找它所在的 Block,一个 Block 可以查找其父 Block。

If-statement (prim::If) 块没有输入,输出是外部块中变量的新值,其值在 if 语句中被更改。if 语句的输出的作用类似于传统 SSA 控制流图中的 Φ 函数节点。if 语句的示例 IR 如下所示:

1
2
3
4
5
6
7
%y_1, ..., %y_r = prim::If(%condition)
block0(): # TRUE BRANCH, never takes arguments, has to return r outputs
%t_1, ..., %t_k = some::node(%a_value_from_outer_block)
-> (%t_1, ..., %t_r)
block1(): # FALSE BRANCH, never takes arguments, has to return r outputs
%f_1, ..., %f_m = some::node(%a_value_from_outer_block)
-> (%f_1, ..., %f_r)

对应于 %y_1, …, %y_r 的值将变为 %t_1, …, %t_r 或 %f_1, …, %f_r,具体取决于运行时 %condition 的值。

下面是一个 TorchScript 程序及其对应 IR 的示例:

1
2
3
4
5
6
7
8
@torch.jit.script
def f(a, b, c):
d = a + b
if c:
e = d + d
else:
e = b + d
return e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
graph(%a.1 : Tensor,
%b.1 : Tensor,
%c.1 : Tensor):
%5 : int = prim::Constant[value=1]()
%d.1 : Tensor = aten::add(%a.1, %b.1, %5)
%9 : bool = aten::Bool(%c.1)
%e : Tensor = prim::If(%9)
block0():
%e.1 : Tensor = aten::add(%d.1, %d.1, %5)
-> (%e.1)
block1():
%e.2 : Tensor = aten::add(%b.1, %d.1, %5)
-> (%e.2)
return (%e)

什么是 Φ 函数(参考自如何构建SSA形式的CFG):

在编译器设计中,静态单赋值形式(通常缩写为 SSA 形式或简称 SSA)是 IR 的属性,它要求每个变量只分配一次,并且每个变量在使用之前定义。SSA 形式的代码极大地降低了定义使用链的可能数目。在传统的非 SSA 形式的代码中,如果有 D 处定义和 U 处使用,就可能有 D×U 种可能的组合。因而 SSA 形式的代码有利于程序的优化和分析。

顺序执行的代码 SSA 形式较为简单。但程序会有分支和合并,通过在合并处插入 ϕ 函数,就能解决带分支代码的 SSA 形式。Φ 函数表示从进来的分支中选取某一个值作为新的值。如下面的代码:

1
2
3
4
5
if (p)
v = 1;
else
v = 2;
return v;

就会被转化成:

1
2
3
4
5
6
if (p)
v1 = 1;
else
v2 = 2;
v3 = phi(v1, v2);
return v3;

使用 SSA 形式中的一个分析例子是常量传播分析。常量传播分析是指分析哪些变量是常量,对于非 SSA 形式的分析,这较为困难。对于 SSA 形式,我们可以将那些使用常量定义的变量,将其所有出现的地方替换成常量,不断迭代直到到达不动点即可。

Loops

循环是用 prim::Loop 实现的,它涵盖了 while 和 for 循环。loop 语句的示例 IR 如下所示:

1
2
3
4
5
%y_1, ..., %y_r = prim::Loop(%max_trip_count, %initial_condition, %x_1, ..., %x_r)
block0(%i, %a_1, ..., %a_r):
%b_1, ..., %b_m = some::node(%a_value_from_outer_block, %a_1)
%iter_condition = some::other_node(%a_2)
-> (%iter_condition, %b_1, ..., %b_r)

解释其语义的最简单方法是参考这个类似 Python 的伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# init intermediate variable
a_1, ..., a_r = x_1, ..., x_r

# init outputs
y_1, ..., y_r = a_1, ..., a_r

i = 0
while condition and i < max_trip_count:
#################### actual body of the loop ####################
b_1, ..., b_r = some::node_1(a_1, ..., a_r)
iter_condition = some::node_2(a_1, ..., a_r)
#################################################################

# assign intermediate variable
a_1, ..., a_r = b_1, ..., b_r

# assign condition
condition = iter_condition

# assign outputs
y_1, ..., y_r = b_1, ..., b_r

# increase count
i += 1

For 循环和 while 循环都使用上述 IR。在 for 循环中 %condition 是常量 true;在 while 循环中,%max_trip_count 被设置为 int64_t 的最大值,并且不使用 %i。

下面是一个 TorchScript 程序及其对应 IR 的示例:

1
2
3
4
5
@torch.jit.script
def f(x):
for i in range(2, 10):
x += i
return x
1
2
3
4
5
6
7
8
9
10
11
graph(%x.1 : Tensor):
%5 : bool = prim::Constant[value=1]() # condition
%27 : int = prim::Constant[value=8]() # max_trip_count
%3 : int = prim::Constant[value=1]() # step
%1 : int = prim::Constant[value=2]() # start
%x : Tensor = prim::Loop(%27, %5, %x.1)
block0(%6 : int, %x.11 : Tensor):
%i.1 : int = aten::__derive_index(%6, %1, %3) # index, start, step
%x.5 : Tensor = aten::add_(%x.11, %i.1, %3)
-> (%5, %x.5)
return (%x)

With

With 语句以两种不同的方式表示。对于大多数编译和优化过程,它们表示为一对 prim::Enter 和 prim::Exit 节点,它们包装了对应于 with 语句主体的 Node(s)。然而,with 语句在 exit_transform 期间(后面会有一节专门讲 exit_transform)使用基于 Block 的表示形式临时表示,其中 prim::With 节点插入到 prim::Exit 节点之后,所有 prim::Exit 和 prim::Enter 之间的节点被移动到 prim::With 的第一个 Block 中,而 prim::Exit 被移动到 prim::With 的第二个 Block 中。例如这段代码:

1
2
with c as increment:
y = x + increment

最终被转换为:

1
2
3
4
%2 : int = prim::Constant[value=1]()
%increment.1 : int = prim::Enter(%c.1)
%y.1 : Tensor = aten::add(%x.1, %increment.1, %2)
%11 : Tensor = prim::Exit(%c.1)

并将暂时转换为:

1
2
3
4
5
6
7
8
%increment.1 : int = prim::Enter(%c.1)
= prim::With()
block0():
%y.1 : Tensor = aten::add(%x.1, %increment.1, %4)
-> ()
block1():
%11 : Tensor = prim::Exit(%c.1)
-> ()

Value

ir.h

Value 表示流经程序中操作的数据,例如矩阵乘法运算的输出。由于 SSA 形式,Value 对象始终由单个 Node 定义,即可以用 v.node() 获取到 Value 对应的 Node。对于 Block/Graph 的输入,此节点是一个特殊的 prim::Param 节点,它不会出现在 Block 的节点列表中。Value 对象会有一个 Type,这提供了静态保证,它的值将是那个类型。

Value 对象具有返回其定义和所有用途的方法:v.node()v.uses()。每个 Use 都有一个指向 Node 的指针,这个 Node 的输入列表包含该 Value。在 v.uses() 上迭代的同时修改 v 的用法时要小心,因为对 v 的每次更改都会使 v.uses() 迭代器无效。

Value 是程序中数据的抽象表示。执行时,实际的 Tensor、List、Tuple 等存储在 IValue 中,它们是 TorchScript 中所有可能值类型的 tagged union。Value 这个名字有点令人困惑,因为它看起来应该是 tagged union,但它最初来自与 llvm::Value 的类比,它的用途与 jit::Value 相同。

Type

aten/src/ATen/core/jit_type.h

与 Python 不同,TorchScript 是静态类型的,因此每个 Value 都有一个与之关联的 Type,每个 FunctionSchema 都有一个参数类型列表和一个函数的返回类型。Type 是代表 TorchScript 内置类型的 C++ 对象层次结构的基类。类型提供诸如 Type::isSubtypeOf 之类的方法来描述类型关系。常见类型有:

  • TensorType:具有可选细化信息的张量。它可能知道它的设备、类型、requires_grad 状态和维数。如果它确实知道维度的数量,它可能知道特定维度的大小
  • Tuple[T1, T2]:例如 Tuple[Tensor, Int]。元组的每个成员都是静态类型的,元组的长度是静态已知的
  • List[T]:例如 List[Tensor]。特定类型的可变列表
  • Optional[T]: 例如 Optional[Tensor]。含义为 T 类型的值或 None
  • Dict[K, V]:例如 Dict[String, Tensor]

如果 S 类型是 P 的子类型,那么我们可以在任何需要 P 类型的地方替换具有 S 类型的 IValue,这意味着子类型的 IValue 表示与基类型的 IValue 表示兼容。

生成程序

JIT 程序是使用 torch.jit.trace 或 torch.jit.script 创建的。在这两种情况下,得到的对象都是一个完整的 Module,其中包含方法中的所有代码,以及模块参数中的所有模型权重。但是,每个前端都通过不同的途径来生成这些模块。

Trace

tracer.h
tracer_state.h

tracer 通过记录对 Tensor 进行的实际操作来生成 Graph。从 Python 到 C++ 使用 torch.jit.trace 进行跟踪的入口点是 _create_method_from_trace

TracingState 对象的 thread local 实例维护「存储在 IValues 中的跟踪期间计算的实际数据(如 Tensor)」与「Graph 中将计算的每个值的抽象值」之间的映射。跟踪器使用函数 void setValueTrace(const IValue&, Value*)Value* getValueTrace(const IValue&) 来维护此映射。

在被跟踪的函数的输入和被构造的 Graph 的 Value 输入之间建立了初始的 IValue 到 Value 的映射。如果我们正在跟踪一个 torch.nn.Module,那么 tracer 还会将 Parameters 和 sub-Module 添加到正在构造的 Module 中,这些 Module 对应于被跟踪的 Python torch.nn.Module。还添加了这些值的映射,这样在跟踪中随着 Parameter 的使用,这些 Parameters 的使用也会在 Graph 中创建。

随着跟踪运行,各个算子在被跟踪的 Graph 中创建 Node 以记录发生的事情。此代码当前是在 tools/autograd/gen_variable_type.py 中为每个算子生成的。它会生成如下所示的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
torch::jit::Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
if (jit::tracer::isTracing()) {
tracer_state = jit::tracer::getTracingState();
at::Symbol op_name;
op_name = jit::Symbol::fromQualString("aten::__ilshift__");
node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self);
jit::tracer::addInputs(node, "other", other);
tracer_state->graph->insertNode(node);

jit::tracer::setTracingState(nullptr);
}
TypeDefault::__ilshift__(self, other);
if (tracer_state) {
jit::tracer::setTracingState(std::move(tracer_state));
jit::tracer::addOutput(node, self);
}

函数 addInputs 和 addOutput 被重载以处理算子使用的不同数据类型。trace 仅适用于 Tensor 和 Future。其他类型的跟踪不是原生支持的。相反,像 Tuple 或 List 这样的容器通常在跟踪结束时被展平为多个 Tensor。

The tracer has special behavior when tracing calls to other TorchScript functions. This behavior is implemented in the GraphExecutor right before a Graph is about to be run. If tracing is enabled while running the graph, the GraphExecutor will disable tracing, run the graph as normal, and then inline the Graph into the trace. It then hooks up the IValues computed by running the Graph to out Values in the inlined graph.

通过跟踪创建的 Graph 会被安装到相应 Module 的 forward 方法。无论被跟踪的事物是函数还是 torch.nn.Module,都会产生一个 Module。在函数情况下,生成的模块将仅具有单个 forward 函数,没有 Parameters 和 no sub-Module。

Script

Script 前端直接将 Python 语法转换为 Module。像许多编译器一样,这分为两个阶段:首先生成一个抽象语法树(AST),它由 Tree 对象构成。然后 IR emitter 对 Tree 进行语义分析并 lower 成 Module。可以通过两种方式生成 Tree:(1) 使用 frontend.py,它将 Python AST 转换为 Tree 对象,或者 (2) 通过 Lexer 和 Parser 直接解析 Python 语法。Lexer + Parser 的方式看似多余,但至关重要。我们需要在未链接 Python 时定义内置函数(frontend/builtin_functions.cpp),因为我们允许用户直接从包含 Python 源代码(api/include/torch/jit.h)的字符串生成 TorchScript 程序,而无需链接完整的 Python 实现(例如 CPython)。我们还使用这种 Python 语法作为 TorchScript 的序列化格式,因为它允许我们在不破坏向后兼容性的情况下对 IR 进行更改。此外,Lexer 也被用于实现 FunctionSchema 解析器,它将 FunctionSchema 声明从字符串转换为 FunctionSchema 对象。

Tree

frontend/tree.h

前端以 Tree 对象的形式生成 AST。Tree 类似于 S-表达式。叶子节点总是字符串。复合树有一个 kind(例如 lexer.h 中定义的 TK_CONST 或 TK_IDENT)和一个子树列表。例如 z.sigmoid() - (x + y) 对应的 Tree 是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
(-
(+
(variable (ident x))
(variable (ident y))
)
(apply
(.
(variable (ident z))
(ident sigmoid)
)
(list)
(list)
)
)

上述示例以 S-表达式样式打印,(kind …) 表示复合树,string_value 表示字符串。每个 Tree 中还有一个强制的 SourceRange 对象,用于描述它来自的文本范围。这些将用于代码中的错误报告。

Tree View

frontend/tree_views.h

Tree 很容易构建、可视化和遍历,但是从像函数定义这样的大型复合树中提取信息并不方便,因为它需要数字索引。Tree View 是树顶部的一个小层,可以创建和解构特定种类的树。例如,这里是 apply 节点的 Tree View,它为其子树提供 named accessor:被调用的函数、输入和属性(即 kwargs):

1
2
3
4
5
6
7
8
9
10
11
struct Apply : public Expr {
Expr callee() const {
return Expr(subtree(0));
}
List<Expr> inputs() const {
return List<Expr>(subtree(1));
}
List<Attribute> attributes() const {
return List<Attribute>(subtree(2));
}
};

遍历 Tree 的典型方法是对其种类使用 switch 操作符,然后构造适当的 Tree View:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
switch (tree.kind()) {
case TK_VAR: {
auto var = Var(tree); // construct tree view
return environment_stack->getSugaredVar(var.name());
}
case '.': {
auto select = Select(tree); // construct tree view
auto sv = emitSugaredExpr(select.value(), 1);
return sv->attr(select.range(), method, select.selector().name());
}
case TK_APPLY: {
auto apply = Apply(tree); // construct tree view
return emitApplyExpr(apply, n_binders);
}
}

frontend.py

torch/jit/frontend.py

我们构造 Tree 对象的一种方法是直接从 Python AST 中构建。此逻辑包含在 frontend.py 中。我们努力保持大部分 JIT 代码用 C++ 编写,因为大多数 JIT 功能仍然需要在没有安装 Python 的情况下工作。这段代码只是简单地构造了 Tree,过滤掉了我们不支持的 Python 的 AST 节点。

Lexer

frontend/lexer.h

当直接从字符串加载 TorchScript 代码时,我们使用标准的 Lexer + Parser 组合。Lexer 接受一个初始字符串,然后暴露一个有状态的接口来遍历字符串的 Token,提供一组标准的函数:

  • next() 推进 Lexer,返回当前 Token
  • cur() 提供当前 Token
  • lookahead() 提供当前 Token 之后的 Token
  • nextIf(int token_kind) 如果与 Token 种类匹配,则推进 Token

与 Python 类似,Lexer 对缩进是敏感的。Token TK_INDENT、TK_DEDENT 和 TK_NEWLINE 在代码缩进、反缩进时以及语句结束时被注入到 Token 流中。例如对于这个流:

1
2
3
if
.
.

我们会得到一个 Token 流:TK_IF, TK_NEWLINE, TK_INDENT, .,TK_NEWLINE, ., TK_NEWLINE, TK_DEDENT。不匹配的左括号会禁用这些 Token 的注入。结果是 Parser 可以像 C 语言的 {}; 一样简单地处理 TK_INDENT、TK_DEDENT 和 TK_NEWLINE。

Tokens

frontend/lexer.h

Token 可以是关键字(def)、运算符(+)、文字(3.4)或标识符(foo)。整数类型的 token_kind 标记了 Token 的类型,并且与 Tree 的 kind 完全相同。

Parser

frontend/parser.h

Parser 使用 Lexer 为函数定义构建 AST。 parseFunction 是解析单个 def ... 的入口,并将返回一个 Def Tree View。

The Parser is written as a top-down precedence parser, or “Pratt” parser. They are simpler and easier to understand than typical parser generators, while still being flexible enough to parse programming languages. For the most part parsing is done by recursive decent. To resolve operator precedence issues, the function to parse an expression is augmented with a precedent p such that calling the function means parse an expression whose operators all have precedence higher than p.

IR Emitter

frontend/ir_emitter.h

文件 ir_emitter.cpp 将 Tree 转换为 Module。主要入口是 defineMethodsInModule,它接受代表函数定义的 Def Tree Views 列表并将它们作为 Method 添加到 Module 中。在 lower 期间发生语义检查。IR Emitter 检查所有使用的变量是否已定义(作用域检查),以及所有值是否具有兼容的类型(类型检查)。在此过程中,它还会发出与 Tree 中的每个语句对应的图节点,并为整个定义生成一个 FunctionSchema。

lower 过程中存在一些辅助对象。SugaredValues 是这样一种特殊值,表示在编译期间可能出现但不是第一类值的对象。例如,在 TorchScript Method 中 self 指的是 Module,而 self.weight 指的是 Module 的 Parameter。两者都不是第一类 Type,并且在 Graph 中没有对应的 Value。Resolver 对象是将外部定义的变量解析为 SugaredValue 的 std::function。例如,包含大部分内置操作的标识符 torch 是通过 Resolver 对象查找的,Resolver 对象负责与程序的 Python 态交互。

Environment 维护变量名称和它们引用的 SugaredValues 之间的映射。

SugaredValue

frontend/sugared_value.h

SugaredValues 是 IR Emitter 在 Graph 创建期间表示非第一类值的方式。这些值类似于 Module 或 Python 函数调用,在 Graph 中没有对应的 Value 对象。IR Emitter 根据 SugaredValue 对象的使用方式将其 desugar 到图中的指令。SugaredValue 类上有许多抽象方法,例如 attr 或 call。考虑表达式 self.foo。如果 foo 是一个 Method,self 将解析为一个特殊的 SugaredValue 子类 ModuleValue。当 Emitter 看到 self.foo 时,它会调用这个 ModuleValue 的函数 sv.attr(“foo”),询问 ModuleValue 当属性 “foo” 被访问时它应该如何为自己 desugar。如果 foo 是一个蚕丝被胡,那么它将确保该参数已添加到正在编译的 Method 中,并返回一个 sugared value SimpleValue,其中包含将参数表示为输入的 Value 对象。如果 foo 是一个 sub-Module,那么它将返回另一个 SugaredModule。The method call is invoked when the emitter sees the value used as a function call.

SugaredValues 也是在编译过程中与 Python 运行时交互的方式。例如,math.pi 被解析为 3.1415…,首先将 math 解析为一个 SugaredValue 表示对 Python 模块(PythonModuleValue)的访问,其 attr 函数将 Python 数字转换为图中的 prim::Constant 节点。

Finally, normal Values are also represented by the SimpleValue SugaredValue in places where it is valid that either a SugaredValue or a normal Value will appear.

Resolver

frontend/resolver.h

编译期间的任何未定义变量都通过调用外部提供的 Resolver 来解析。当从 Python 调用时,此 Resolver 通过 pybind11 与 Python runtime 交互,以解析诸如 torch 和 math 之类的符号。

SugaredValue 和 Resolver 的组合将 IR Emitter 的实现与 pybind11 Python 绑定分离,从而使其能够与 Python 状态交互。这使得在 Python runtime 不存在时可以使用大部分 IR Emitter 功能。

Environment

frontend/ir_emitter.cpp

Environment 对象在编译期间跟踪变量名称的分配。It is local to the IR emitter file. 当控制流引入 sub-Block 时,一个新 Environment 会被创建。Environment 保留了两个表,一个用于类型系统中非第一类值(SugaredValues),一个用于类型系统中的第一类值。当设置第一类值时,一个 prim::Store 被 emit。当引用第一类值时,一个 prim::Load 被 emit。SugaredValues 不可重新分配。

Conversion To SSA

frontend/convert_to_ssa.cpp

如 Block 一节所述,IR 使用结构化控制流表示,结构化控制流由 if 和 loop 组成。这使得优化和 lower 到其他不支持非结构化控制流的编译器变得更容易。Python 控制流(中断、继续、返回)被 lower 为这种简化形式。Environment 中任何变量都会被闭包进来,因此 Environment 中的所有写入和读取都可被直接转换为 SSA 形式。

向 SSA 的转换分多个部分进行:

  1. Store 和 load 被添加到控制流算子(if 和 loop)中
  2. Break 和 continue 语句被从 Graph 中删除,并被 prim::LoopContinuation(%loop_continue_condition, %loop_carried_vars) 所替换。对于 break 语句,continue 条件被设置为 false;对于 continue 语句,则会内联循环条件。%loop_carried_vars 是包含 break 或 continue 语句的最内层循环的循环携带变量,通过在语句位置插入 prim::Load 调用来添加
  3. 将循环条件内联到 Graph 的循环
  4. 删除所有 store 并用变量名称的作用域内值替换所有 load
  5. 在 exit_transform pass 中删除 prim::LoopContinuations 和 prim::ReturnStmts

Exit Transform

frontend/exit_transforms.cpp

该 pass 接受一个 Graph 并移除其中的 LoopContinuation 和 ReturnStmts,并且正确设置 Block 输出。prim::LoopContinuation(*vals) 表示这些值针对的是最近的 loop Block。prim::ReturnStmt(*vals) 表示这些值针对的是最近的 Closure 或 Graph Block。

如果一个 Block 有一个 Exit Node,在到达 Exit 目标之前不会执行进一步的指令。如果我们遇到一个包含嵌套 Block 的 Node,这些嵌套 Block 可能已经命中 Exit Node,例如在一个 Block 中退出而在另一个 Block 中不退出的 if 语句,我们使用布尔值来标识 exit 是否已被命中。接下来的指令则被有条件地执行(条件化执行)。

例如下面这段 Python 代码:

1
2
3
4
5
while i < 5:
if i == 3:
i += 1
continue
i += 2

将会被转换成:

1
2
3
4
5
6
7
8
9
10
11
continue_loop = i < 5
while continue_loop:
if i == 3:
i = i + 1
continue_loop = i < 5
did_exit = True
if did_exit:
pass
else:
i = i + 2
continue_loop = i < 5

该 pass 还跟踪总是抛出异常的 Node 或 Block,这样我们就不会不必要地条件化执行。在下面的示例中,我们可以将 if 语句视为始终返回并删除 print 语句:

1
2
3
4
5
if i < 0:
raise Exception("Negative input")
else:
return math.sqrt(i)
print(i) # unreachable code

prim::Uninitialized 在编译器可以证明该值永远不会被使用时插入。它可以通过 exception、break、continue 和 return 来引入。

Python-Compiler Interaction

python/script_init.cpp

一组特殊的 SugaredValues 用于在编译过程中「Python 环境中的对象」和「Graph 中的 Value」之间的转换。此行为的入口是 toSugaredValue(py::object obj, ...),它通过 pybind11 Python 值推断出如何将其转换为适当的 SugaredValue。Values exist to represent Python functions, Python modules, and ScriptModule objects.

剩余章节

该文档还有 6 个章节:分析程序、执行程序、保存程序、测试程序、Python 打印、Python 绑定。就先不翻译了。相关细节可以参考《TorchScript 原理篇(下) - 运行模型》一文,其中有模型运行的相关源码分析。

《JIT Technical Overview》翻译 | PyTorch

http://www.zh0ngtian.tech/posts/fc5ee8eb.html

作者

zhongtian

发布于

2022-07-24

更新于

2023-12-16

许可协议

评论