TorchScript 入门篇 | PyTorch
本文介绍了 TorchScript 的基本原理和用法。
简介
什么是 TorchScript?这里引用官方的介绍:
TorchScript 是可以由 TorchScript 编译器理解、编译和序列化的 PyTorch 模型的表示形式。从根本上说,TorchScript 本身就是一种编程语言。它是使用 PyTorch API 的 Python 的子集。
简单来说,TorchScript 软件栈可以将 Python 代码转换成 C++ 代码。TorchScript 软件栈包括两部分:TorchScript(Python)和 LibTorch(C++)。TorchScript 负责将 Python 代码转成一个模型文件,LibTorch 负责解析运行这个模型文件。
原理
保存模型
TorchScript 保存模型有两种模式:trace 模式和 script 模式。
trace 模式
trace 模式顾名思义,就是跟踪模型的执行,然后将其路径记录下来。在使用 trace 模式时,需要构造一个符合要求的输入,然后使用 TorchScript tracer 运行一遍,整个运行过程就会被记录下来。在 trace 模式中运行时,每执行一个算子,就会往当前的 graph 加入一个 node。所有代码执行完毕,每一步的操作就会以一个计算图里的某个节点的形式被保存下来。值得一提的是,PyTorch 导出 ONNX 也是使用了这部分代码,所以理论上能够导出 ONNX 的模型也能够使用 trace 模式导出 torch 模型。
trace 模式有比较大的限制:
- 不能有 if-else 等控制流
- 只支持 Tensor 操作
通过上述对实现方式的解释,很容易理解为什么有这种限制:1. 跟踪出的 graph 是静态的,如果有控制流,那么记录下来的只是当时生成模型时走的那条路;2. 追踪代码是跟 Tensor 算子绑定在一起的,如果是非 Tensor 的操作,是无法被记录的。
通过 trace 模式的特点可以看出,trace 模式通常用于深度模型的导出(深度模型通常没有 if-else 控制流且没有非 Tensor 操作)。
script 模式
使用方式两种模式很接近,但是实现原理却大相径庭。TorchScript 实现了一个完整的编译器以支持 script 模式。保存模型阶段对应编译器的前端(语法分析、类型检查、中间代码生成)。在保存模型时,TorchScript 编译器解析 Python 代码,并构建代码的 AST(抽象语法树)。
script 模式在的限制比较小,不仅支持 if-else 等控制流,还支持非 Tensor 操作,如 List、Tuple、Map 等容器操作。
运行模型
运行模型阶段对应编译器后端(代码优化、目标代码生成、目标代码优化)。除此之外,LibTorch 还实现了一个可以运行该编译器所生成代码的解释器。在运行代码时,在 LibTorch 中,AST 被加载,在进行一系列代码优化后生成目标代码(并非机器码),然后由解释器运行。
使用
trace 模式
对于下面这种只有 Tensor 操作的模型,比较适合使用 trace 模式:
1 | class Module_0(torch.nn.Module): |
script 模式
对于下面这种存在控制流和非 Tensor 操作的模型,比较适合使用 script 模式:
1 | class Module_1(torch.nn.Module): |
混合模式
trace 模式和 script 模式各有千秋也各有局限,在使用时将两种模式结合在一起使用可以最大化发挥 TorchScript 的优势。例如,一个 module 包含控制流,同时也包含一个只有 Tensor 操作的子模型。这种情况下当然可以直接使用 script 模式,但是 script 模式需要对部分变量进行类型标注,比较繁琐。这种情况下就可以仅对上述子模型进行 trace,整体再进行 script:
1 | class Module_2(torch.nn.Module): |
C++
针对上面模型导出的例子,C++ 中加载使用的方式如下:
1 |
|
CMakeLists.txt
的写法可以参考这里。
TorchScript 的语法限制
正如前面所介绍的,在使用 trace 模式时,不能使用控制流(如果使用则只能记录对应 example input 的那个分支)和非 Tensor 操作。script 模式则自由许多,但是 TorchScript 毕竟是 Python 的子集,在使用时还有是诸多限制的。下面列举出一些已知的限制:
支持的类型有限,这些类型是指在运行(而非初始化)过程中使用的对象或者函数参数
Type Description Tensor
A PyTorch tensor of any dtype, dimension, or backend Tuple[T0, T1, ..., TN]
A tuple containing subtypes T0
,T1
, etc. (e.g.Tuple[Tensor,Tensor]
)bool
A boolean value int
A scalar integer float
A scalar floating point number str
A string List[T]
A list of which all members are type T
Optional[T]
A value which is either None or type T
Dict[K, V]
A dict with key type K
and value typeV
. Onlystr
,int
, andfloat
are allowed as key types.这其中不包括 set 数据类型,这意味着需要使用 set 的地方就要通过其他的方式绕过,比如先用 list 然后去重
使用 tuple 时需要声明其中的类型,例如 Tuple[int, int, int],这也就意味着 tuple 在运行时长度不能变化,所以要使用 list 代替
创建字典时,只有 int、float、comple、string、torch.Tensor 可以作为 key
不支持 lambda 函数,但是可以通过自定义排序类的方式实现,略微麻烦,但是可以解决
因为 TorchScript 是静态类型语言,运行时不能变换变量类型
因为编码问题,所以对中文字符串进行遍历时会抛异常,所以尽量不要处理中文,如果需要处理中文,则需要将中文切分成字符粒度后再送入模型中进行处理
综上,虽然 TorchScript 存在一些限制,但是都是可以通过别的手段绕过的,至少在语义上是可以基本对齐 C++ 的。
参考
PyTorch C++ API — PyTorch master documentation
TorchScript 如何实现Python -> C++ 代码转换
Garry’s Blog - Advanced libtorch
python - How do I type hint a method with the type of the enclosing class?
TorchScript 入门篇 | PyTorch