TorchScript 踩坑篇 | PyTorch
本文介绍了 TorchScript 在使用中常见的报错和解决方法。
实用技巧
list、dict 对象作为 module 的成员
空的 list、dict 对象是不能直接作为 module 的成员变量的,如果要用的话,需要在构造函数之前进行声明。对于非空 list、dict 对象,则可以直接在构造函数中声明。
1 | from typing import List, Dict |
参考:TorchScript Language Reference — Module Attributes
list 自定义排序
list 排序不支持自定义函数,但是可以通过自定义排序类实现。举个例子,有一个二维列表 List[List[int]],我们想根据列表的长度进行排序,就可以这么做:
1 | # Python >= 3.7 |
in 的使用限制
在 TorchScript 中,in 操作仅支持有限的几种类型,如果对复杂类型使用 in 会报错,支持类型如下:
- int
inList[int] - float
inList[float] - str
inList[str] - str
instr - str
inDict[str, T] - int
inDict[int, T] - bool
inDict[bool, T] - float
inDict[float, T] - complex
inDict[complex, T] - Tensor
inDict[Tensor, T]
所以当需要对其他类型使用 in 操作时,需要自行实现判断逻辑。
如何表达 None 类型
None 是 Python 中很重要一种类型,万物皆可为 None,那么在 TorchScript 中应该如何表达呢?使用 typing.Optional 即可。
1 | # Python |
参考:[JIT] Infer type of argument foo = None to be Optional[Tensor] · Issue #40867
去除变量的 Optional 属性
1 | def cast_away_optional(arg: Optional[变量类型]) -> 变量类型: |
参考:Expose a cast variant to remove Optional · Issue #645
ParameterList 的替代方式
ParameterList 是不允许使用的,但是如果一定要用,可以自定义一个继承自 nn.Module 的类,这个类中只放一个参数,然后使用 ModuleList 的方式代替。
1 | # PyTorch |
从内存加载模型
load 函数有两个重载:
std::shared_ptr<script::Module> load(const std::string& filename, c10::optional<c10::Device> device = c10::nullopt);std::shared_ptr<script::Module> load(std::istream& in, c10::optional<c10::Device> device = c10::nullopt);
第一个函数就是常用的从磁盘上加载模型。第二个函数就可以直接从内存加载模型,这个函数常用的场景是将模型进行解密后加载进 LibTorch(私有化部署场景要求模型加密)。
pass None in C++
传入 at::nullopt 即可。
参考:None Default Arguments · Issue #5
错误处理
- Python 中使用 assert、raise 会导致最终 C++ 中抛出异常,注意使用 try catch。
- Python 代码中如果有访问越界,也会在 C++ 中抛出异常。
判断 nn.Module 所在的设备
不像 Tensor 类,nn.Module 本身没有 device 成员(原因见这里),不能判断其位于哪个设备上。这时就需要用一个比较 trick 的方法:
1 | class ModuleWrapper(nn.Module): |
常见报错
Unknown builtin op: aten::Tensor
使用 torch.tensor 而非 torch.Tensor。
参考:Unknown builtin op: aten::Tensor - jit - PyTorch Forums
Expected a value of type ‘t’ for argument ‘el’
此问题常见于想要 push None 到某个 list 中,由于 TorchScript 要求 list 中的元素类型必须一样,所以对于一个 List[int],如果想要 push None 到里面就会报错,因为 None 和 int 不是一个类型。
想要解决这个问题,只需要将 list 声明为 List[Optional[int]] 即可,这样的话 list 中就可以 push int 和 None 两种类型了。
参考:Different types in if statement branches produce error in TorchScript - jit - PyTorch Forums
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals
不要使用索引去获取 ModuleList 的元素,应该使用 for elem in module_list 来获取。
参考:
torch.jit.script cannot index a ModuleList with an int returned from range() · Issue #47496
https://github.com/pytorch/pytorch/blob/master/test/jit/test_module_containers.py#L499
Attribute 1 is not of annotated type
通常是由 ModuleList 或者 ModuleDict 成员引起的,使用索引去或者 key 去获取其元素会造成该问题,使用 for elem in module_list 来获取元素可以解决。
ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save
保存 TorchScript module 时使用 module.save("module.pt") 而非 torch.save("module.pt)。
Output annotation element type and runtime tensor element type must match for tolist()
在使用 torch.Tensor.tolist() 这个方法的时候,注意标注 list 的具体类型。也注意元素的数据类型是否一致,比如 int32 与 int64、float32 与 float64。
参考:[jit] Tensor .tolist() is not bound in TorchScript · Issue #26752
Expected all tensors to be on the same device
想要令某个函数在某个特定设备上执行,需要做的就是把参与运算的所有 Tensor 都拷贝到相同的设备上,否则会报 Expected all tensors to be on the same device, but found at least two devices 之类的错误。拷贝的操作也很简单,直接使用 Tensor 的 to 方法就行。
RuntimeError(没有其他任何提示)
对于一些比较冷门的错误,运行时可能不会打出具体的错误信息,导致没办法定位问题所在,这种情况就比较麻烦。这种情况下就只能通过在 Python 中添加日志进行调试,找到问题在哪一行代码。
参考
PyTorch C++ API — PyTorch master documentation
TorchScript 如何实现Python -> C++ 代码转换
Garry’s Blog - Advanced libtorch
python - How do I type hint a method with the type of the enclosing class?
TorchScript 踩坑篇 | PyTorch