TVM - Relay IR Pass

本文介绍Relay IR Pass的构造。

Relay IR Pass核心依然是在C++中实现,但提供了Python接口,方便上层直接调用并对计算流图进行变换优化。

Pass管理器在include/tvm/relay/transform.h中,里面包含所有Pass的声明,希望做到

  • 管理调度不同的优化pass
  • 收集需要的分析信息,并且保持它们是最新的
  • 减少程序员实现新pass的麻烦

Python的接口函数声明在python/tvm/relay/transform.py中,在python/tvm/relay/_transform.py中通过FFI对C++函数进行调用,命名空间为relay._transform

具体C++的实现则分为两个部分:

  • 高层IR图变换,源码在src/relay/pass中,集中变换则是在src/relay/backend/build_module.cc中的relay::Module Optimize
  • 后端代码的图变换,源码在src/relay/backend/vm中,集中变换在python/tvm/build_module.py中的lower函数

Pass的构造

  • PassInfo
      class PassInfoNode : public RelayNode {
      std::string name;
      int opt_level;
      std::vector<std::string> required;
      };
    
  • PassContext
      class PassContextNode : public RelayNode {
      public:
      ErrorReporter err_reporter;
      int opt_level{2};
      int fallback_device{static_cast<int>(kDLCPU)};
      tvm::Array<tvm::Expr> required_pass;
      tvm::Array<tvm::Expr> disabled_pass;
      };
    
      class PassContext : public NodeRef {
      public:
      TVM_DLL static PassContext Create();
      TVM_DLL static PassContext Current();
      /* Other fields are omitted. */
    
      private:
      // The entry of a pass context scope.
      TVM_DLL void EnterWithScope();
      // The exit of a pass context scope.
      TVM_DLL void ExitWithScope();
    
      // Classes to get the Python `with` like syntax.
      friend class tvm::With<PassContext>;
      };
    
      struct RelayPassContextThreadLocalEntry {
      /*! \brief The default pass context. */
      PassContext default_context;
      /*! \brief The current pass context. */
      std::stack<PassContext> context_stack;
      RelayPassContextThreadLocalEntry() {
          default_context = PassContext(make_node<PassContextNode>());
      }
      };
    
      /*! \brief The thread-local store to hold the pass context. */
      typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
          RelayPassContextThreadLocalStore;
    
  • Pass Constructs:提供基类
      class PassNode : RelayNode {
      virtual PassInfo Info() const = 0;
      virtual Module operator()(const IRModule& mod
                                  const PassContext& pass_ctx) const = 0;
      };
    

也就是说,一个Pass一定是作用在特定context下的IRModule,所有Pass都设计成ModuleModule的映射,完整Pass的定义在src/relay/ir/transform.ccsrc/ir/transform.cc中。

Module-Level

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

Function-Level

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

Sequential

类似于PyTorch中的nn.Sequential,顺序执行多个Pass

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

参考资料