defget_jit_def(fn, def_name, self_name=None, is_classmethod=False): """ Build a JIT AST (TreeView) from the given function. Args: fn: A function object to compile def_name: The name to give to the resulting AST object. This is not always the same as `fn.__name__`, for example: def _forward(self): ... forward = _forward In this case, the `__name__` attribute of the function object is "_forward", but we want the result AST to have the name "forward". self_name: If this function is a method, what the type name of `self` is. """ # dedent_src 为包含了要 script 函数的字符串 sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) dedent_src = dedent(source) # 调用 Python ast 包将字符串解析为 Python 的抽象语法树 py_ast = ast.parse(dedent_src) iflen(py_ast.body) != 1ornotisinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) # 如果使用 torch.jit.annotate 标注变量类型,type_line 为变量类型标注 type_line = torch.jit.annotations.get_type_line(source) # SourceContext 是 SourceRangeFactory 的 wrapper,包含编译所需的所有元信息 ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) # 因为编译的是个单独的函数,所以取其第一个成员 fn_def = py_ast.body[0] # 如果使用了 MonkeyType,则可以利用其生成的类型标注数据(MonkeyType 是通过运行时跟踪类型并自动将类型注释添加到 Python 3 代码的工具) # If MonkeyType is installed, get all the consolidated type traces # for the arguments from type_trace_db type_trace_db = torch.jit._script._get_type_trace_db() pdt_arg_types = None if monkeytype_trace: qualname = get_qualified_name(fn) pdt_arg_types = type_trace_db.get_args_types(qualname) # build_def 将 Python 的抽象语法树转化为 PyTorch 使用的抽象语法树格式 return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
# 这里只留下了和例子相关的 build 函数 classStmtBuilder(Builder): @staticmethod defbuild_Expr(ctx, stmt): value = stmt.value if value.__class__.__name__ == 'Str': # If a statement is a string literal expression, # then it is a docstring. Just ignore it. returnNone else: return ExprStmt(build_expr(ctx, value))
@staticmethod defbuild_Assign(ctx, stmt): rhs = build_expr(ctx, stmt.value) # stmt.value = _ast.BinOp lhs = [build_expr(ctx, x) for x in stmt.targets] # x = _ast.Name return Assign(lhs, rhs)
if op == ast.Div andnot ctx.uses_true_division: err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) raise FrontendError(err_range, 'Division of ints in TorchScript uses Python 3 true ' 'division semantics. Please put `from __future__ ' 'import division` at the top of your file')
// for historic reasons, these are defined in ir_emitter.cpp // Returns the list of Functions just defined. std::vector<Function*> CompilationUnit::define( const c10::optional<c10::QualifiedName>& prefix, const std::vector<Property>& properties, const std::vector<ResolverPtr>& propResolvers, const std::vector<Def>& definitions, // 这个是 torch 抽象语法树 const std::vector<ResolverPtr>& defResolvers, // determines how we handle free variables in each definition const Self* self, // if non-null, the first argument to each def, is bound to this value bool shouldMangle = false/* see [name mangling] */){ TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size()); TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size()); std::vector<Function*> functions; std::unordered_map<std::string, Function*> function_table;
// Records fn in function_table, functions and with register_function. // This is done several times below, so this lambda helps avoid repeating // code. auto record_function = [&](std::unique_ptr<Function> fn) { function_table[fn->name()] = fn.get(); functions.emplace_back(fn.get()); this->register_function(std::move(fn)); };
for (size_t i = 0; i < properties.size(); i++) { PropertyPair property_fns = define_property( prefix, properties[i], propResolvers[i], self, function_table, shouldMangle);
if (setter_fn) { record_function(std::move(setter_fn)); } }
// torch 抽象语法树传给了 define 函数的另一个重载 for (size_t i = 0; i < definitions.size(); i++) { auto fn = define( prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, CompilationUnit::FunctionType::Method);
// 通过后面的代码可以知道,fn 包含了 creator 函数,其功能是将 AST 转为 IR record_function(std::move(fn)); }
// We need to compile `__init__` first, since it can determine what attributes // are available to other methods. So reorder the definitions accordingly. for (auto& kv : function_table) { if (kv.first == "__init__") { kv.second->ensure_defined(); } }
for (Function* function : functions) { // 调用 function 中保存的 creator,通过 AST 创建 IR,存入 function 中 function->ensure_defined(); }
return functions; }
voidGraphFunction::ensure_defined(){ if (function_creator_) { auto creator = function_creator_; function_creator_ = placeholderCreator; creator(*this); function_creator_ = nullptr; } check_single_output(); }
std::unique_ptr<Function> CompilationUnit::define( const c10::optional<QualifiedName>& prefix, const Def& def, const ResolverPtr& resolver, const Self* self, const std::unordered_map<std::string, Function*>& function_table, bool shouldMangle, CompilationUnit::FunctionType type)const{ TORCH_INTERNAL_ASSERT(resolver); auto _resolver = resolver; // script_compile_function 传进来的是个 nullptr if (!self) { // if self is defined, then these are methods and do not go into the // global namespace otherwise, they get defined together so we add them to // the function table so the methods can see each other _resolver = std::make_shared<FunctionResolver>(resolver.get(), function_table); } // creator 将 def(存有 torch 抽象语法树)通过 FunctionResolver 转成 IR 存入 method 中 auto creator = [def, _resolver, self](Function& method) { // Store the function name so that it can be referenced if there is an error // while compiling this function std::string call_name = method.qualname().name(); if (self) { auto atoms = method.qualname().atoms(); // There should be at least a ClassName.method_name TORCH_INTERNAL_ASSERT(atoms.size() >= 2); call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1); } ErrorReport::CallStack call(call_name, def.range()); to_ir(def, _resolver, self, method); };
auto name = prefix ? QualifiedName(*prefix, def.name().name()) : QualifiedName(def.name().name()); if (shouldMangle) { // If `shouldMangle` is set, we should generate a unique name for this // function if there is already an existing one. if (auto fn = find_function(name)) { name = mangle(name); } } // 将 creator 存入 fn 中 auto fn = torch::make_unique<GraphFunction>( std::move(name), std::make_shared<Graph>(), creator); if (self) { // Register this as a method on `self`'s type if (type == CompilationUnit::FunctionType::Hook) { self->getClassType()->addForwardHook(fn.get()); } elseif (type == CompilationUnit::FunctionType::PreHook) { self->getClassType()->addForwardPreHook(fn.get()); } else { self->getClassType()->addMethod(fn.get()); } }
voidemitStatements( List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end){ for (; begin != end; ++begin) { auto stmt = *begin; ErrorReport::CallStack::update_pending_range(stmt.range()); switch (stmt.kind()) { case TK_IF: emitIf(If(stmt)); break; case TK_WHILE: emitWhile(While(stmt)); break; case TK_FOR: emitFor(For(stmt)); break; case TK_ASSIGN: emitAssignment(Assign(stmt)); break; case TK_AUG_ASSIGN: emitAugAssignment(AugAssign(stmt)); break; case TK_EXPR_STMT: { auto expr = ExprStmt(stmt).expr(); emitSugaredExpr(expr, 0); } break; case TK_RAISE: emitRaise(Raise(stmt)); break; case TK_ASSERT: emitAssert(Assert(stmt)); break; case TK_RETURN: { emitReturn(Return(stmt)); } break; case TK_CONTINUE: { emitContinue(Continue(stmt)); } break; case TK_BREAK: { emitBreak(Break(stmt)); } break; case TK_PASS: // Emit nothing for pass break; case TK_DEF: emitClosure(Def(stmt)); break; case TK_DELETE: emitDelete(Delete(stmt)); break; case TK_WITH: emitWith(With(stmt)); break; default: throwErrorReport(stmt) << "Unrecognized statement kind " << kindToString(stmt.kind()); } // Found an exit statement in this block. The remaining statements aren't // reachable so we don't emit them. if (exit_blocks.count(environment_stack->block())) return; } }
// 将 Module::save 注册到 torch.jit.ScriptFunction py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr()) .def( "save", [](const StrongFunctionPtr& self, const std::string& filename, const ExtraFilesMap& _extra_files = ExtraFilesMap()) { Module module("__torch__.PlaceholderModule"); // [issue 27343] // Modules have 'training' attributes by default, but due to // https://github.com/pytorch/pytorch/issues/27343, functions end // up having a training attribute when they are loaded. This adds // a fake 'training' attribute that shouldn't be used, but prevents // jitter on saving and loading. Once that issue is fixed this can // be deleted. module.register_attribute("training", BoolType::get(), true); // 将 StrongFunctionPtr 添加到 Module 对象中 addFunctionToModule(module, self); module.save(filename, _extra_files); }, py::arg("filename"), py::arg("_extra_files") = ExtraFilesMap());