TorchScript 踩坑篇 | PyTorch
本文介绍了 TorchScript 在使用中常见的报错和解决方法。
实用技巧
list、dict 对象作为 module 的成员
空的 list、dict 对象是不能直接作为 module 的成员变量的,如果要用的话,需要在构造函数之前进行声明。对于非空 list、dict 对象,则可以直接在构造函数中声明。
1 | from typing import List, Dict |
参考:TorchScript Language Reference — Module Attributes
list 自定义排序
list 排序不支持自定义函数,但是可以通过自定义排序类实现。举个例子,有一个二维列表 List[List[int]],我们想根据列表的长度进行排序,就可以这么做:
1 | # Python >= 3.7 |
in 的使用限制
在 TorchScript 中,in
操作仅支持有限的几种类型,如果对复杂类型使用 in
会报错,支持类型如下:
- int
in
List[int] - float
in
List[float] - str
in
List[str] - str
in
str - str
in
Dict[str, T] - int
in
Dict[int, T] - bool
in
Dict[bool, T] - float
in
Dict[float, T] - complex
in
Dict[complex, T] - Tensor
in
Dict[Tensor, T]
所以当需要对其他类型使用 in 操作时,需要自行实现判断逻辑。
如何表达 None 类型
None 是 Python 中很重要一种类型,万物皆可为 None,那么在 TorchScript 中应该如何表达呢?使用 typing.Optional 即可。
1 | # Python |
参考:[JIT] Infer type of argument foo = None to be Optional[Tensor] · Issue #40867
去除变量的 Optional 属性
1 | def cast_away_optional(arg: Optional[变量类型]) -> 变量类型: |
参考:Expose a cast variant to remove Optional · Issue #645
ParameterList 的替代方式
ParameterList 是不允许使用的,但是如果一定要用,可以自定义一个继承自 nn.Module 的类,这个类中只放一个参数,然后使用 ModuleList 的方式代替。
1 | # PyTorch |
从内存加载模型
load 函数有两个重载:
std::shared_ptr<script::Module> load(const std::string& filename, c10::optional<c10::Device> device = c10::nullopt);
std::shared_ptr<script::Module> load(std::istream& in, c10::optional<c10::Device> device = c10::nullopt);
第一个函数就是常用的从磁盘上加载模型。第二个函数就可以直接从内存加载模型,这个函数常用的场景是将模型进行解密后加载进 LibTorch(私有化部署场景要求模型加密)。
pass None in C++
传入 at::nullopt
即可。
参考:None Default Arguments · Issue #5
错误处理
- Python 中使用 assert、raise 会导致最终 C++ 中抛出异常,注意使用 try catch。
- Python 代码中如果有访问越界,也会在 C++ 中抛出异常。
判断 nn.Module 所在的设备
不像 Tensor 类,nn.Module 本身没有 device 成员(原因见这里),不能判断其位于哪个设备上。这时就需要用一个比较 trick 的方法:
1 | class ModuleWrapper(nn.Module): |
常见报错
Unknown builtin op: aten::Tensor
使用 torch.tensor
而非 torch.Tensor
。
参考:Unknown builtin op: aten::Tensor - jit - PyTorch Forums
Expected a value of type ‘t’ for argument ‘el’
此问题常见于想要 push None 到某个 list 中,由于 TorchScript 要求 list 中的元素类型必须一样,所以对于一个 List[int],如果想要 push None 到里面就会报错,因为 None 和 int 不是一个类型。
想要解决这个问题,只需要将 list 声明为 List[Optional[int]] 即可,这样的话 list 中就可以 push int 和 None 两种类型了。
参考:Different types in if statement branches produce error in TorchScript - jit - PyTorch Forums
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals
不要使用索引去获取 ModuleList 的元素,应该使用 for elem in module_list
来获取。
参考:
torch.jit.script cannot index a ModuleList with an int returned from range() · Issue #47496
https://github.com/pytorch/pytorch/blob/master/test/jit/test_module_containers.py#L499
Attribute 1 is not of annotated type
通常是由 ModuleList 或者 ModuleDict 成员引起的,使用索引去或者 key 去获取其元素会造成该问题,使用 for elem in module_list
来获取元素可以解决。
ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save
保存 TorchScript module 时使用 module.save("module.pt")
而非 torch.save("module.pt)
。
Output annotation element type and runtime tensor element type must match for tolist()
在使用 torch.Tensor.tolist() 这个方法的时候,注意标注 list 的具体类型。也注意元素的数据类型是否一致,比如 int32 与 int64、float32 与 float64。
参考:[jit] Tensor .tolist() is not bound in TorchScript · Issue #26752
Expected all tensors to be on the same device
想要令某个函数在某个特定设备上执行,需要做的就是把参与运算的所有 Tensor 都拷贝到相同的设备上,否则会报 Expected all tensors to be on the same device, but found at least two devices
之类的错误。拷贝的操作也很简单,直接使用 Tensor 的 to 方法就行。
RuntimeError(没有其他任何提示)
对于一些比较冷门的错误,运行时可能不会打出具体的错误信息,导致没办法定位问题所在,这种情况就比较麻烦。这种情况下就只能通过在 Python 中添加日志进行调试,找到问题在哪一行代码。
参考
PyTorch C++ API — PyTorch master documentation
TorchScript 如何实现Python -> C++ 代码转换
Garry’s Blog - Advanced libtorch
python - How do I type hint a method with the type of the enclosing class?
TorchScript 踩坑篇 | PyTorch