TorchScript 原理篇(下) - 运行模型 | PyTorch
上一篇描述了如何生成 IR,本篇主要描述如何运行 IR。
概述
下面是本篇分析使用的例子:
1 | // load module |
torch::jit::load
torch.jit.load
和 torch::jit::load
底层都是通过 ScriptModuleDeserializer
类进行反序列化,这个函数会将序列化文件反序列化成 torch::jit::Module
对象。
Module::get_method
Module::get_method
会返回一个 Method
对象,其 operator()
重载如下:
1 | IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) { |
这里的 function_
其实就是前面保存模型时涉及到的 GraphFunction
类型变量。
GraphExecutor
GraphFunction::operator()
1 | IValue GraphFunction::operator()(std::vector<IValue> stack, const Kwargs& kwargs) { |
可以看出,最后是调用了 GraphExecutor::run
函数。
GraphExecutor::run
1 | void GraphExecutor::run(Stack& inputs) { |
1 | const ExecutionPlan& GraphExecutorImpl::getPlanFor(Stack& stack, size_t remaining_bailout_depth) |
1 | const ExecutionPlan& GraphExecutorImpl::getOrCompile(const Stack& stack) { |
GraphExecutorImpl::compileSpec
1 | ExecutionPlan GraphExecutorImpl::compileSpec(const ArgumentSpec& spec) { |
上面这段代码完成了 Graph
的优化,主要做了下面这几件事。为了方便理解,笔者将官方文档的一个示例搬了过来。(笔者注:这部分涉及大量的编译原理知识,看起来着实吃力,先简单总结下,等好好看看编译原理再仔细研究下这一块的代码。)
示例代码如下:
1 |
|
这段代码对应的未经优化的 Graph
如下:
1 | graph(%x : Tensor, |
特化变量类型
获取 Graph
的输入参数的属性 ArgumentSpec
,包括以下几种属性:
- dtype
- rank (not shape)
- requires_grad
- device type (CPU, CUDA)
- defined (whether the
Tensor
exists or is a placeholder)
1 | # post specialization, inputs are now specialized types |
可以看到输入参数的类型已经被重新标注了。
运行必须的 pass
pass 是编译器中的一个概念,意为“遍历一遍 IR,可以同时对它做一些操作”。这一步是为解释器生成合法 Graph
所必需的转换步骤。举个例子,一些 pass(如微分)会引入并非由算子定义的 Node
,这就需要运行 pass 去清除这些多余的 Node
,LowerGradOf
和 specializeAutogradZero
就是做了这件事。
通过 Graph
传播详细信息
这一步会传播常量和 ArgumentSpec
等信息,尽可能进行预计算。
1 | graph(%x : Float(*, *), |
可以看到,除了输入参数,中间变量的类型也被重新标注了。
保留梯度的优化
这些优化不会破坏 autograd。主要包括以下优化:
- 消除 dead code
- 消除公共子表达式
- 将冗余常量池化为单个值
- 窥孔优化,包括将一些代数操作重写为更简单的操作
- 展开小循环
- 展开循环产生的批处理矩阵乘法
1 | graph(%x : Float(*, *), |
这一步对于这个例子没有变化。
非微分相关优化
在不需要梯度的情况下(即仅推理时),可以直接应用优化来生成不带梯度的 Graph
。以 FuseGraph 例,它寻找相邻的 point-wise 操作和诸如 split
、concat
之类的reviewing 操作,并在图中创建 prim::FusionGroup
节点来替换这些操作。被注册以执行 prim::FusionGroup
的节点将为每个唯一的 Node
生成一个新的 CUDA kernel 函数,取代原来分离执行的方式。
注意融合组编译的两个阶段:首先,FuseGraph pass 将 Graph
拆分为可融合的子图,并将生成的 Graph
返回给图执行器。然后,当 Graph 转化为 Code 时,会查找 FusionGroup 节点对应的 Operation,并生成一个新的 CUDA kernel 函数。其他编译器应该以类似的方式工作,首先将一个新的算子引入到编译代码应该运行的图中,然后注册一个运算符来实现执行实际编译的节点。
1 | graph(%x : Float(*, *), |
1 | with prim::FusionGroup_0 = graph(%13 : Float(*, *), |
可以看到,sigmoid、tanh 等 point-wise 的操作被融合成 prim::FusionGroup_0
。
在不需要梯度的情况下,优化过程到此结束。在 compileSpec
函数的最后,从 Graph
构造一个 ExecutionPlan
对象并交给 InterpreterState
执行。
InterpreterState
几个类
ExecutionPlan
1 | struct ExecutionPlan { |
ExecutionPlan
对象的构造仅仅是将 Graph
转换成 Code
。
Code
Code
是 CodeImpl
的 wrapper。这里只列出一部分代码。
1 | struct CodeImpl { |
Code
通过 PreprocessGraph、emitCodeForBlock、insertInstruction 和 insertBailoutBlocks 四个步骤将 Graph 转为一系列指令。
Operation
大多数 Operation
都有一个关联的 FunctionSchema
,它描述了有多少输入将被 pop 以及有多少将被 push。堆栈概念使得定义具有可变数量输入和输出的运算符变得容易,而无需为每个单独的 Operation
分配输入和输出向量。下面的例子展示了 Operation
的使用过程:
1 | using Stack = std::vector<IValue>; |
Operator
Operator
可以简单理解为 Operation
和 FunctionSchema
的 wrapper。
步骤
PreprocessGraph
1 | PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) { |
- insertEnterMethodCalls:在
prim::Enter
节点之后插入显式prim::MethodCall
节点以实际调用对象上的__enter__
方法 - dropUnused:插入
prim::Drop
节点以终止任何未被使用的引用 - insertLastUses:确保每个
Value
在定义它的Block
中都有消费者。对于大多数节点来说,这已经是正确的。例外情况是:- 从未被使用的
Value
:添加一个prim::Drop
节点,该节点在定义后立即使用该值 - 最后一次被使用的
Value
:在最后一次使用的控制流节点之后插入一个prim::Drop
- 从未被使用的
CodeImpl::emitCodeForBlock
1 | void interpreter::CodeImpl::emitCodeForBlock(Block* block) { |
1 | void interpreter::CodeImpl::emitNodeAtBlockLevel(Node* node) { |
1 | void interpreter::CodeImpl::emitNode(Node* node) { |
emitNode
函数根据节点的类型将指令添加到 interpreter::CodeImpl::instructions_
中。
以 emitCall
和 emitOperator
为例:
1 | void interpreter::CodeImpl::emitCall(Function* func, at::ArrayRef<Value*> inputs) { |
1 | virtual void interpreter::CodeImpl::emitOperator(Node* node) { |
这里出现了两个变量:function_table_
和 operator_table_
,前者存放需要调用的函数,后者存放 Operator
对应的 Operation
。
CodeImpl::insertInstruction
1 | void interpreter::CodeImpl::insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) { |
可以看到,insertInstruction
将 Operator
对应的 OpCode
填入了 instructions_
中。
CodeImpl::insertBailoutBlocks
这里涉及到一个概念,什么是 bailout ?介绍这个概念之前,先引入另一个节点 prim::profile
。prim::profile
节点会被 ProfilingRecord::instrumentBlock
插入到每个 Value
的使用者中。prim::profile
节点的作用可以简单理解为记录和推导假设 Tensor
的形状,这有利于代码生成器能够生成更高效的代码。当在实际运行时,假设的 Tensor
形状有可能是错的,为了防止由此导致的运行失败,需要运行原始代码,即不依赖于形状假设的代码。CodeImpl::insertBailoutBlocks
构建原始计算图的去优化版本。bailout 这个单词的中文意思是“紧急救助”,在 LibTorch 中的作用也类似于此。
InterpreterStateImpl::run
InterpreterState
是 InterpreterStateImpl
的 wrapper。而 InterpreterStateImpl
则是一个可以执行指令的虚拟机。
1 | void InterpreterStateImpl::run(Stack& stack) { |
1 | bool runImpl(Stack& stack) { |
这个函数虽然复杂而且比较核心,但是也容易理解,它只是模仿了 CPU 处理指令的行为:
- 从 stack 中获取指令,并将指令、pc 等信息存入当前 frame 中
- 获取到当前 pc 指向的指令
- 根据指令的类型进行不同的操作。例如,对于
OpCode::OP
,获取operator_table_
中存放的Operation
,然后执行它;对于LOAD
将 reg 中的变量放入 stack 中 - 遇到子图调用,push 当前 frame 到 frames
warm up
从上面的分析可以看出,LibTorch 是 lazy initialization 模式,在 Module
第一次 run 的时候才会运行图优化。所以在使用 LibTorch 的时候要注意进行 warm up。需要注意的是,用来保存 Graph
优化结果的 map 所使用的的 key 是 ArgumentSpec
。这就要求在使用的时候尽量保持输入的一致,这里的一致是指输入 Tensor 的和 ArgumentSpec
相关的属性要保持一致。
总结
纵观整个 LibTorch 运行流程,LibTorch 实际上实现了一个编译器。保存模型阶段对应编译器的前端(语法分析、类型检查、中间代码生成)。运行模型阶段对应编译器后端(代码优化、目标代码生成、目标代码优化)。除此之外,LibTorch 还实现了一个可以运行该编译器所生成代码的解释器。
参考
TorchScript 如何实现Python -> C++ 代码转换
TorchScript 原理篇(下) - 运行模型 | PyTorch