TorchScript 原理篇(上) - 保存模型 | PyTorch

本文是在《PyTorch 源码解读之即时编译篇》 基础之上进行了一些修正和补充。

概述

本文涉及代码来自 PyTorch 1.9.0 (commit_id = d69c22d)。为了方便查看,格式略微有变动。TorchScript 支持 script 多种类型,为了方便了解原理,本文选用最简单的 scripted funtion 进行解读。

下面是本篇分析使用的例子:

1
2
3
4
5
6
def my_func(a: int) -> int:
a = a + 2
return a

scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
scripted_func.save("my_func.pt")

Python 部分

本文涉及 C++ 代码的命名空间均为 torch::jit,下文省略,个别容易和 Python 混淆的地方才会额外标注。

整体调用逻辑如下:

  • script
    • get_jit_def:获得 torch 抽象语法树
      • ast.parse:获得 Python 抽象语法树
      • build_def:将 Python 的抽象语法树转化为 torch 抽象语法树
        • build_stmts:建立 torch 抽象语法树
    • torch._C._jit_script_compile:调用 C++ 函数,根据 torch 抽象语法树获得中间表达,即 torch::jit::CompilationUnit 对象
    • save:将 torch::jit::CompilationUnit 对象序列化,并保存到磁盘

script

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 去掉与 script function 无关的代码
def script(obj, optimize=None, _frames_up=0, _rcb=None):
# 如果是已经 scripted 过的对象则直接返回
if isinstance(obj, ScriptFunction):
return obj

# this is a decorated fn, and we need to the underlying fn and its rcb
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)

# 检查是否有重载,因为 function 是根据函数名寻找函数的
_check_directly_compile_overloaded(obj)

# 检查之前是否编译过了
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn

# 获得抽象语法树
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)

# 调用 C++ 函数,根据抽象语法树获得 IR
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__

# 将编译结果缓存
_set_jit_function_cache(obj, fn)
return fn

get_jit_def

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
"""
Build a JIT AST (TreeView) from the given function.

Args:
fn: A function object to compile
def_name: The name to give to the resulting AST object. This is not
always the same as `fn.__name__`, for example:
def _forward(self):
...
forward = _forward
In this case, the `__name__` attribute of the function object is "_forward",
but we want the result AST to have the name "forward".
self_name: If this function is a method, what the type name of `self` is.
"""

# dedent_src 为包含了要 script 函数的字符串
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)

# 调用 Python ast 包将字符串解析为 Python 的抽象语法树
py_ast = ast.parse(dedent_src)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])

# 如果使用 torch.jit.annotate 标注变量类型,type_line 为变量类型标注
type_line = torch.jit.annotations.get_type_line(source)

# SourceContext 是 SourceRangeFactory 的 wrapper,包含编译所需的所有元信息
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)

# 因为编译的是个单独的函数,所以取其第一个成员
fn_def = py_ast.body[0]

# 如果使用了 MonkeyType,则可以利用其生成的类型标注数据(MonkeyType 是通过运行时跟踪类型并自动将类型注释添加到 Python 3 代码的工具)
# If MonkeyType is installed, get all the consolidated type traces
# for the arguments from type_trace_db
type_trace_db = torch.jit._script._get_type_trace_db()
pdt_arg_types = None
if monkeytype_trace:
qualname = get_qualified_name(fn)
pdt_arg_types = type_trace_db.get_args_types(qualname)

# build_def 将 Python 的抽象语法树转化为 PyTorch 使用的抽象语法树格式
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)

py_ast

为了查看 py_ast 这个对象的内容,可以使用下面这行代码来打印。

1
import yaml, jsonpickle; print(yaml.dump(yaml.load(jsonpickle.encode(py_ast)), indent=2))

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
- py/type: _ast.Module
- body:
- py/reduce:
- py/type: _ast.FunctionDef
- args:
py/reduce:
- py/type: _ast.arguments
- args:
- py/reduce:
- py/type: _ast.arg
- annotation:
py/reduce:
- py/type: _ast.Name
- col_offset: 15
ctx:
py/reduce:
- py/type: _ast.Load
- {}
id: int
lineno: 1
arg: a
col_offset: 12
lineno: 1
defaults: []
kw_defaults: []
kwarg: null
kwonlyargs: []
vararg: null
body:
- py/reduce:
- py/type: _ast.Assign
- col_offset: 4
lineno: 2
targets:
- py/reduce:
- py/type: _ast.Name
- col_offset: 4
ctx:
py/reduce:
- py/type: _ast.Store
- {}
id: a
lineno: 2
value:
py/reduce:
- py/type: _ast.BinOp
- col_offset: 8
left:
py/reduce:
- py/type: _ast.Name
- col_offset: 8
ctx:
py/id: 12
id: a
lineno: 2
lineno: 2
op:
py/reduce:
- py/type: _ast.Add
- {}
right:
py/reduce:
- py/type: _ast.Num
- col_offset: 12
lineno: 2
n: 2
- py/reduce:
- py/type: _ast.Return
- col_offset: 4
lineno: 3
value:
py/reduce:
- py/type: _ast.Name
- col_offset: 11
ctx:
py/id: 12
id: a
lineno: 3
col_offset: 0
decorator_list: []
lineno: 1
name: my_func
returns:
py/reduce:
- py/type: _ast.Name
- col_offset: 23
ctx:
py/id: 12
id: int
lineno: 1

将上述内容简化如下:

  • ast.body[0]
    • _ast.Assign
      • op: _ast.Add
      • left: _ast.Name
        • id: a
      • right: _ast.Num
        • id: 2
    • _ast.Return
      • value: _ast.Name
        • id: a

ast.body 是一个 list,其长度等于解析的 string 中包含的函数的个数。其第一个元素的 body 成员包含两个元素,类型分别为 _ast.Assign_ast.Return。对于前者,其 value 是一个 Binop,包含三个成员:op = Add、left = Name、right = Num。这个 Binop 即解析的 a = a + 2

buid_def

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
body = py_def.body

# 这一堆代码是利用输入参数构造了一堆元信息
r = ctx.make_range(py_def.lineno + len(py_def.decorator_list),
py_def.col_offset,
py_def.col_offset + len("def"))

param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
return_type = None
if getattr(py_def, 'returns', None) is not None:
return_type = build_expr(ctx, py_def.returns)

decl = Decl(r, param_list, return_type)
is_method = self_name is not None
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)

# build_stmts 函数利用了传进来了 ctx 和 body,其中 ctx 包含了源代码的元信息,body 则是抽象语法树
return Def(Ident(r, def_name),
decl,
build_stmts(ctx, body))

build_stmts

1
2
3
4
5
6
7
8
9
10
11
12
# 以下模块均从 C++ 导入
from torch._C._jit_tree_views import (
ClassDef, Ident, Stmt, Decl, Def, Var,
EmptyTypeAnnotation, Param, ExprStmt, Assign,
Delete, Return, Raise, Assert, AugAssign, While,
For, If, Pass, Break, Continue, Apply, Dots, Select,
TrueLiteral, FalseLiteral, NoneLiteral, Starred,
ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
1
2
3
4
# 对语法树进行递归调用,stmt 是 statement(语句)的简写
def build_stmts(ctx, stmts):
stmts = [build_stmt(ctx, s) for s in stmts]
return list(filter(None, stmts))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# build_stmt 会调用 StmtBuilder.__call__
build_stmt = StmtBuilder()
build_expr = ExprBuilder()


class Builder(object):
# StmtBuilder 与 ExprBuilder 继承自 Builder 且没有重写 __call__ 函数,所以最终调用的还是 Builder.__call__
def __call__(self, ctx, node):
# 根据语法树节点的类型拼凑方法的名字以完成调用
method = getattr(self, 'build_' + node.__class__.__name__, None)
if method is None:
raise UnsupportedNodeError(ctx, node)
return method(ctx, node)


# 这里只留下了和例子相关的 build 函数
class StmtBuilder(Builder):
@staticmethod
def build_Expr(ctx, stmt):
value = stmt.value
if value.__class__.__name__ == 'Str':
# If a statement is a string literal expression,
# then it is a docstring. Just ignore it.
return None
else:
return ExprStmt(build_expr(ctx, value))

@staticmethod
def build_Assign(ctx, stmt):
rhs = build_expr(ctx, stmt.value) # stmt.value = _ast.BinOp
lhs = [build_expr(ctx, x) for x in stmt.targets] # x = _ast.Name
return Assign(lhs, rhs)

@staticmethod
def build_Return(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))


class ExprBuilder(Builder):
binop_map = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
ast.Pow: '**',
ast.Mod: '%',
ast.FloorDiv: '//',
ast.BitAnd: '&',
ast.BitXor: '^',
ast.BitOr: '|',
ast.LShift: '<<',
ast.RShift: '>>',
}

binop_map[ast.MatMult] = '@'

@staticmethod
def build_Name(ctx, expr):
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
if expr.id.startswith(_reserved_prefix):
raise NotSupportedError(r, "names of variables used in JIT-ed functions "
"can't start with " + _reserved_prefix)
if expr.id == "True":
return TrueLiteral(r)
elif expr.id == "False":
return FalseLiteral(r)
elif expr.id == "None":
return NoneLiteral(r)
elif expr.id == "Ellipsis":
return Dots(r)
return Var(Ident(r, expr.id))

@staticmethod
def build_Num(ctx, expr):
value = str(expr.n)
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
return Const(r, value)

@staticmethod
def build_BinOp(ctx, expr):
lhs = build_expr(ctx, expr.left) # _ast.Name
rhs = build_expr(ctx, expr.right) # _ast.Num
op = type(expr.op) # _ast.Add

if op == ast.Div and not ctx.uses_true_division:
err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
raise FrontendError(err_range, 'Division of ints in TorchScript uses Python 3 true '
'division semantics. Please put `from __future__ '
'import division` at the top of your file')

# 转化为约定的代表运算类型的 string 符号
op_token = ExprBuilder.binop_map.get(op) # op_token = "+"
if op_token is None:
err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
return BinOp(op_token, lhs, rhs)

最终的 PyTorch AST 可以直接打印出来,如下:

(def
  (ident my_func)
  (decl
    (list
      (param
        (ident a)
        (option (variable (ident int)))
        (option)
        (False)))
    (option (variable (ident int))))
  (list
    (assign
      (list (variable (ident a)))
      (option
        (+
          (variable (ident a))
          (const 2)))
      (option))
    (return (variable (ident a)))))

C++ 部分

几个类

本节介绍 JIT 部分比较常见的几个类,按照 low-level 到 high-level 的顺序介绍。

Graph

Graphs 是用于定义 TorchScript 函数实现的 IR 的基础。如果对 LLVM 比较熟悉的话,这个类似 llvm::Function 类。Graph 主要由 NodeBlockValue 组成。Node 是指令(例如做矩阵乘)。Blocs 是一组有序执行的 Node。每个 Node 输入一系列 Value,输出一系列 Value

用前面分析使用的 Python 代码为例,并打印其 Graph

1
2
3
4
5
6
7
8
9
10
11
12
13
def my_func(a: int) -> int:
a = a + 2
return a

scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
print(scripted_func.graph)

"""
graph(%a.1 : int):
%2 : int = prim::Constant[value=2]()
%a.5 : int = aten::add(%a.1, %2)
return (%a.5)
"""

打印出的内容是 PyTorch IR 的标准文本形式,很容易在其中找到之前提到的元素

  • graphGraph
  • %a.5 : int 是带有类型标注的 Value
  • %a.5 : int = aten::add(%a.1, %2)Node,表示 aten::add 算子,它以 %a.1%2 作为输入,以 %a.5 所谓作为输出

Graph 是静态单赋值(single-static assignment, SSA)的形式,即每个变量只会被赋值一次,每一次赋值定义一个新的变量。

Function

Function 有两种:BuiltinOpFunctionGraphFunction。前者是 native C++ 函数的 wrapper。后者拥有以下类型的主要成员:

  • FunctionSchema:描述输入参数和返回值的类型和名称

  • Graph:即 IR,定义了函数的实现。下一节会详细介绍

  • GraphExecutor:用于实际执行定义该方法的 Graph。下一节会详细介绍

Method

Method 可以简单理解为是 Function 的 wrapper。拥有 Module 所有的可执行方法。

CompilationUnit

CompilationUnit 本质是一组 FunctionCompilationUnit 拥有几个 define 函数,可以将 AST 转为 IR 存放在其 Function 类型的成员中。

StrongFunctionPtr

StrongFunctionPtrCompilationUnitFunction 的关系正如下面代码所描述的。

1
2
3
4
5
6
7
8
9
struct StrongFunctionPtr {
StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
: cu_(std::move(cu)), function_(function) {
TORCH_INTERNAL_ASSERT(cu_);
TORCH_INTERNAL_ASSERT(function_);
}
std::shared_ptr<CompilationUnit> cu_;
Function* function_;
};

Module

TorchScript 和 LibTorch 交互的中介。序列化和反序列化都是以 Module 为单位进行的。Module 继承自 Object 类,拥有一个 ObjectPtr 类型成员,ObjectPtr 拥有类型为 StrongTypePtr 的成员 type_ 和类型为 std::vector<IValue> 的成员 slots_。其中前者存有该 module 所有 Method,后者存有该 module 的所有数值类成员。大概结构如下:

  • Module : Object
    • ObjectPtr _ivalue_
      • StrongTypePtr type_
      • std::vector<IValue> slots_

步骤

整体调用逻辑如下:

  • script_compile_function
    • get_python_cu
    • CompilationUnit::define

script_compile_function

1
2
3
4
5
6
7
8
9
10
11
12
// 将 script_compile_function 函数注册到 Python 接口 torch._C._jit_script_compile
m.def(
"_jit_script_compile",
[](const std::string& qualname,
const Def& def,
const ResolutionCallback& rcb,
const FunctionDefaults& defaults) {
C10_LOG_API_USAGE_ONCE("torch.script.compile");
const auto name = c10::QualifiedName(qualname);
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
return script_compile_function(name, def, defaults, rcb);
});
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// StrongFunctionPtr 对应到 Python 就是 torch.jit.ScriptFunction
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
const ResolutionCallback& rcb) {
auto cu = get_python_cu();
// 通过后面的解释可以知道,defined_functions 包含了一系列 creator 函数,其作用是将 AST 转为 IR
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
/*properties=*/{},
/*propResolvers=*/{},
{def},
{pythonResolver(rcb)},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
// ret 中包含了一个 CompilationUnit 成员和一个 Function 成员,他们都包含了 creator 函数
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}

get_python_cu

1
2
3
4
5
6
// 构造 CompilationUnit 对象
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}

CompilationUnit::define

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> CompilationUnit::define(
const c10::optional<c10::QualifiedName>& prefix,
const std::vector<Property>& properties,
const std::vector<ResolverPtr>& propResolvers,
const std::vector<Def>& definitions, // 这个是 torch 抽象语法树
const std::vector<ResolverPtr>& defResolvers, // determines how we handle free variables in each definition
const Self* self, // if non-null, the first argument to each def, is bound to this value
bool shouldMangle = false /* see [name mangling] */) {
TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size());
TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size());
std::vector<Function*> functions;
std::unordered_map<std::string, Function*> function_table;

// Records fn in function_table, functions and with register_function.
// This is done several times below, so this lambda helps avoid repeating
// code.
auto record_function = [&](std::unique_ptr<Function> fn) {
function_table[fn->name()] = fn.get();
functions.emplace_back(fn.get());
this->register_function(std::move(fn));
};

for (size_t i = 0; i < properties.size(); i++) {
PropertyPair property_fns = define_property(
prefix,
properties[i],
propResolvers[i],
self,
function_table,
shouldMangle);

auto& getter_fn = property_fns.getGetter();
auto& setter_fn = property_fns.getSetter();

record_function(std::move(getter_fn));

if (setter_fn) {
record_function(std::move(setter_fn));
}
}

// torch 抽象语法树传给了 define 函数的另一个重载
for (size_t i = 0; i < definitions.size(); i++) {
auto fn = define(
prefix,
definitions[i],
defResolvers[i],
self,
function_table,
shouldMangle,
CompilationUnit::FunctionType::Method);

// 通过后面的代码可以知道,fn 包含了 creator 函数,其功能是将 AST 转为 IR
record_function(std::move(fn));
}

// We need to compile `__init__` first, since it can determine what attributes
// are available to other methods. So reorder the definitions accordingly.
for (auto& kv : function_table) {
if (kv.first == "__init__") {
kv.second->ensure_defined();
}
}

for (Function* function : functions) {
// 调用 function 中保存的 creator,通过 AST 创建 IR,存入 function 中
function->ensure_defined();
}

return functions;
}

void GraphFunction::ensure_defined() {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
creator(*this);
function_creator_ = nullptr;
}
check_single_output();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
std::unique_ptr<Function> CompilationUnit::define(
const c10::optional<QualifiedName>& prefix,
const Def& def,
const ResolverPtr& resolver,
const Self* self,
const std::unordered_map<std::string, Function*>& function_table,
bool shouldMangle,
CompilationUnit::FunctionType type) const {
TORCH_INTERNAL_ASSERT(resolver);
auto _resolver = resolver;

// script_compile_function 传进来的是个 nullptr
if (!self) {
// if self is defined, then these are methods and do not go into the
// global namespace otherwise, they get defined together so we add them to
// the function table so the methods can see each other
_resolver =
std::make_shared<FunctionResolver>(resolver.get(), function_table);
}

// creator 将 def(存有 torch 抽象语法树)通过 FunctionResolver 转成 IR 存入 method 中
auto creator = [def, _resolver, self](Function& method) {
// Store the function name so that it can be referenced if there is an error
// while compiling this function
std::string call_name = method.qualname().name();
if (self) {
auto atoms = method.qualname().atoms();
// There should be at least a ClassName.method_name
TORCH_INTERNAL_ASSERT(atoms.size() >= 2);
call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1);
}
ErrorReport::CallStack call(call_name, def.range());
to_ir(def, _resolver, self, method);
};

auto name = prefix ? QualifiedName(*prefix, def.name().name())
: QualifiedName(def.name().name());
if (shouldMangle) {
// If `shouldMangle` is set, we should generate a unique name for this
// function if there is already an existing one.
if (auto fn = find_function(name)) {
name = mangle(name);
}
}

// 将 creator 存入 fn 中
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
if (self) {
// Register this as a method on `self`'s type
if (type == CompilationUnit::FunctionType::Hook) {
self->getClassType()->addForwardHook(fn.get());
} else if (type == CompilationUnit::FunctionType::PreHook) {
self->getClassType()->addForwardPreHook(fn.get());
} else {
self->getClassType()->addMethod(fn.get());
}
}

return fn;
}

define 函数中,function 中保存的 creator 被调用,其通过 AST 创建 IR,存入 function 中。creator 函数会调用 to_ir 创建 IR。

to_ir

Script 模式的整个流程可以简单分为两步:首先生成 AST(以树对象的形式呈现),然后 IR emitter 通过对该树对象做语义分析从而将其转化为 Module。第二步的主要工作由结构体 to_ir 的构造函数完成。在 emitStatements 函数中,语句根据其类型被 emit 到 Graph 中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
struct to_ir {
to_ir(
const Def& def, // AST
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
typeParser_(resolver),
environment_stack(nullptr) {
......

method.setSchema(emitDef(def, self, graph->block()));

......
}
}

FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
......

emitStatements(stmts_list.begin(), stmts_list.end());

......
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
void emitStatements(
List<Stmt>::const_iterator begin,
List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
ErrorReport::CallStack::update_pending_range(stmt.range());
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
case TK_AUG_ASSIGN:
emitAugAssignment(AugAssign(stmt));
break;
case TK_EXPR_STMT: {
auto expr = ExprStmt(stmt).expr();
emitSugaredExpr(expr, 0);
} break;
case TK_RAISE:
emitRaise(Raise(stmt));
break;
case TK_ASSERT:
emitAssert(Assert(stmt));
break;
case TK_RETURN: {
emitReturn(Return(stmt));
} break;
case TK_CONTINUE: {
emitContinue(Continue(stmt));
} break;
case TK_BREAK: {
emitBreak(Break(stmt));
} break;
case TK_PASS:
// Emit nothing for pass
break;
case TK_DEF:
emitClosure(Def(stmt));
break;
case TK_DELETE:
emitDelete(Delete(stmt));
break;
case TK_WITH:
emitWith(With(stmt));
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
// Found an exit statement in this block. The remaining statements aren't
// reachable so we don't emit them.
if (exit_blocks.count(environment_stack->block()))
return;
}
}

emitX 函数会调用 graph->insert 将各种 Node 插入到 Graph 中。emitX 函数的细节比较复杂,暂不展开介绍。

Module::save

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 将 Module::save 注册到 torch.jit.ScriptFunction
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
.def(
"save",
[](const StrongFunctionPtr& self,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
Module module("__torch__.PlaceholderModule");
// [issue 27343]
// Modules have 'training' attributes by default, but due to
// https://github.com/pytorch/pytorch/issues/27343, functions end
// up having a training attribute when they are loaded. This adds
// a fake 'training' attribute that shouldn't be used, but prevents
// jitter on saving and loading. Once that issue is fixed this can
// be deleted.
module.register_attribute("training", BoolType::get(), true);
// 将 StrongFunctionPtr 添加到 Module 对象中
addFunctionToModule(module, self);
module.save(filename, _extra_files);
},
py::arg("filename"),
py::arg("_extra_files") = ExtraFilesMap());
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
// Make a graph with a fake self argument
auto graph = func.function_->graph()->copy();
auto v = graph->insertInput(0, "self");
v->setType(module._ivalue()->type());
const auto name = QualifiedName(*module.type()->name(), "forward");
// 将 StrongFunctionPtr 转成 Method 对象
auto method = module._ivalue()->compilation_unit()->create_function(name, graph);
// 将 Method 对象添加到 Module 对象中
module.type()->addMethod(method);
}

void ClassType::addMethod(torch::jit::Function* method) {
TORCH_CHECK(
findMethod(method->name()) == nullptr,
"Can't redefine method: ",
method->name(),
" on class: ",
repr_str());
methods_.push_back(method);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
void Module::save(const std::string& filename, const ExtraFilesMap& extra_files) const {
ExportModule(*this, filename, extra_files, false /* bytecode_format */);
}

void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info) {
caffe2::serialize::PyTorchStreamWriter writer(filename);
ScriptModuleSerializer serializer(writer);
serializer.serialize(module, extra_files, bytecode_format, save_mobile_debug_info);
}

通过代码可以知道,在进行保存时,先将 StrongFunctionPtr 转换成 Module 对象,然后通过 ScriptModuleSerializer 进行序列化。ScriptModuleSerializerScriptModuleDeserializer 是 PyTorch 中用于序列化的类,其中涉及到模型的序列化,可以参考 TorchScript serialization

参考

PyTorch 源码解读之即时编译篇

TorchScript 如何实现Python -> C++ 代码转换

PyTorch JIT Source Code Read Note

pytorch/OVERVIEW.md

TorchScript 原理篇(上) - 保存模型 | PyTorch

http://www.zh0ngtian.tech/posts/b1864870.html

作者

zhongtian

发布于

2021-07-29

更新于

2023-12-16

许可协议

评论