之前的一篇翻译《JIT Technical Overview》简单介绍了 TorchScript IR 中涉及到的几个主要类,包括 Graph、Node、Block、Value、Type 等。本文主要是从 C++ 接口的角度梳理下这些类之间的关系。
Module::attributes()
通过 torch::jit::Module::attributes()
方法可以获取到 module 中的所有成员函数。所有成员变量最终以 torch::jit::IValue
的形式在 IR 中表示。
1 2 3 4 5 6 7 8 graph TD module(torch::jit::Module)==>|"attributes()"|attributes(torch::jit::attribute_list) attributes==>|"attribute()"|attribute(torch::jit::detail::AttributePolicy) attribute==>|"attribute()"|attribute_(torch::jit::NameValue) attribute_==>|"name"|name(std::string) attribute_==>|"value"|value(torch::jit::IValue) style value fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
Module::get_methods
通过 torch::jit::Module::get_methods()
方法可以获取到 module 中的所有成员函数。下图将几种重要的类加粗标记了。其中一些具体的数据(如某条语句的输入输出、函数的参数等)最终以 torch::jit::Value
和 torch::jit::IValue
的形式在 IR 中表示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 graph TD module(torch::jit::Module)==>|"get_methods()"|methods(std::vector of torch::jit::Method) methods==>method(torch::jit::Method) method==>|"graph()"|graph_(torch::jit::Graph) graph_==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value) ios==>io(torch::jit::Value) graph_==>|"nodes()"|nodes(torch::jit::graph_node_list) nodes==>|"node()"|node(torch::jit::Node) node==>|"kind()"|kind(torch::jit::NodeKind) kind==>|"ns()"|ns(torch::jit::Symbol // like an interned string) node==>|"inputs() & outputs()"|ios node==>|"blocks()"|blocks(std::vector of torch::jit::Block) blocks==>block(torch::jit::Block) block==>|"nodes()"|nodes method==>|"function()"|function(torch::jit::Function) function==>|"getSchema()"|schema(c10::FunctionSchema) function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction) graph_function==>|"graph()"|graph_ schema==>|"arguments() & returns()"|ars(std::vector of c10::Argument) ars==>ar(c10::Argument) ar==>|"name"|name(std::string) ar==>|"value"|ivalue(torch::jit::IValue) style graph_ fill:#fff,stroke:#000,stroke-width:4px style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5 style node fill:#fff,stroke:#000,stroke-width:4px style block fill:#fff,stroke:#000,stroke-width:4px style ivalue fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
torch::jit::Value
torch::jit::Value
是 TorchScript IR 中非常重要的数据类,它和很多其他类有着复杂的关系。下图展示了几种数据类之间的关系。
1 2 3 4 5 6 7 8 9 graph TD node(torch::jit::Node)==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value) ios==>io(torch::jit::Value) io==>|"node()"|node io==>|"uses()"|uses(std::vector of torch::jit::Use) uses==>use(torch::jit::Use) use==>|"user()"|node style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
IR 中所有节点的输入输出都以 torch::jit::Value
的形式出现,用以承载复杂的类型系统。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 graph LR value(torch::jit::Value)==>|"debugName()"|debug_name(std::string) value==>|"type()"|type(c10::TypePtr) type==>list(c10::ListType) type==>dict(c10::DictType) type==>class_(c10::ClassType) type==>tensor(c10::TensorType) type==>int(c10::IntType) type==>function(c10::FunctionType) type==>type_etc(......) list==>|"kind()"|kind_list(c10::TypeKind::ListType) list==>|"getElementType()"|element_type(c10::TypePtr) dict==>|"kind()"|kind_dict(c10::TypeKind::DictType) dict==>|"getKeyType()"|key_type(c10::TypePtr) dict==>|"getValueType()"|value_type(c10::TypePtr) class_==>|"kind()"|kind_class_(c10::TypeKind::ClassType) class_==>|"name"|name(std::string) tensor==>|"kind()"|kind_tensor(c10::TypeKind::TensorType) tensor==>|"sizes()"|sizes(std::vector of int64_t) tensor==>|"scalarType()"|scalar_type(at::ScalarType) tensor==>etc_tensor(......) int==>|"kind()"|kind_int(c10::TypeKind::IntType) function==>|"expect()"|func_type_ptr(c10::FunctionTypePtr) func_type_ptr==>|"function()"|jit_function(torch::jit::Function) jit_function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction) graph_function==>|"graph()"|graph_(torch::jit::Graph)
torch::jit::IValue
torch::jit::IValue
和 torch::jit::Value
的一个区别是:前者主要出现在程序运行的过程中,后者仅在 IR 中。torch::jit::IValue
是运行过程中出现的所有数据类型的容器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 graph LR ivalue(torch::jit::IValue)==>|"type()"|type(c10::TypePtr) ivalue==>|"payload()"|payload(torch::jit::IValue::Payload) payload==>|"payload()"|nt_payload(torch::jit::IValue::Payload::TriviallyCopyablePayload) nt_payload==>|"as_int()"|as_int(int64_t) nt_payload==>|"as_double()"|as_double(double) nt_payload==>|"as_bool()"|as_bool(bool) nt_payload==>|"as_intrusive_ptr()"|as_intrusive_ptr(c10::intrusive_ptr_target) payload==>|"as_tensor()"|as_tensor(at::Tensor // Tensor's actual implementation) ivalue==>|"tag()"|tag(torch::jit::IValue::Tag) tag==>tag_tensor(torch::jit::IValue::Tag::Tensor) tag==>tag_int(torch::jit::IValue::Tag::Int) tag==>tag_etc(......)