TorchScript 入门篇 | PyTorch

本文介绍了 TorchScript 的基本原理和用法。

简介

什么是 TorchScript?这里引用官方的介绍

TorchScript 是可以由 TorchScript 编译器理解、编译和序列化的 PyTorch 模型的表示形式。从根本上说,TorchScript 本身就是一种编程语言。它是使用 PyTorch API 的 Python 的子集。

简单来说,TorchScript 软件栈可以将 Python 代码转换成 C++ 代码。TorchScript 软件栈包括两部分:TorchScript(Python)和 LibTorch(C++)。TorchScript 负责将 Python 代码转成一个模型文件,LibTorch 负责解析运行这个模型文件。

原理

保存模型

TorchScript 保存模型有两种模式:trace 模式和 script 模式。

trace 模式

trace 模式顾名思义,就是跟踪模型的执行,然后将其路径记录下来。在使用 trace 模式时,需要构造一个符合要求的输入,然后使用 TorchScript tracer 运行一遍,整个运行过程就会被记录下来。在 trace 模式中运行时,每执行一个算子,就会往当前的 graph 加入一个 node。所有代码执行完毕,每一步的操作就会以一个计算图里的某个节点的形式被保存下来。值得一提的是,PyTorch 导出 ONNX 也是使用了这部分代码,所以理论上能够导出 ONNX 的模型也能够使用 trace 模式导出 torch 模型。

trace 模式有比较大的限制:

  • 不能有 if-else 等控制流
  • 只支持 Tensor 操作

通过上述对实现方式的解释,很容易理解为什么有这种限制:1. 跟踪出的 graph 是静态的,如果有控制流,那么记录下来的只是当时生成模型时走的那条路;2. 追踪代码是跟 Tensor 算子绑定在一起的,如果是非 Tensor 的操作,是无法被记录的。

通过 trace 模式的特点可以看出,trace 模式通常用于深度模型的导出(深度模型通常没有 if-else 控制流且没有非 Tensor 操作)。

script 模式

使用方式两种模式很接近,但是实现原理却大相径庭。TorchScript 实现了一个完整的编译器以支持 script 模式。保存模型阶段对应编译器的前端(语法分析、类型检查、中间代码生成)。在保存模型时,TorchScript 编译器解析 Python 代码,并构建代码的 AST(抽象语法树)。

script 模式在的限制比较小,不仅支持 if-else 等控制流,还支持非 Tensor 操作,如 List、Tuple、Map 等容器操作。

运行模型

运行模型阶段对应编译器后端(代码优化、目标代码生成、目标代码优化)。除此之外,LibTorch 还实现了一个可以运行该编译器所生成代码的解释器。在运行代码时,在 LibTorch 中,AST 被加载,在进行一系列代码优化后生成目标代码(并非机器码),然后由解释器运行。

使用

trace 模式

对于下面这种只有 Tensor 操作的模型,比较适合使用 trace 模式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Module_0(torch.nn.Module):
def __init__(self, N, M):
super(Module_0, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)

def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.weight.mm(input)
output = self.linear(output)
return output


scripted_module = torch.jit.trace(Module_0(2, 3).eval(), (torch.zeros(3, 2)))
scripted_module.save("Module_0.pt")

script 模式

对于下面这种存在控制流和非 Tensor 操作的模型,比较适合使用 script 模式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Module_1(torch.nn.Module):
def __init__(self, N, M):
super(Module_1, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.linear = torch.nn.Linear(N, M)

def forward(self, input: torch.Tensor, do_linear: bool) -> torch.Tensor:
output = self.weight.mm(input)
if do_linear:
output = self.linear(output)
return output


scripted_module = torch.jit.script(Module_1(3, 3).eval())
scripted_module.save("Module_1.pt")

混合模式

trace 模式和 script 模式各有千秋也各有局限,在使用时将两种模式结合在一起使用可以最大化发挥 TorchScript 的优势。例如,一个 module 包含控制流,同时也包含一个只有 Tensor 操作的子模型。这种情况下当然可以直接使用 script 模式,但是 script 模式需要对部分变量进行类型标注,比较繁琐。这种情况下就可以仅对上述子模型进行 trace,整体再进行 script:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Module_2(torch.nn.Module):
def __init__(self, N, M):
super(Module_2, self).__init__()
self.linear = torch.nn.Linear(N, M)
self.sub_module = torch.jit.trace(Module_0(2, 3).eval(), (torch.zeros(3, 2)))

def forward(self, input: torch.Tensor, do_linear: bool) -> torch.Tensor:
output = self.sub_module(input)
if do_linear:
output = self.linear(output)
return output


scripted_module = torch.jit.script(Module_2(2, 3).eval())

C++

针对上面模型导出的例子,C++ 中加载使用的方式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <torch/script.h>

int main() {
// load module
torch::jit::script::Module torch_module;
try {
torch_module = torch::jit::load("my_module.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the module" << std::endl;
return -1;
}

// make inputs
std::vector<float> vec(9);
std::vector<torch::jit::IValue> torch_inputs;
torch::Tensor torch_tensor =
torch::from_blob(vec.data(), {3, 3}, torch::kFloat32);
torch_inputs.emplace_back(torch_tensor);
torch_inputs.emplace_back(false);

// run module
torch::jit::IValue torch_outputs;
try {
torch_outputs = torch_module.forward(torch_inputs);
} catch (const c10::Error& e) {
std::cerr << "error running the module" << std::endl;
return -1;
}

auto outputs_tensor = torch_outputs.toTensor();
}

CMakeLists.txt 的写法可以参考这里

TorchScript 的语法限制

正如前面所介绍的,在使用 trace 模式时,不能使用控制流(如果使用则只能记录对应 example input 的那个分支)和非 Tensor 操作。script 模式则自由许多,但是 TorchScript 毕竟是 Python 的子集,在使用时还有是诸多限制的。下面列举出一些已知的限制:

  • 支持的类型有限,这些类型是指在运行(而非初始化)过程中使用的对象或者函数参数

    • Type Description
      Tensor A PyTorch tensor of any dtype, dimension, or backend
      Tuple[T0, T1, ..., TN] A tuple containing subtypes T0, T1, etc. (e.g. Tuple[Tensor,Tensor])
      bool A boolean value
      int A scalar integer
      float A scalar floating point number
      str A string
      List[T] A list of which all members are type T
      Optional[T] A value which is either None or type T
      Dict[K, V] A dict with key type K and value type V. Only str, int, and float are allowed as key types.
    • 这其中不包括 set 数据类型,这意味着需要使用 set 的地方就要通过其他的方式绕过,比如先用 list 然后去重

    • 使用 tuple 时需要声明其中的类型,例如 Tuple[int, int, int],这也就意味着 tuple 在运行时长度不能变化,所以要使用 list 代替

    • 创建字典时,只有 int、float、comple、string、torch.Tensor 可以作为 key

  • 不支持 lambda 函数,但是可以通过自定义排序类的方式实现,略微麻烦,但是可以解决

  • 因为 TorchScript 是静态类型语言,运行时不能变换变量类型

  • 因为编码问题,所以对中文字符串进行遍历时会抛异常,所以尽量不要处理中文,如果需要处理中文,则需要将中文切分成字符粒度后再送入模型中进行处理

综上,虽然 TorchScript 存在一些限制,但是都是可以通过别的手段绕过的,至少在语义上是可以基本对齐 C++ 的。

参考

PyTorch C++ API — PyTorch master documentation

TorchScript 如何实现Python -> C++ 代码转换

Garry’s Blog - Advanced libtorch

python - How do I type hint a method with the type of the enclosing class?

TorchScript Language Reference

TorchScript 入门篇 | PyTorch

http://www.zh0ngtian.tech/posts/76ff5f2a.html

作者

zhongtian

发布于

2021-07-29

更新于

2023-12-16

许可协议

评论