TorchScript IR 中的类型体系 | PyTorch

之前的一篇翻译《JIT Technical Overview》简单介绍了 TorchScript IR 中涉及到的几个主要类,包括 Graph、Node、Block、Value、Type 等。本文主要是从 C++ 接口的角度梳理下这些类之间的关系。

Module::attributes()

通过 torch::jit::Module::attributes() 方法可以获取到 module 中的所有成员函数。所有成员变量最终以 torch::jit::IValue 的形式在 IR 中表示。

1
2
3
4
5
6
7
8
graph TD
module(torch::jit::Module)==>|"attributes()"|attributes(torch::jit::attribute_list)
attributes==>|"attribute()"|attribute(torch::jit::detail::AttributePolicy)
attribute==>|"attribute()"|attribute_(torch::jit::NameValue)
attribute_==>|"name"|name(std::string)
attribute_==>|"value"|value(torch::jit::IValue)

style value fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5

Module::get_methods

通过 torch::jit::Module::get_methods() 方法可以获取到 module 中的所有成员函数。下图将几种重要的类加粗标记了。其中一些具体的数据(如某条语句的输入输出、函数的参数等)最终以 torch::jit::Valuetorch::jit::IValue 的形式在 IR 中表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
graph TD
module(torch::jit::Module)==>|"get_methods()"|methods(std::vector of torch::jit::Method)
methods==>method(torch::jit::Method)
method==>|"graph()"|graph_(torch::jit::Graph)
graph_==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value)
ios==>io(torch::jit::Value)

graph_==>|"nodes()"|nodes(torch::jit::graph_node_list)
nodes==>|"node()"|node(torch::jit::Node)
node==>|"kind()"|kind(torch::jit::NodeKind)
kind==>|"ns()"|ns(torch::jit::Symbol // like an interned string)

node==>|"inputs() & outputs()"|ios

node==>|"blocks()"|blocks(std::vector of torch::jit::Block)
blocks==>block(torch::jit::Block)
block==>|"nodes()"|nodes

method==>|"function()"|function(torch::jit::Function)
function==>|"getSchema()"|schema(c10::FunctionSchema)
function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction)
graph_function==>|"graph()"|graph_
schema==>|"arguments() & returns()"|ars(std::vector of c10::Argument)
ars==>ar(c10::Argument)
ar==>|"name"|name(std::string)
ar==>|"value"|ivalue(torch::jit::IValue)

style graph_ fill:#fff,stroke:#000,stroke-width:4px
style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
style node fill:#fff,stroke:#000,stroke-width:4px
style block fill:#fff,stroke:#000,stroke-width:4px
style ivalue fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5

torch::jit::Value

torch::jit::Value 是 TorchScript IR 中非常重要的数据类,它和很多其他类有着复杂的关系。下图展示了几种数据类之间的关系。

1
2
3
4
5
6
7
8
9
graph TD
node(torch::jit::Node)==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value)
ios==>io(torch::jit::Value)
io==>|"node()"|node
io==>|"uses()"|uses(std::vector of torch::jit::Use)
uses==>use(torch::jit::Use)
use==>|"user()"|node

style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5

IR 中所有节点的输入输出都以 torch::jit::Value 的形式出现,用以承载复杂的类型系统。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
graph LR
value(torch::jit::Value)==>|"debugName()"|debug_name(std::string)

value==>|"type()"|type(c10::TypePtr)
type==>list(c10::ListType)
type==>dict(c10::DictType)
type==>class_(c10::ClassType)
type==>tensor(c10::TensorType)
type==>int(c10::IntType)
type==>function(c10::FunctionType)
type==>type_etc(......)

list==>|"kind()"|kind_list(c10::TypeKind::ListType)
list==>|"getElementType()"|element_type(c10::TypePtr)

dict==>|"kind()"|kind_dict(c10::TypeKind::DictType)
dict==>|"getKeyType()"|key_type(c10::TypePtr)
dict==>|"getValueType()"|value_type(c10::TypePtr)

class_==>|"kind()"|kind_class_(c10::TypeKind::ClassType)
class_==>|"name"|name(std::string)

tensor==>|"kind()"|kind_tensor(c10::TypeKind::TensorType)
tensor==>|"sizes()"|sizes(std::vector of int64_t)
tensor==>|"scalarType()"|scalar_type(at::ScalarType)
tensor==>etc_tensor(......)

int==>|"kind()"|kind_int(c10::TypeKind::IntType)

function==>|"expect()"|func_type_ptr(c10::FunctionTypePtr)
func_type_ptr==>|"function()"|jit_function(torch::jit::Function)
jit_function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction)
graph_function==>|"graph()"|graph_(torch::jit::Graph)

torch::jit::IValue

torch::jit::IValuetorch::jit::Value 的一个区别是:前者主要出现在程序运行的过程中,后者仅在 IR 中。torch::jit::IValue 是运行过程中出现的所有数据类型的容器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
graph LR
ivalue(torch::jit::IValue)==>|"type()"|type(c10::TypePtr)

ivalue==>|"payload()"|payload(torch::jit::IValue::Payload)
payload==>|"payload()"|nt_payload(torch::jit::IValue::Payload::TriviallyCopyablePayload)
nt_payload==>|"as_int()"|as_int(int64_t)
nt_payload==>|"as_double()"|as_double(double)
nt_payload==>|"as_bool()"|as_bool(bool)
nt_payload==>|"as_intrusive_ptr()"|as_intrusive_ptr(c10::intrusive_ptr_target)

payload==>|"as_tensor()"|as_tensor(at::Tensor // Tensor's actual implementation)

ivalue==>|"tag()"|tag(torch::jit::IValue::Tag)
tag==>tag_tensor(torch::jit::IValue::Tag::Tensor)
tag==>tag_int(torch::jit::IValue::Tag::Int)
tag==>tag_etc(......)

TorchScript IR 中的类型体系 | PyTorch

http://www.zh0ngtian.tech/posts/dbbbf040.html

作者

zhongtian

发布于

2022-10-14

更新于

2023-12-16

许可协议

评论