本文介绍Relay IR Pass的构造。
Relay IR Pass核心依然是在C++中实现,但提供了Python接口,方便上层直接调用并对计算流图进行变换优化。
Pass管理器在include/tvm/relay/transform.h
中,里面包含所有Pass的声明,希望做到
Python的接口函数声明在python/tvm/relay/transform.py
中,在python/tvm/relay/_transform.py
中通过FFI对C++函数进行调用,命名空间为relay._transform
。
具体C++的实现则分为两个部分:
src/relay/pass
中,集中变换则是在src/relay/backend/build_module.cc
中的relay::Module Optimize
src/relay/backend/vm
中,集中变换在python/tvm/build_module.py
中的lower
函数 class PassInfoNode : public RelayNode {
std::string name;
int opt_level;
std::vector<std::string> required;
};
class PassContextNode : public RelayNode {
public:
ErrorReporter err_reporter;
int opt_level{2};
int fallback_device{static_cast<int>(kDLCPU)};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
RelayPassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
RelayPassContextThreadLocalStore;
class PassNode : RelayNode {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
也就是说,一个Pass一定是作用在特定context下的IRModule
,所有Pass都设计成Module
到Module
的映射,完整Pass的定义在src/relay/ir/transform.cc
和src/ir/transform.cc
中。
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
类似于PyTorch中的nn.Sequential
,顺序执行多个Pass
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};