TorchScript 原理篇(上) - 保存模型 | PyTorch
本文最后更新于:2021年9月11日
注:本文是在《PyTorch 源码解读之即时编译篇》 基础之上进行了一些修正和补充。
本文涉及代码来自 PyTorch 1.9.0 (commit_id = d69c22d)。为了方便查看,格式略微有变动。TorchScript 支持 script 多种类型,为了方便了解原理,本文选用最简单的 scripted funtion 进行解读。
下面是本篇分析使用的例子:
def my_func(a: int) -> int:
a = a + 2
return a
scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
scripted_func.save("my_func.pt")
Python 部分
本文涉及 C++ 代码的命名空间均为 torch::jit
,下文省略,个别容易和 Python 混淆的地方才会额外标注。
整体调用逻辑如下:
script
get_jit_def
:获得 torch 抽象语法树ast.parse
:获得 Python 抽象语法树build_def
:将 Python 的抽象语法树转化为 torch 抽象语法树build_stmts
:建立 torch 抽象语法树
torch._C._jit_script_compile
:调用 C++ 函数,根据 torch 抽象语法树获得中间表达,即torch::jit::CompilationUnit
对象save
:将torch::jit::CompilationUnit
对象序列化,并保存到磁盘
script
# 去掉与 script function 无关的代码
def script(obj, optimize=None, _frames_up=0, _rcb=None):
# 如果是已经 scripted 过的对象则直接返回
if isinstance(obj, ScriptFunction):
return obj
# this is a decorated fn, and we need to the underlying fn and its rcb
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
# 检查是否有重载,因为 function 是根据函数名寻找函数的
_check_directly_compile_overloaded(obj)
# 检查之前是否编译过了
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
# 获得抽象语法树
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
# 调用 C++ 函数,根据抽象语法树获得 IR
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__
# 将编译结果缓存
_set_jit_function_cache(obj, fn)
return fn
get_jit_def
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
"""
Build a JIT AST (TreeView) from the given function.
Args:
fn: A function object to compile
def_name: The name to give to the resulting AST object. This is not
always the same as `fn.__name__`, for example:
def _forward(self):
...
forward = _forward
In this case, the `__name__` attribute of the function object is "_forward",
but we want the result AST to have the name "forward".
self_name: If this function is a method, what the type name of `self` is.
"""
# dedent_src 为包含了要 script 函数的字符串
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)
# 调用 Python ast 包将字符串解析为 Python 的抽象语法树
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
# 如果使用 torch.jit.annotate 标注变量类型,type_line 为变量类型标注
type_line = torch.jit.annotations.get_type_line(source)
# SourceContext 是 SourceRangeFactory 的 wrapper,包含编译所需的所有元信息
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
# 因为编译的是个单独的函数,所以取其第一个成员
fn_def = py_ast.body[0]
# 如果使用了 MonkeyType,则可以利用其生成的类型标注数据(MonkeyType 是通过运行时跟踪类型并自动将类型注释添加到 Python 3 代码的工具)
# If MonkeyType is installed, get all the consolidated type traces
# for the arguments from type_trace_db
type_trace_db = torch.jit._script._get_type_trace_db()
pdt_arg_types = None
if monkeytype_trace:
qualname = get_qualified_name(fn)
pdt_arg_types = type_trace_db.get_args_types(qualname)
# build_def 将 Python 的抽象语法树转化为 PyTorch 使用的抽象语法树格式
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
py_ast
为了查看 py_ast
这个对象的内容,可以使用下面这行代码来打印。
import yaml, jsonpickle; print(yaml.dump(yaml.load(jsonpickle.encode(py_ast)), indent=2))
结果如下:
- py/type: _ast.Module
- body:
- py/reduce:
- py/type: _ast.FunctionDef
- args:
py/reduce:
- py/type: _ast.arguments
- args:
- py/reduce:
- py/type: _ast.arg
- annotation:
py/reduce:
- py/type: _ast.Name
- col_offset: 15
ctx:
py/reduce:
- py/type: _ast.Load
- {}
id: int
lineno: 1
arg: a
col_offset: 12
lineno: 1
defaults: []
kw_defaults: []
kwarg: null
kwonlyargs: []
vararg: null
body:
- py/reduce:
- py/type: _ast.Assign
- col_offset: 4
lineno: 2
targets:
- py/reduce:
- py/type: _ast.Name
- col_offset: 4
ctx:
py/reduce:
- py/type: _ast.Store
- {}
id: a
lineno: 2
value:
py/reduce:
- py/type: _ast.BinOp
- col_offset: 8
left:
py/reduce:
- py/type: _ast.Name
- col_offset: 8
ctx:
py/id: 12
id: a
lineno: 2
lineno: 2
op:
py/reduce:
- py/type: _ast.Add
- {}
right:
py/reduce:
- py/type: _ast.Num
- col_offset: 12
lineno: 2
n: 2
- py/reduce:
- py/type: _ast.Return
- col_offset: 4
lineno: 3
value:
py/reduce:
- py/type: _ast.Name
- col_offset: 11
ctx:
py/id: 12
id: a
lineno: 3
col_offset: 0
decorator_list: []
lineno: 1
name: my_func
returns:
py/reduce:
- py/type: _ast.Name
- col_offset: 23
ctx:
py/id: 12
id: int
lineno: 1
将上述内容简化如下:
- ast.body[0]
- _ast.Assign
- op: _ast.Add
- left: _ast.Name
- id: a
- right: _ast.Num
- id: 2
- _ast.Return
- value: _ast.Name
- id: a
- value: _ast.Name
- _ast.Assign
ast.body
是一个 list,其长度等于解析的 string 中包含的函数的个数。其第一个元素的 body
成员包含两个元素,类型分别为 _ast.Assign
和 _ast.Return
。对于前者,其 value 是一个 Binop
,包含三个成员:op = Add、left = Name、right = Num。这个 Binop
即解析的 a = a + 2
。
buid_def
def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
body = py_def.body
# 这一堆代码是利用输入参数构造了一堆元信息
r = ctx.make_range(py_def.lineno + len(py_def.decorator_list),
py_def.col_offset,
py_def.col_offset + len("def"))
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
return_type = None
if getattr(py_def, 'returns', None) is not None:
return_type = build_expr(ctx, py_def.returns)
decl = Decl(r, param_list, return_type)
is_method = self_name is not None
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
# build_stmts 函数利用了传进来了 ctx 和 body,其中 ctx 包含了源代码的元信息,body 则是抽象语法树
return Def(Ident(r, def_name),
decl,
build_stmts(ctx, body))
build_stmts
# 以下模块均从 C++ 导入
from torch._C._jit_tree_views import (
ClassDef, Ident, Stmt, Decl, Def, Var,
EmptyTypeAnnotation, Param, ExprStmt, Assign,
Delete, Return, Raise, Assert, AugAssign, While,
For, If, Pass, Break, Continue, Apply, Dots, Select,
TrueLiteral, FalseLiteral, NoneLiteral, Starred,
ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
# 对语法树进行递归调用,stmt 是 statement(语句)的简写
def build_stmts(ctx, stmts):
stmts = [build_stmt(ctx, s) for s in stmts]
return list(filter(None, stmts))
# build_stmt 会调用 StmtBuilder.__call__
build_stmt = StmtBuilder()
build_expr = ExprBuilder()
class Builder(object):
# StmtBuilder 与 ExprBuilder 继承自 Builder 且没有重写 __call__ 函数,所以最终调用的还是 Builder.__call__
def __call__(self, ctx, node):
# 根据语法树节点的类型拼凑方法的名字以完成调用
method = getattr(self, 'build_' + node.__class__.__name__, None)
if method is None:
raise UnsupportedNodeError(ctx, node)
return method(ctx, node)
# 这里只留下了和例子相关的 build 函数
class StmtBuilder(Builder):
@staticmethod
def build_Expr(ctx, stmt):
value = stmt.value
if value.__class__.__name__ == 'Str':
# If a statement is a string literal expression,
# then it is a docstring. Just ignore it.
return None
else:
return ExprStmt(build_expr(ctx, value))
@staticmethod
def build_Assign(ctx, stmt):
rhs = build_expr(ctx, stmt.value) # stmt.value = _ast.BinOp
lhs = [build_expr(ctx, x) for x in stmt.targets] # x = _ast.Name
return Assign(lhs, rhs)
@staticmethod
def build_Return(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
class ExprBuilder(Builder):
binop_map = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
ast.Pow: '**',
ast.Mod: '%',
ast.FloorDiv: '//',
ast.BitAnd: '&',
ast.BitXor: '^',
ast.BitOr: '|',
ast.LShift: '<<',
ast.RShift: '>>',
}
binop_map[ast.MatMult] = '@'
@staticmethod
def build_Name(ctx, expr):
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
if expr.id.startswith(_reserved_prefix):
raise NotSupportedError(r, "names of variables used in JIT-ed functions "
"can't start with " + _reserved_prefix)
if expr.id == "True":
return TrueLiteral(r)
elif expr.id == "False":
return FalseLiteral(r)
elif expr.id == "None":
return NoneLiteral(r)
elif expr.id == "Ellipsis":
return Dots(r)
return Var(Ident(r, expr.id))
@staticmethod
def build_Num(ctx, expr):
value = str(expr.n)
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
return Const(r, value)
@staticmethod
def build_BinOp(ctx, expr):
lhs = build_expr(ctx, expr.left) # _ast.Name
rhs = build_expr(ctx, expr.right) # _ast.Num
op = type(expr.op) # _ast.Add
if op == ast.Div and not ctx.uses_true_division:
err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
raise FrontendError(err_range, 'Division of ints in TorchScript uses Python 3 true '
'division semantics. Please put `from __future__ '
'import division` at the top of your file')
# 转化为约定的代表运算类型的 string 符号
op_token = ExprBuilder.binop_map.get(op) # op_token = "+"
if op_token is None:
err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
return BinOp(op_token, lhs, rhs)
最终的 PyTorch AST 可以直接打印出来,如下:
(def
(ident my_func)
(decl
(list
(param
(ident a)
(option (variable (ident int)))
(option)
(False)))
(option (variable (ident int))))
(list
(assign
(list (variable (ident a)))
(option
(+
(variable (ident a))
(const 2)))
(option))
(return (variable (ident a)))))
C++ 部分
几个类
本节介绍 JIT 部分比较常见的几个类,按照 low-level 到 high-level 的顺序介绍。
Graph
Graphs
是用于定义 TorchScript 函数实现的 IR 的基础。如果对 LLVM 比较熟悉的话,这个类似 llvm::Function
类。Graph
主要由 Node
、Block
和 Value
组成。Node
是指令(例如做矩阵乘)。Blocs
是一组有序执行的 Node
。每个 Node
输入一系列 Value
,输出一系列 Value
。
用前面分析使用的 Python 代码为例,并打印其 Graph
:
def my_func(a: int) -> int:
a = a + 2
return a
scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
print(scripted_func.graph)
"""
graph(%a.1 : int):
%2 : int = prim::Constant[value=2]()
%a.5 : int = aten::add(%a.1, %2)
return (%a.5)
"""
打印出的内容是 PyTorch IR 的标准文本形式,很容易在其中找到之前提到的元素
graph
是Graph
%a.5 : int
是带有类型标注的Value
%a.5 : int = aten::add(%a.1, %2)
是Node
,表示aten::add
算子,它以%a.1
和%2
作为输入,以%a.5
所谓作为输出
Graph
是静态单赋值(single-static assignment, SSA)的形式,即每个变量只会被赋值一次,每一次赋值定义一个新的变量。
Function
Function
有两种:BuiltinOpFunction
和 GraphFunction
。前者是 native C++ 函数的 wrapper。后者拥有以下类型的主要成员:
FunctionSchema
:描述输入参数和返回值的类型和名称Graph
:即 IR,定义了函数的实现。下一节会详细介绍GraphExecutor
:用于实际执行定义该方法的Graph
。下一节会详细介绍
Method
Method
可以简单理解为是 Function
的 wrapper。拥有 Module
所有的可执行方法。
CompilationUnit
CompilationUnit
本质是一组 Function
。CompilationUnit
拥有几个 define
函数,可以将 AST 转为 IR 存放在其 Function
类型的成员中。
StrongFunctionPtr
StrongFunctionPtr
、CompilationUnit
、Function
的关系正如下面代码所描述的。
struct StrongFunctionPtr {
StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
: cu_(std::move(cu)), function_(function) {
TORCH_INTERNAL_ASSERT(cu_);
TORCH_INTERNAL_ASSERT(function_);
}
std::shared_ptr<CompilationUnit> cu_;
Function* function_;
};
Module
TorchScript 和 LibTorch 交互的中介。序列化和反序列化都是以 Module
为单位进行的。Module
继承自 Object
类,拥有一个 ObjectPtr
类型成员,ObjectPtr
拥有类型为 StrongTypePtr
的成员 type_
和类型为 std::vector<IValue>
的成员 slots_
。其中前者存有该 module 所有 Method
,后者存有该 module 的所有数值类成员。大概结构如下:
- Module : Object
- ObjectPtr _ivalue_
- StrongTypePtr type_
- std::vector<IValue> slots_
- ObjectPtr _ivalue_
步骤
整体调用逻辑如下:
script_compile_function
get_python_cu
CompilationUnit::define
script_compile_function
// 将 script_compile_function 函数注册到 Python 接口 torch._C._jit_script_compile
m.def(
"_jit_script_compile",
[](const std::string& qualname,
const Def& def,
const ResolutionCallback& rcb,
const FunctionDefaults& defaults) {
C10_LOG_API_USAGE_ONCE("torch.script.compile");
const auto name = c10::QualifiedName(qualname);
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
return script_compile_function(name, def, defaults, rcb);
});
// StrongFunctionPtr 对应到 Python 就是 torch.jit.ScriptFunction
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
const ResolutionCallback& rcb) {
auto cu = get_python_cu();
// 通过后面的解释可以知道,defined_functions 包含了一系列 creator 函数,其作用是将 AST 转为 IR
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
/*properties=*/{},
/*propResolvers=*/{},
{def},
{pythonResolver(rcb)},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
// ret 中包含了一个 CompilationUnit 成员和一个 Function 成员,他们都包含了 creator 函数
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}
get_python_cu
// 构造 CompilationUnit 对象
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}
CompilationUnit::define
// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> CompilationUnit::define(
const c10::optional<c10::QualifiedName>& prefix,
const std::vector<Property>& properties,
const std::vector<ResolverPtr>& propResolvers,
const std::vector<Def>& definitions, // 这个是 torch 抽象语法树
const std::vector<ResolverPtr>& defResolvers, // determines how we handle free variables in each definition
const Self* self, // if non-null, the first argument to each def, is bound to this value
bool shouldMangle = false /* see [name mangling] */) {
TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size());
TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size());
std::vector<Function*> functions;
std::unordered_map<std::string, Function*> function_table;
// Records fn in function_table, functions and with register_function.
// This is done several times below, so this lambda helps avoid repeating
// code.
auto record_function = [&](std::unique_ptr<Function> fn) {
function_table[fn->name()] = fn.get();
functions.emplace_back(fn.get());
this->register_function(std::move(fn));
};
for (size_t i = 0; i < properties.size(); i++) {
PropertyPair property_fns = define_property(
prefix,
properties[i],
propResolvers[i],
self,
function_table,
shouldMangle);
auto& getter_fn = property_fns.getGetter();
auto& setter_fn = property_fns.getSetter();
record_function(std::move(getter_fn));
if (setter_fn) {
record_function(std::move(setter_fn));
}
}
// torch 抽象语法树传给了 define 函数的另一个重载
for (size_t i = 0; i < definitions.size(); i++) {
auto fn = define(
prefix,
definitions[i],
defResolvers[i],
self,
function_table,
shouldMangle,
CompilationUnit::FunctionType::Method);
// 通过后面的代码可以知道,fn 包含了 creator 函数,其功能是将 AST 转为 IR
record_function(std::move(fn));
}
// We need to compile `__init__` first, since it can determine what attributes
// are available to other methods. So reorder the definitions accordingly.
for (auto& kv : function_table) {
if (kv.first == "__init__") {
kv.second->ensure_defined();
}
}
for (Function* function : functions) {
// 调用 function 中保存的 creator,通过 AST 创建 IR,存入 function 中
function->ensure_defined();
}
return functions;
}
void GraphFunction::ensure_defined() {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
creator(*this);
function_creator_ = nullptr;
}
check_single_output();
}
std::unique_ptr<Function> CompilationUnit::define(
const c10::optional<QualifiedName>& prefix,
const Def& def,
const ResolverPtr& resolver,
const Self* self,
const std::unordered_map<std::string, Function*>& function_table,
bool shouldMangle,
CompilationUnit::FunctionType type) const {
TORCH_INTERNAL_ASSERT(resolver);
auto _resolver = resolver;
// script_compile_function 传进来的是个 nullptr
if (!self) {
// if self is defined, then these are methods and do not go into the
// global namespace otherwise, they get defined together so we add them to
// the function table so the methods can see each other
_resolver =
std::make_shared<FunctionResolver>(resolver.get(), function_table);
}
// creator 将 def(存有 torch 抽象语法树)通过 FunctionResolver 转成 IR 存入 method 中
auto creator = [def, _resolver, self](Function& method) {
// Store the function name so that it can be referenced if there is an error
// while compiling this function
std::string call_name = method.qualname().name();
if (self) {
auto atoms = method.qualname().atoms();
// There should be at least a ClassName.method_name
TORCH_INTERNAL_ASSERT(atoms.size() >= 2);
call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1);
}
ErrorReport::CallStack call(call_name, def.range());
to_ir(def, _resolver, self, method);
};
auto name = prefix ? QualifiedName(*prefix, def.name().name())
: QualifiedName(def.name().name());
if (shouldMangle) {
// If `shouldMangle` is set, we should generate a unique name for this
// function if there is already an existing one.
if (auto fn = find_function(name)) {
name = mangle(name);
}
}
// 将 creator 存入 fn 中
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
if (self) {
// Register this as a method on `self`'s type
if (type == CompilationUnit::FunctionType::Hook) {
self->getClassType()->addForwardHook(fn.get());
} else if (type == CompilationUnit::FunctionType::PreHook) {
self->getClassType()->addForwardPreHook(fn.get());
} else {
self->getClassType()->addMethod(fn.get());
}
}
return fn;
}
在 define
函数中,function
中保存的 creator
被调用,其通过 AST 创建 IR,存入 function
中。creator
函数会调用 to_ir
创建 IR。
to_ir
Script 模式的整个流程可以简单分为两步:首先生成 AST(以树对象的形式呈现),然后 IR emitter 通过对该树对象做语义分析从而将其转化为 Module
。第二步的主要工作由结构体 to_ir
的构造函数完成。在 emitStatements
函数中,语句根据其类型被 emit 到 Graph
中。
struct to_ir {
to_ir(
const Def& def, // AST
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
typeParser_(resolver),
environment_stack(nullptr) {
......
method.setSchema(emitDef(def, self, graph->block()));
......
}
}
FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
......
emitStatements(stmts_list.begin(), stmts_list.end());
......
}
void emitStatements(
List<Stmt>::const_iterator begin,
List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
ErrorReport::CallStack::update_pending_range(stmt.range());
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
case TK_AUG_ASSIGN:
emitAugAssignment(AugAssign(stmt));
break;
case TK_EXPR_STMT: {
auto expr = ExprStmt(stmt).expr();
emitSugaredExpr(expr, 0);
} break;
case TK_RAISE:
emitRaise(Raise(stmt));
break;
case TK_ASSERT:
emitAssert(Assert(stmt));
break;
case TK_RETURN: {
emitReturn(Return(stmt));
} break;
case TK_CONTINUE: {
emitContinue(Continue(stmt));
} break;
case TK_BREAK: {
emitBreak(Break(stmt));
} break;
case TK_PASS:
// Emit nothing for pass
break;
case TK_DEF:
emitClosure(Def(stmt));
break;
case TK_DELETE:
emitDelete(Delete(stmt));
break;
case TK_WITH:
emitWith(With(stmt));
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
// Found an exit statement in this block. The remaining statements aren't
// reachable so we don't emit them.
if (exit_blocks.count(environment_stack->block()))
return;
}
}
emitX
函数会调用 graph->insert
将各种 Node
插入到 Graph
中。emitX
函数的细节比较复杂,暂不展开介绍。
Module::save
// 将 Module::save 注册到 torch.jit.ScriptFunction
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
.def(
"save",
[](const StrongFunctionPtr& self,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
Module module("__torch__.PlaceholderModule");
// [issue 27343]
// Modules have 'training' attributes by default, but due to
// https://github.com/pytorch/pytorch/issues/27343, functions end
// up having a training attribute when they are loaded. This adds
// a fake 'training' attribute that shouldn't be used, but prevents
// jitter on saving and loading. Once that issue is fixed this can
// be deleted.
module.register_attribute("training", BoolType::get(), true);
// 将 StrongFunctionPtr 添加到 Module 对象中
addFunctionToModule(module, self);
module.save(filename, _extra_files);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap());
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
// Make a graph with a fake self argument
auto graph = func.function_->graph()->copy();
auto v = graph->insertInput(0, "self");
v->setType(module._ivalue()->type());
const auto name = QualifiedName(*module.type()->name(), "forward");
// 将 StrongFunctionPtr 转成 Method 对象
auto method = module._ivalue()->compilation_unit()->create_function(name, graph);
// 将 Method 对象添加到 Module 对象中
module.type()->addMethod(method);
}
void ClassType::addMethod(torch::jit::Function* method) {
TORCH_CHECK(
findMethod(method->name()) == nullptr,
"Can't redefine method: ",
method->name(),
" on class: ",
repr_str());
methods_.push_back(method);
}
void Module::save(const std::string& filename, const ExtraFilesMap& extra_files) const {
ExportModule(*this, filename, extra_files, false /* bytecode_format */);
}
void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info) {
caffe2::serialize::PyTorchStreamWriter writer(filename);
ScriptModuleSerializer serializer(writer);
serializer.serialize(module, extra_files, bytecode_format, save_mobile_debug_info);
}
通过代码可以知道,在进行保存时,先将 StrongFunctionPtr
转换成 Module
对象,然后通过 ScriptModuleSerializer
进行序列化。ScriptModuleSerializer
和 ScriptModuleDeserializer
是 PyTorch 中用于序列化的类,其中涉及到模型的序列化,可以参考 TorchScript serialization。
参考
TorchScript 如何实现Python -> C++ 代码转换
评论系统采用 utterances ,加载有延迟,请稍等片刻。