TorchScript 踩坑篇 | PyTorch

本文介绍了 TorchScript 在使用中常见的报错和解决方法。

实用技巧

list、dict 对象作为 module 的成员

空的 list、dict 对象是不能直接作为 module 的成员变量的,如果要用的话,需要在构造函数之前进行声明。对于非空 list、dict 对象,则可以直接在构造函数中声明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import List, Dict

class Foo(torch.nn.Module):
# 需要预先声明
empty_list: List[str]
empty_dict: Dict[str, int]

def __init__(self, a_dict):
super(Foo, self).__init__()

self.empty_list = []
self.empty_dict = {}

self.non_empty_list: List[str] = ["0"]
self.non_empty_dict = {"0": 0}

参考:TorchScript Language Reference — Module Attributes

list 自定义排序

list 排序不支持自定义函数,但是可以通过自定义排序类实现。举个例子,有一个二维列表 List[List[int]],我们想根据列表的长度进行排序,就可以这么做:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Python >= 3.7
from __future__ import annotations

class ListWrapper:
def __init__(self, list_member: List[int]):
self.list_member: List[int] = list_member

# Python >= 3.7
def __lt__(self, other: ListWrapper):
return len(self.list_member) < len(other.list_member)

# Python < 3.7
def __lt__(self, other: "ListWrapper"):
return len(self.list_member) < len(other.list_member)


def sort_list_list(list_to_sort: List[List[int]]) -> List[List[int]]:
list_wrapper_list: List[ListWrapper] = []
for l in list_to_sort:
list_wrapper: ListWrapper = ListWrapper(l)
list_wrapper_list.append(list_wrapper)
list_wrapper_list = sorted(list_wrapper_list)
for i, l in enumerate(list_wrapper_list):
list_to_sort[i] = l.list_member

in 的使用限制

在 TorchScript 中,in 操作仅支持有限的几种类型,如果对复杂类型使用 in 会报错,支持类型如下:

  • int in List[int]
  • float in List[float]
  • str in List[str]
  • str in str
  • str in Dict[str, T]
  • int in Dict[int, T]
  • bool in Dict[bool, T]
  • float in Dict[float, T]
  • complex in Dict[complex, T]
  • Tensor in Dict[Tensor, T]

所以当需要对其他类型使用 in 操作时,需要自行实现判断逻辑。

如何表达 None 类型

None 是 Python 中很重要一种类型,万物皆可为 None,那么在 TorchScript 中应该如何表达呢?使用 typing.Optional 即可。

1
2
3
4
5
6
7
8
9
# Python
def func(x = None):
pass

# TorchScript
import typing

def func(typing.Optional[<type_name>] = None):
pass

参考:[JIT] Infer type of argument foo = None to be Optional[Tensor] · Issue #40867

去除变量的 Optional 属性

1
2
3
def cast_away_optional(arg: Optional[变量类型]) -> 变量类型:
assert arg is not None
return arg

参考:Expose a cast variant to remove Optional · Issue #645

ParameterList 的替代方式

ParameterList 是不允许使用的,但是如果一定要用,可以自定义一个继承自 nn.Module 的类,这个类中只放一个参数,然后使用 ModuleList 的方式代替。

1
2
3
4
5
6
7
8
9
# PyTorch
param_list = nn.ParameterList([nn.Parameter(torch.Tensor(1, n)) for _ in range(m)])

# TorchScript
class ParamWrapper(nn.Module):
def __init__(self, n):
super(ParamWrapper, self).__init__()
self.param = nn.Parameter(torch.Tensor(1, n))
param_list = nn.ModuleList([ParamWrapper(n) for _ in range(m)])

从内存加载模型

load 函数有两个重载:

  • std::shared_ptr<script::Module> load(const std::string& filename, c10::optional<c10::Device> device = c10::nullopt);
  • std::shared_ptr<script::Module> load(std::istream& in, c10::optional<c10::Device> device = c10::nullopt);

第一个函数就是常用的从磁盘上加载模型。第二个函数就可以直接从内存加载模型,这个函数常用的场景是将模型进行解密后加载进 LibTorch(私有化部署场景要求模型加密)。

pass None in C++

传入 at::nullopt 即可。

参考:None Default Arguments · Issue #5

错误处理

  1. Python 中使用 assert、raise 会导致最终 C++ 中抛出异常,注意使用 try catch。
  2. Python 代码中如果有访问越界,也会在 C++ 中抛出异常。

判断 nn.Module 所在的设备

不像 Tensor 类,nn.Module 本身没有 device 成员(原因见这里),不能判断其位于哪个设备上。这时就需要用一个比较 trick 的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class ModuleWrapper(nn.Module):
def __init__(self):
super(ModuleWrapper, self).__init__()
self.dummy_param = nn.Parameter(torch.empty(0))
self.device = self.dummy_param.device

def get_device(self):
return self.dummy_param.device


class MyModule(ModuleWrapper):
def __init__(self):
super(EventTable, self).__init__()

def forward():
print(self.device)

常见报错

Unknown builtin op: aten::Tensor

使用 torch.tensor 而非 torch.Tensor

参考:Unknown builtin op: aten::Tensor - jit - PyTorch Forums

Expected a value of type ‘t’ for argument ‘el’

此问题常见于想要 push None 到某个 list 中,由于 TorchScript 要求 list 中的元素类型必须一样,所以对于一个 List[int],如果想要 push None 到里面就会报错,因为 None 和 int 不是一个类型。

想要解决这个问题,只需要将 list 声明为 List[Optional[int]] 即可,这样的话 list 中就可以 push int 和 None 两种类型了。

参考:Different types in if statement branches produce error in TorchScript - jit - PyTorch Forums

Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals

不要使用索引去获取 ModuleList 的元素,应该使用 for elem in module_list 来获取。

参考:

Attribute 1 is not of annotated type

通常是由 ModuleList 或者 ModuleDict 成员引起的,使用索引去或者 key 去获取其元素会造成该问题,使用 for elem in module_list 来获取元素可以解决。

ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save

保存 TorchScript module 时使用 module.save("module.pt") 而非 torch.save("module.pt)

Output annotation element type and runtime tensor element type must match for tolist()

在使用 torch.Tensor.tolist() 这个方法的时候,注意标注 list 的具体类型。也注意元素的数据类型是否一致,比如 int32 与 int64、float32 与 float64。

参考:[jit] Tensor .tolist() is not bound in TorchScript · Issue #26752

Expected all tensors to be on the same device

想要令某个函数在某个特定设备上执行,需要做的就是把参与运算的所有 Tensor 都拷贝到相同的设备上,否则会报 Expected all tensors to be on the same device, but found at least two devices 之类的错误。拷贝的操作也很简单,直接使用 Tensor 的 to 方法就行。

RuntimeError(没有其他任何提示)

对于一些比较冷门的错误,运行时可能不会打出具体的错误信息,导致没办法定位问题所在,这种情况就比较麻烦。这种情况下就只能通过在 Python 中添加日志进行调试,找到问题在哪一行代码。

参考

PyTorch C++ API — PyTorch master documentation

TorchScript 如何实现Python -> C++ 代码转换

Garry’s Blog - Advanced libtorch

python - How do I type hint a method with the type of the enclosing class?

TorchScript 踩坑篇 | PyTorch

http://www.zh0ngtian.tech/posts/3f804c9b.html

作者

zhongtian

发布于

2021-07-29

更新于

2023-12-16

许可协议

评论