【TVM】Relax.zero Pipeline源码分析

源码位置:

本文分析的是zero_pipeline: 完成relax func降级,算子融合、常量折叠等工作。

mod_after_pipeline = relax.pipeline.get_pipeline()(mod)

pipeline包含的transform op

seq = tvm.transform.Sequential(
	[
		transform.LegalizeOps(enable_warning=enable_warning),
		transform.AnnotateTIROpPattern(),
		transform.FoldConstant(),
		transform.FuseOps(),
		transform.FuseTIR(),
	]
)

LegalizeOps #

pass功能:对relax op进行降级,生成 tir op 相关类:

  • LegalizeMutator:对mod进行重写
Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
                                                                            PassContext pc) {
    bool apply_legalize_ops =
        pc->GetConfig<Bool>("relax.transform.apply_legalize_ops").value_or(Bool(true))->value;
    if (apply_legalize_ops) {
      mod = LegalizeMutator(mod, cmap, enable_warning).Transform();
    }
    return mod;
  };
  return CreateModulePass(/*pass_function=*/pass_func,
                          /*opt_level=*/0,
                          /*pass_name=*/"LegalizeOps",
                          /*required=*/{});

LegalizeMutator #

数据成员:

  • mod_:IRModule
  • cmap_: Map<String, PackedFunc> cmap_:包含自定义的legalization function map
  • enable_warning_:bool:

Transform:

  • 遍历所有的函数,只对relax function做处理
  IRModule Transform() {
    for (const auto& [gv, func] : mod_->functions) {
      if (func->IsInstance<FunctionNode>()) {
        auto updated_func = Downcast<Function>(this->VisitExpr(func));
        builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
      }
    }
    return builder_->GetContextIRModule();
  }

处理relax.call的op:

  • 获取降级属性Map:legalize_map
  • 获取Op:call_pure_packed_op、call_tir_op、call_dps_packed_op
  • 判断是否自定义的算子实现,如果是,使用自定义的cmap进行legalize, 并WrapPureCall
  • 使用Legalize_map进行降级,并WrapPureCall
  using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

  Expr VisitExpr_(const CallNode* call) final {
    Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
    
    static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
    static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
    static const Op& call_tir_op = Op::Get("relax.call_tir");
    static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
    
    auto* op_node = visited_call->op.as<OpNode>();

    auto op = GetRef<Op>(op_node);
    std::string op_name(op->name);
    bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);

    // Priority: customize > default.
    // Check if it has customize legalization registered.
    if (cmap_.defined() && cmap_.value().count(op->name)) {
      auto ret = cmap_.value()[op->name](this->builder_, visited_call);
      if (ret.IsObjectRef<Expr>() && WrapPureCondition(op, ret.AsObjectRef<Expr>())) {
        return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
      }
      return ret;
    }
    // Check if it has default legalization registered.
    if (legalize_map.count(op)) {
      auto ret = legalize_map[op](this->builder_, visited_call);
      if (WrapPureCondition(op, ret)) {
        return WrapPureCall(Downcast<Call>(ret));
      }
      return ret;
    }

    return visited_call;
  }

WrapPureCall:

  Call WrapPureCall(const Call& ret) {
    static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
    Array<Expr> ret_args = {ret->op};
    for (auto arg : ret->args) {
      ret_args.push_back(arg);
    }
    return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
  }

实例分析 add #

大多数op的legalize注册使用topi实现:

# /python/tvm/relax/transform/legalize_ops/binary.py
register_legalize("relax.add", _binary(topi.add))

def _binary(te_func: TEFunc) -> LegalizeFunc:
    """A common wrapper util for the legalization of binary operators.

    It detects if one of the binary op arguments is a constant scalar. It so,
    it extracts the scalar value to simplify the generated PrimFunc.
    """

    def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
        # To simplify the created PrimFunc, we first check if arg1 is a constant scalar.
        # If it is not, we then check if arg0 is a constant scalar.
        arg0 = call.args[0]
        arg1 = _try_convert_to_scalar_const(call.args[1])
        if isinstance(arg1, Expr):  # type: ignore
            arg0 = _try_convert_to_scalar_const(arg0)
        return bb.call_te(te_func, arg0, arg1)

    return binary_call_te

topi实现:

# /src/topi/broadcast.cc
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);

TOPI_DEFINE_BCAST_OP(add, { return a + b; });

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)                                                   \
  inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; }      \
  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B,                 \
                              std::string name = "T_" #Name, std::string tag = kBroadcast) {      \
    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };                               \
    return detail::WithBroadcast(l, A, B, name, tag);                                             \
  }                                                                                               \
  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B,                   \
                              std::string name = "T_" #Name, std::string tag = kElementWise) {    \
    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };                               \
    return tvm::te::compute(                                                                      \
        A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
  }                                                                                               \
  inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B,                   \
                              std::string name = "T_" #Name, std::string tag = kElementWise) {    \
    auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };                              \
    return tvm::te::compute(                                                                      \
        B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
  }

ConstantFolder #

  • 对于每个函数,调用ConstantFolder::Fold(f, m)
Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) { 
		      return ConstantFolder::Fold(f, m); 
	  };
  return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant);

ConstantFolder实现:

  static Function Fold(Function func, IRModule ctx_module) {
    ConstantFolder folder(std::move(ctx_module));
    func = RemoveAllUnused(Downcast<Function>(folder(func)));
    return func;
  }

VisitExpr #

  Expr VisitExpr_(const CallNode* call) final {
    // post-order mutation
    Call post_call = Downcast<Call>(VisitExprPostOrder_(call));

    // Check if it is useful to fold this call
    if (!ShouldBeFolded(post_call)) return post_call;

    static const Op& call_tir_op = Op::Get("relax.call_tir");
    static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
    auto* op_node = post_call->op.as<OpNode>();

    auto op = GetRef<Op>(op_node);

    if (op.same_as(call_tir_op)) {
      return VisitCallTIR(post_call).value_or(post_call);
    }
    post_call =
        Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span);
    ...
    // If we are in a dataflow block, we can fold ops.
    if (builder_->CurrentBlockIsDataFlow()) {
      // Check if we can them to call_tir
      if (legalize_map.count(op)) {
        // Get the legalized expression
        Call post_call_normalized = Downcast<Call>(builder_->Normalize(post_call));
        Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call_normalized));
        // If the legalized expression is call_tir, try to fold it.
        const CallNode* call = legalized_expr.as<CallNode>();
        if (call && call->op.same_as(call_tir_op)) {
          return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
        }
        ...
    }
    return std::move(post_call);
  }

AnnotateTIROpPattern #

类型:函数PASS PASS功能:标记TIR函数的Op模式类型

# src/relax/transform/annotate_tir_op_pattern.cc
namespace transform {

Pass AnnotateTIROpPattern() {
  auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
    return AnnotateOpPattern(std::move(f));
  };
  return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {});
}

TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);

实现:

  • 分析函数的算子模式类型
  • 将类型属性添加到函数属性中
# src/relax/transform/annotate_tir_op_pattern.cc
tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
  if (f->HasNonzeroAttr("op_pattern")) {
    return f;
  } else {
    relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
    return WithAttr(std::move(f), "op_pattern", Integer(static_cast<int>(kind)));
  }
}

算子模式类型: kElemWise < kBroadcast < kInjective < kCommReduce < kOutEWiseFusable < kTuple < kOpaque

/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
  // Elementwise operation
  kElemWise = 0,
  // Broadcasting operator, can always map output axis to the input in order.
  // for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
  // Note that the axis need to be in order so transpose is not a bcast operator.
  kBroadcast = 1,
  // Injective operator, can always injectively map output axis to a single input axis.
  // All injective operator can still be safely fused to injective and reduction.
  kInjective = 2,
  // Communicative reduction operator.
  kCommReduce = 3,
  // Complex operation, can still fuse elemwise operations into its output.
  // but cannot chain another complex op
  kOutEWiseFusable = 4,
  // The pattern for tuple nodes. Can fuse into subsequent injective ops,
  // but treated specially
  kTuple = 7,
  // Opaque operation, cannot fuse anything.
  kOpaque = 8
};

PatternKindAnalyzer #

分析器:

# src/relax/analysis/tir_op_pattern_kind.cc
relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) {
  PatternKindAnalyzer analyzer(func);
  analyzer(func->body);
  return analyzer.GetResult();
}

VisitStmt实现:

  • visit 函数参数
  • visit block
  • visit BufferStoreNode:只支持一个BufferStore
  • visit BufferLoadNoe:push_back BufferLoad
  explicit PatternKindAnalyzer(const tir::PrimFunc& func) {
    for (const tir::Var& param : func->params) {
      Optional<Buffer> param_buf = func->buffer_map.Get(param);
      if (param_buf.defined()) {
        param_buffers_.insert(param_buf.value());
      }
    }
  }

BufferStore和BufferLoad

  void VisitStmt_(const BufferStoreNode* op) final {
    // If we have already seen buffer store in the current block, classify as Opaque.
    if (store_.defined() && !IsSameArray(op->indices, store_.value()->indices)) {
      kind_ = relay::kOpaque;
      return;
    }
    store_ = GetRef<BufferStore>(op);
    StmtVisitor::VisitStmt_(op);
  }

  void VisitExpr_(const BufferLoadNode* op) final {
    loads_.push_back(GetRef<BufferLoad>(op));
    ExprVisitor::VisitExpr_(op);
  }

数据成员 #

  • store_:bufferstore节点
  • loads_:bufferload节点
  • kind_:节点模式
  • param_buffers:
 private:
  /*!
   * \brief The BufferStore node in the current block.
   * \note We only support one BufferStore node in a block (usually generated by TE compute)
   */
  Optional<BufferStore> store_;
  /*! \brief The BufferLoad nodes in the current block. */
  Array<BufferLoad> loads_;
  /*! \brief The result of op pattern. */
  relay::OpPatternKind kind_ = relay::kElemWise;
  /*! \brief The buffers from function params. I.e. the input and output buffers. */
  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> param_buffers_;

 public:
  relay::OpPatternKind GetResult() { return kind_; }
};

Visit Block #

  1. 不处理root block
  2. 如果当前block没有任何存储,则分类为Opaque
  3. 遍历所有的BufferLoad,根据load和store判断op模式
    • IsElemwisePattern
    • IsBroadcastPattern
    • IsInjectivePattern
  void VisitStmt_(const BlockNode* op) final {
    // Step 1. Clear loads and store
    loads_.clear();
    store_ = NullOpt;
    // Step 2. Visit block body.
    StmtVisitor::VisitStmt(op->body);

    BufferStore store = store_.value();

    // Step 3. Checking load store indices pattern
    relay::OpPatternKind index_pair_pattern = relay::kElemWise;
    bool has_elem_wise = false;
    for (const BufferLoad& load : loads_) {
      // Since elemwise is stricter than broadcast and broadcast is stricter than injective,
      // while the order amount enums: kElemWise < kBroadcast < kInjective.
      // We can simply use `std::max` to detect these three patterns.
      // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i]
      // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows
      // broadcast pattern.
      if (IsElemwisePattern(store, load)) {
        index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
        has_elem_wise = true;
      } else if (IsBroadcastPattern(store, load)) {
        index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
      } else if (IsInjectivePattern(store, load)) {
        index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
      } else {
        index_pair_pattern = relay::kOpaque;
        break;
      }
    }
    // If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise
    // e.g. A[i, j] = B[i, j] + C[i]
    if (index_pair_pattern == relay::kBroadcast && has_elem_wise) {
      index_pair_pattern = relay::kElemWise;
    }
    // If the block index pattern is not opaque, update kind.
    if (index_pair_pattern != relay::kOpaque) {
      // This rule for softmax: reduce + injective.
      if (IsOutputBlock(op) && kind_ == relay::kCommReduce) {
        kind_ = relay::kOutEWiseFusable;
      } else {
        kind_ = std::max(kind_, index_pair_pattern);
      }
      return;
    }

    // Step 4. Checking if the block contains reduce axis by looking into block iterators.
    bool has_reduction = false;
    Array<tir::Var> reduce_vars;
    for (const IterVar& it : op->iter_vars) {
      if (it->iter_type == kCommReduce) {
        has_reduction = true;
        reduce_vars.push_back(it->var);
      }
    }

    if (has_reduction) {
      if (IsFMA(op->body)) {
        // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
        kind_ = std::max(kind_, relay::kOutEWiseFusable);
        return;
      } else {
        for (size_t i = 0; i < loads_.size(); ++i) {
          // If it's not a pure reduce, regards as kOutEWiseFusable.
          // This rule works for pooling for now.
          if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
            kind_ = std::max(kind_, relay::kOutEWiseFusable);
            return;
          }
        }
      }
      kind_ = std::max(kind_, relay::kCommReduce);
    } else {
      kind_ = relay::kOpaque;
    }
  }

判断算子模式类型 #

模式顺序:kElemWise(最严格) < kBroadcast < kInjective. 例子:判断为broadcast

C[i, j] = A[i, j] + B[i]
# C[i, j], A[i, j] == elemwise
# C[i, j], B[i] == broadcast
  1. IsElemwisePattern:
  • load indices == store indices, 则是elementwise算子
  • 不相同的情况:长度不同,值不同
  static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) {
    return IsSameArray(store->indices, load->indices);
  }
  1. IsBroadcastPattern
  • 判断依据:Load的indice是否在Store的indices中
  • 不处理[0][1]
  static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) {
    size_t ndim_load_buf = load->buffer->shape.size();
    size_t ndim_store_buf = store->buffer->shape.size();

    for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) {
      if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) {
        // Skip unit load dimensions
        // E.g. A[i, j] = B[1, j] is still broadcast
        continue;
      }

      // Try to find the i-th load index in the store indices.
      while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) {
        ++j;
      }

      // It's not broadcast if we cannot find load indices in the store indices in order.
      if (j == ndim_store_buf) {
        return false;
      }
    }
    return true;
  }
  1. IsInjectivePattern
  • 所有的load indices在store indices中,不考虑顺序
  static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) {
    std::unordered_set<const tir::VarNode*> vars;
    for (const PrimExpr& store_index : store->indices) {
      if (const auto* v = store_index.as<tir::VarNode>()) {
        vars.insert(v);
      } else {
        return false;
      }
    }
    for (const PrimExpr& load_index : load->indices) {
      // return false if there are vars used in load indices but not in store indices.
      if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) {
        return false;
      }
    }
    return true;
  }

reduce判断/kOutEWiseFusable #

寻找reduce轴:

bool has_reduction = false;
Array<tir::Var> reduce_vars;
for (const IterVar& it : op->iter_vars) {
  if (it->iter_type == kCommReduce) {
	has_reduction = true;
	reduce_vars.push_back(it->var);
  }
}

判断流程:

  • 保底是:kCommReduce
  • 如果包含:乘加:C[i, j] += A[i, k] * B[j, k], kOutEWiseFusable
  • 如果非Pure Reduce:kOutEWiseFusable
    if (has_reduction) {
      if (IsFMA(op->body)) {
        // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
        kind_ = std::max(kind_, relay::kOutEWiseFusable);
        return;
      } else {
        for (size_t i = 0; i < loads_.size(); ++i) {
          // If it's not a pure reduce, regards as kOutEWiseFusable.
          // This rule works for pooling for now.
          if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
            kind_ = std::max(kind_, relay::kOutEWiseFusable);
            return;
          }
        }
      }
      kind_ = std::max(kind_, relay::kCommReduce);
    }

pure reduce判断:

  • 所有的reduce轴是一个reduce var, 则pure
  • 如果出现j+k这种,是一个add,则not pure
  /*!
   * \brief Checking if it is pure reduce pattern.
   * It's pure reduce pattern iff all reduces axis are directly reduce var
   * E.g. A[i] = sum(B[i, j]) is pure reduce
   *      A[i] = sum(B[i, j + k]) is not pure reduce
   *      pooling is not pure reduce
   */
  static bool IsPureReducePattern(Array<tir::Var> reduce_loops, Array<PrimExpr> indices) {
    for (const PrimExpr& e : indices) {
      int id = -1;
      if (UsesVar(e, [&](const tir::VarNode* var) {
            for (size_t i = 0; i < reduce_loops.size(); ++i) {
              if (reduce_loops[i].get() == var) {
                id = i;
                return true;
              }
            }
            return false;
          })) {
        if (!reduce_loops[id].same_as(e)) {
          return false;
        }
      }
    }
    return true;
  }

FuseOps #

类型:模块PASS PASS功能:对特定特征的算子作融合

实现:

# src/relax/transform/fuse_ops.cc
Pass FuseOps(int fuse_opt_level) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
      [=](IRModule m, PassContext pc) {
        int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
        auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps));
        return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue());
      };
  return CreateModulePass(/*pass_function=*/pass_func,  //
                          /*opt_level=*/0,              //
                          /*pass_name=*/"FuseOps",      //
                          /*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);

FuseOps:

  1. 根据输入IRModule创建index-forward图; arena是一个内存分配器
  2. 应用融合算法对图进行分区
  3. 根据图分区结果,对IRModule中的算子进行融合
IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {
  support::Arena arena;

  // Step 1. Create the indexed-forward graph according to the input IRModule.
  IndexedForwardGraph graph = GraphCreator::Create(mod, &arena);

  // Step 2. Partition the graph by applying the fusion algorithm.
  std::vector<GraphPartitioner::Group*> groups =
      GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);

  // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition
  // results.
  return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform();
}

创建IndexedForward图 #

GraphCreator分析, 数据成员:

  • mod: IRModule
  • arenas_:所有内部节点对象的内存分配器
  • graph_:创建的indexed forward图
  • initialized_nodes:已经设置模式的图节点
class GraphCreator : public ExprVisitor {
 public:...
 private:
  /*! \brief The IRModule from which the indexed forward graph is created */
  IRModule mod_;
  /*! \brief The allocator of all the internal node objects */
  support::Arena* arena_;
  /*! \brief The created indexed forward graph */
  IndexedForwardGraph graph_;
  /*! \brief The graph nodes whose patterns are set */
  std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};

Create分析:

  1. 处理的对象:relax function
  2. 保证以post-dfs的顺序添加创建的节点,然后我们检查是否containers有相同的大小
  static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
    GraphCreator creator(mod, arena);
    for (const auto& it : mod->functions) {
      // Only visit Relax function without attr kPrimitive.
      const auto* func = it.second.as<FunctionNode>();
      if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) {
        continue;
      }
      creator(GetRef<Function>(func));
    }

    // The algorithm of the graph creator ensures that each created node will be added to the
    // post-dfs order and will be set its op pattern. Thus we check whether all these containers
    // have the same size.
    size_t n_nodes = creator.graph_.node_map.size();

    return creator.graph_;
  }

图分析 #

以forward方向的索引的数据流图 数据成员:

  • note_map:将node映射到图节点
  • post_dfs_order:post-dfs顺序
class IndexedForwardGraph {
public:
  /*! \brief The node map that maps node to graph */
  std::unordered_map<const tvm::Object*, Node*> node_map;
  /*! \brief All the nodes in post DFS order */
  std::vector<Node*> post_dfs_order;
}

边:

  struct Edge {
    /*! \brief The corresponding node */
    Node* node{nullptr};
    /*! \brief The respective pattern of this op */
    OpPatternKind pattern{kOpaque};
  };

节点:

  struct Node {
    /*! \brief weak reference to the corresponding edge. */
    const tvm::Object* ref{nullptr};
    /*! \brief The index of the node in topological order. */
    size_t index{0}; // 拓扑顺序的index
    /*! \brief Whether this node is referenced by external source */
    bool extern_ref{false};
    /*! \brief The general pattern in the node */
    OpPatternKind pattern{kOpaque};
    /*! \brief The outputs of the node. */
    LinkedList<Edge> outputs; // 节点的输出边
  };

创建节点:

  IndexedForwardGraph::Node* CreateNode(const Object* key) {
    auto* node = arena_->make<IndexedForwardGraph::Node>();
    graph_.node_map[key] = node;
    return node;
  }

visit Function #

  • 遍历所有的参数,对每个参数创建Graph节点,并标记为Extern, 且节点模式标记为kOpaque
  void VisitExpr_(const FunctionNode* func) final {
    for (const Var& param : func->params) {
      IndexedForwardGraph::Node* param_node = CreateNode(param.get());
      // The parameter is passed in from the outside, and thus it's marked as an external reference,
      // and it's pattern is `kOpaque`.
      MarkAsExternRef(param_node);
      SetNodePattern(param_node, OpPatternKind::kOpaque);
      AddToPostDFSOrder(param_node, param.get());
    }
    ExprVisitor::VisitExpr_(func);
  }

visit bindingBlock #

  void VisitBindingBlock(const BindingBlock& block) final {
    if (const auto* df_block = block.as<DataflowBlockNode>()) {
      VisitBindingBlock_(df_block);
    }
    // We skip ordinary binding blocks since they might be impure (with side effect or control flow)
  }

visit binding #

  • visit VarBindingNode
  • visit MatchCast

MatchCast:

  • 创建图Node,并设置其节点模式为kOpaque
  • 将节点添加到graph.post_dfs_order中
  void VisitBinding_(const MatchCastNode* binding) final {
    IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
    LOG(INFO) << "[FuseOps][GraphCreator] MatchCastNode, not call node var: " << binding->var << ", value: " << binding->value << ", struct_info: " << binding->struct_info;
    SetNodePattern(node, OpPatternKind::kOpaque);
    AddToPostDFSOrder(node, binding->var.get());
  }

VarBinding:

  • 创建节点,根据value类型,选择visit目标
  • 添加到post_dfs_order中
  1. Call:如果函数已经定义模式,则设置;若无,则设置为kOpaque
  2. TupleGetItem:设置节点模式为kInjective
  3. VisitUnsupportedNode:设置节点模式为kOpaque
  void VisitBinding_(const VarBindingNode* binding) final {
    IndexedForwardGraph::Node* node = CreateNode(binding->var.get());

    // If the variable is not a dataflow variable, it must be the output variable of this dataflow
    // block
    if (!binding->var->IsInstance<DataflowVarNode>()) {
      this->MarkAsExternRef(node);
    }
    if (const auto* call = binding->value.as<CallNode>()) {
      // Case 1. The expression is a CallNode
      VisitCall(call, node);
    } else if (const auto* tuple_get_item = binding->value.as<TupleGetItemNode>()) {
      // Case 2. The expression is a TupleGetItemNode
      VisitTupleGetItem(tuple_get_item, node);
    } else {
      VisitUnsupportedNode(binding->value, node);
      // Case 3. The type of the expression is not fusion-supported.
      // In this case, we skip adding edges, adding an empty node into graph.
    }
    AddToPostDFSOrder(node, binding->var.get());
  }

VisitCall:

  • 设置创建节点的模式
  • Visit每个函数参数,将会对每个参数创建节点Node/选择已有的node,并将每个节点放到CallNode的edges中
  void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) {
    ICHECK_NOTNULL(binding_var_node);

    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
    OpPatternKind pattern = OpPatternKind::kOpaque;
    Array<Expr> args = call->args;

    // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the
    // function attribute and visit the arguments one by one.
    // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we
    // recurse into the call expression.
    const auto* op = call->op.as<OpNode>();
    if (op == call_tir_op_.get()) {
      const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

      // Override args for call_tir
      args = Downcast<Tuple>(call->args[1])->fields;

      Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
      if (opt_pattern.defined()) {
        pattern = static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
      } else {
        pattern = OpPatternKind::kOpaque;
      }
    }
    // The pattern of the current binding variable node is set to the pattern of this operator.
    SetNodePattern(binding_var_node, pattern);
    // Visit all call args
    for (const Expr& arg : args) {
      ICHECK(IsLeafOrTuple(arg));
      VisitLeaf(arg, binding_var_node, pattern);
    }
  }

VisitLeaf:

  void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node,
                 const OpPatternKind& pattern) {
    ICHECK_NOTNULL(binding_var_node);

    // Recursive visit if it's Tuple
    if (const auto* tuple = leaf_expr.as<TupleNode>()) {
      for (const Expr& expr : tuple->fields) {
        VisitLeaf(expr, binding_var_node, pattern);
      }
      return;
    }

    if (!leaf_expr->IsInstance<LeafExprNode>()) {
      // Skip GlobalVar, ExternFunc, OpNode.
      return;
    }

    auto it = graph_.node_map.find(leaf_expr.get());
    IndexedForwardGraph::Node* leaf_node = nullptr;
    if (it != graph_.node_map.end()) {
      leaf_node = it->second;
    } else if (leaf_expr->IsInstance<ConstantNode>() || leaf_expr->IsInstance<ShapeExprNode>() ||
               leaf_expr->IsInstance<PrimValueNode>() || leaf_expr->IsInstance<StringImmNode>() ||
               leaf_expr->IsInstance<DataTypeImmNode>()) {
      leaf_node = CreateNode(leaf_expr.get());
      // Since we never fuse constants, the pattern of the constant is set to `kOpaque`.
      SetNodePattern(leaf_node, OpPatternKind::kOpaque);
      AddToPostDFSOrder(leaf_node, leaf_expr.get());
    } else {
      LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr
                 << " used before definition.";
    }
    AddEdge(leaf_node, binding_var_node, pattern);
  }
  // leaf --> curr node
  void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end,
               OpPatternKind pattern) {
    auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
    link->value.node = end;
    link->value.pattern = pattern;
    start->outputs.Push(link);
  }

AddToPostDFSOrder #

函数功能:将输入节点以post-dfs顺序添加到图中

  void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) {
    auto it = graph_.node_map.find(key);

    node->ref = key;
    node->index = graph_.post_dfs_order.size();
    graph_.post_dfs_order.push_back(node);
  }

图分区 #

// Step 2. Partition the graph by applying the fusion algorithm.
std::vector<GraphPartitioner::Group*> groups =
	  GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);

GraphPartitioner #

类描述:由联合查找数据结构标记的图的分区 数据成员:

  • arena_:内部的arena用于临时空间
  • opt_level_:融合操作的优化级别
  • max_fuse_depth:一个fused函数中的最大的算子数目
  • groups_:内部group
  • visited:
# src/relay/analysis/graph_partitioner.h
class GraphPartitioner {
 private:
  /*! \brief The internal arena for temporary space. */
  support::Arena* arena_;
  /*! \brief optimization level for fuse operation. */
  int opt_level_;
  /*! \brief The maximum number of operations in one fused function */
  size_t max_fuse_depth_;
  /*! \brief The internal groups. */
  std::vector<Group*> groups_;
  /*! \brief internal field used for deduplication */
  std::unordered_set<IndexedForwardGraph::Node*> visited_;
}

Group:

  struct Group {
    Group* parent{nullptr}; // The parent in the union find data structure.
    OpPatternKind pattern; //当前group的Op模式类型
    const tvm::Object* root_ref{nullptr}; // 
    const tvm::Object* anchor_ref{nullptr};
    uint32_t num_nodes{1}; // 属于当前group的节点输入
    runtime::Map<runtime::String, ObjectRef> attrs; // 当前group函数的属性
	Group* FindRoot(); // 找到group root, 执行路径压缩
  };

Partition #

函数描述:对图进行拆分 函数流程:

  1. 初始化groups:每个节点一个group
  2. 获取dominator tree:reverse topo遍历,组合OPPattern
  3. 运行融合算法
# src/relay/analysis/graph_partitioner.cc
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  this->InitGroups(graph);
  if (opt_level_ == 0) return std::move(groups_);
  // get post dominator tree
  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
  // run fusion algorithm.
  for (int phase = 0; phase < 3; ++phase) {
    this->RunFuse(graph, post_dom_tree, phase);
  }
  return std::move(groups_);
}

InitGroups #

函数描述:每个节点新建一个group 函数流程:

  • 如果pattern为kOutEWiseFusable,设置anchor_ref
void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
  groups_.resize(graph.post_dfs_order.size());
  // 遍历图的每个节点
  for (size_t nid = 0; nid < groups_.size(); ++nid) {
    const auto* graph_node = graph.post_dfs_order[nid];
    auto* group_node = arena_->make<Group>();
    group_node->pattern = graph_node->pattern;
    group_node->root_ref = graph_node->ref;
    // set anchor ref if necessary.
    if (group_node->pattern == relay::kOutEWiseFusable) {
      group_node->anchor_ref = graph_node->ref;
    }
    groups_[nid] = group_node;
  }
}

RunFuse #

函数描述 函数流程:

  • 遍历所有的groups
  • 当前节点==kObqque或者当前节点没有dominator,则不fuse
  • Phase0:完成(kElemWise,kBroadcast)、(kOutEWiseFusable)
  • Phase1:
  • Phase2:
void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph,    //
                               const DominatorTree& post_dom_tree,  //
                               int phase) {
  for (size_t nid = 0; nid < groups_.size(); ++nid) {
    // the group of current node has been specified already.
    auto* graph_node = graph.post_dfs_order[nid];
    auto* dom_node = post_dom_tree.nodes[nid];
    Group* group_node = groups_[nid];
    // no actions for opaque nodes
    if (group_node->pattern == kOpaque) continue;
    // no actions needed if the current node have no dominator
    if (dom_node->parent == nullptr) continue;
    size_t dom_parent_gindex = dom_node->parent->gnode->index;

    // refuse the fusion if too many ops are going to be fused together
    if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
      continue;

    if (phase == 2) {
      // Fuse injective ops into intermediate tuples, if any
      if (group_node->pattern > relay::kInjective) continue;
      Group* dom_parent_group = groups_[dom_parent_gindex];
      Group* dom_root_group = dom_parent_group->FindRoot();
      // If dom node group has a tuple as its root, we do not fuse tuple fields into it
      if (dom_root_group->pattern == relay::kTuple) continue;
      if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
        // Now we know the tuple has been fused into subsequent injective ops
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
        // dom_root_group can also be tuple, as in inception layers
        // CheckPath is needed to avoid fusing two intermediate tuples
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
      continue;
    }

    // Skip if current node is already fused to the parent.
    if (groups_[dom_parent_gindex] != nullptr &&
        group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
      continue;
    }
    // Do not fuse into tuple for now
    if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
    // Try to fuse current node to its post-dominator.
    if (group_node->pattern == kOutEWiseFusable) {
      if (phase != 0) continue;
      // Path for OutEWiseFusable: conv2d
      // Check if the dominator relation is elemwise.
      if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
        ICHECK(dom_node->parent->gnode != nullptr);
        // The fuse can be executed if all the intermediate ops are still broadcast.
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
    } else if (group_node->pattern <= kBroadcast) {
      // Pre-condition: can only be fused to parent which is injective or reduction.
      if (dom_node->parent != nullptr &&
          (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
        // Check if all the intermediate ops are still broadcast.
        // The final terminal node can already be fused to a OutEWiseFusable group.
        auto fcond = [](OpPatternKind kind, bool is_sink) {
          if (!is_sink) {
            // Elemwise, broadcast, and injective ops on the parallel branches
            // are allowed be fused to the elemwise/broadcast anchor.
            return kind <= kInjective;
          } else {
            return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
                    kind == kOutEWiseFusable);
          }
        };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
    } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
      // defer injective fusion to second phase.
      // so conv2d always finishes fusing.
      if (phase != 1) continue;
      // Check if all path are injective.
      auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
      if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
        CommitFuse(graph_node, dom_node->parent->gnode);
      }
    } else {
      // do nothing.
      ICHECK(group_node->pattern == kCommReduce);
    }
  }
}
Phase0 #
  • kOutEWiseFusable 当前节点和dom节点:conv2d是kOutEWiseFusable, dom_node是kElemWise, 中间节点:kBroadcast
if (group_node->pattern == kOutEWiseFusable) {
  // Path for OutEWiseFusable: conv2d
  // Check if the dominator relation is elemwise.
  if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
	ICHECK(dom_node->parent->gnode != nullptr);
	// The fuse can be executed if all the intermediate ops are still broadcast.
	auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
	if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
	  CommitFuse(graph_node, dom_node->parent->gnode);
	}
  }
}
  • kBroadcast 当前节点/dom节点:当前节点(kElemWise、kBroadcast),dom节点(kElemWise、kBroadcast、kInjective、kCommReduce) 中间节点:not sink(kElemWise、kBroadcast、kInjective)、sink()
    } else if (group_node->pattern <= kBroadcast) {
      // Pre-condition: can only be fused to parent which is injective or reduction.
      if (dom_node->parent != nullptr &&
          (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
        // Check if all the intermediate ops are still broadcast.
        // The final terminal node can already be fused to a OutEWiseFusable group.
        auto fcond = [](OpPatternKind kind, bool is_sink) {
          if (!is_sink) {
            // Elemwise, broadcast, and injective ops on the parallel branches
            // are allowed be fused to the elemwise/broadcast anchor.
            return kind <= kInjective;
          } else {
            return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
                    kind == kOutEWiseFusable);
          }
        };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
    }
Phase 1 #
  • kInjective、kTuple 中间路径节点:(kElemWise、KBroadCast、kInjective)
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
  // defer injective fusion to second phase.
  // so conv2d always finishes fusing.
  if (phase != 1) continue;
  // Check if all path are injective.
  auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
  if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
	CommitFuse(graph_node, dom_node->parent->gnode);
  }
}
Phase 2 #
  • 当前节点:(kCommReduce、kOutEWiseFusable、kTuple)
  • dom节点:(kElemWise、kBroadcast、kInjective、kTuple)
  • 中间节点:(kElemWise、kBroadcast、kInjective)
    if (phase == 2) {
      // Fuse injective ops into intermediate tuples, if any
      if (group_node->pattern > relay::kInjective) continue;
      Group* dom_parent_group = groups_[dom_parent_gindex];
      Group* dom_root_group = dom_parent_group->FindRoot();
      // If dom node group has a tuple as its root, we do not fuse tuple fields into it
      if (dom_root_group->pattern == relay::kTuple) continue;
      if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
        // Now we know the tuple has been fused into subsequent injective ops
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
        // dom_root_group can also be tuple, as in inception layers
        // CheckPath is needed to avoid fusing two intermediate tuples
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
      continue;
    }

CheckPath #

函数描述: 函数流程:

template <typename F>
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                 F fcond) {
  ICHECK(!src->extern_ref);
  visited_.clear();
  ICHECK(src != sink);
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  }
  return true;
}

template <typename F>
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                  F fcond) {
  if (visited_.count(src)) return true;
  visited_.insert(src);
  Group* gnode = groups_[src->index];
  ICHECK(gnode != nullptr);
  gnode = gnode->FindRoot();
  if (!fcond(gnode->pattern, src == sink)) return false;
  if (src == sink) return true;
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  }
  return true;
}

CommitFuse #

函数描述:输入graph node和dom node, 合并到dom target group中 函数流程:

  • 清空visited
  • 将src合并到dom target group中,
  • 遍历src的outputs,合并到dom target group中
  • 遍历outputs的output,合并到dom target group中,直到src == sink
void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
  Group* target = groups_[sink->index];
  visited_.clear();
  ICHECK(src != sink);
  CommitFuse_(src, sink, target);
}

void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                   Group* target) {
  if (src == sink) return;
  if (visited_.count(src)) return;
  // 1. 更新visit
  visited_.insert(src);
  Group* gnode = groups_[src->index];
  ICHECK(gnode != nullptr);
  // 2. merge the current group to the parent if possible.
  MergeFromTo(gnode, target);
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    CommitFuse_(link->value.node, sink, target);
  }
}

void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
  child = child->FindRoot();
  parent = parent->FindRoot();
  if (child == parent) return;
  // update the number of nodes of the parent group
  parent->num_nodes += child->num_nodes;
  child->parent = parent;
  // update anchor ref and pattern
  if (child->anchor_ref != nullptr) {
    ICHECK(parent->anchor_ref == nullptr);
    parent->anchor_ref = child->anchor_ref;
    parent->pattern = CombinePattern(child->pattern, parent->pattern);
  }
}

DominatorTree #

类描述:用于表示domination或者节点的post domination关系 数据成员:std::vector<Node*> nodes;

  • gnode:来源于indexForwardgraph
  • depth, parent, pattern都是新创建的,reverse topo得到的, pattern是parent的combine pattern
  struct Node {
    /*! \brief The node in the tree */
    IndexedForwardGraph::Node* gnode{nullptr};
    /*! \brief parent of the tree */
    Node* parent{nullptr};
    /*! \brief current depth*/
    int depth{0};
    /*! \brief aggregated pattern to parent */
    OpPatternKind pattern{kOpaque};
  };

PostDom #

函数描述: 函数流程:

  • reverse topo序, 从gv开始遍历
DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
  DominatorTree tree;
  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
  // reverse topo order
  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
    size_t index = i - 1;
    tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
  }
  return tree;
}

GetNode #

函数描述: 函数流程:

  1. make Node
  2. gv和函数参数的depth = 1, parent=nullptr, pattern = opaque
  3. 当前节点的最近公共祖先==当前节点所有edges/outputs的最近公共祖先
DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
                                            IndexedForwardGraph::Node* gnode) {
  Node* tnode = arena->make<Node>();
  tnode->gnode = gnode;
  if (gnode->extern_ref) {
    tnode->depth = 1;
    tnode->parent = nullptr;
    tnode->pattern = kOpaque;
  } else {
    // find the LCAs of all outputs.
    OpPatternKind pattern = kElemWise;
    Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); // g->outputs, 为gnode的输出边(指向callnode的函数参数)
    tnode->depth = parent ? parent->depth + 1 : 1;
    tnode->parent = parent;
    tnode->pattern = pattern; // pattern为所有参数的pattern combine
  }
  return tnode;
}

LeastCommonAncestor: #

DominatorTree::Node* DominatorTree::LeastCommonAncestor(
    const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
  auto link = input_nodes.head;
  if (link == nullptr) {
    return nullptr;
  }
  auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
    size_t oindex = edge.node->index;
    ICHECK_LT(oindex, nodes.size());
    Node* onode = nodes[oindex];
    ICHECK(onode != nullptr);
    return onode;
  };
  Node* parent = get_node(link->value);
  *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
  link = link->next;
  for (; link != nullptr; link = link->next) {
    parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
    *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
  }
  return parent;
}
// 比较两个node的depth
DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
                                                        OpPatternKind* edge_pattern) {
  while (lhs != rhs) {
    if (lhs == nullptr) return nullptr;
    if (rhs == nullptr) return nullptr;
    if (lhs->depth < rhs->depth) {
      edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
      rhs = rhs->parent;
    } else if (rhs->depth < lhs->depth) {
      edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
      lhs = lhs->parent;
    } else {
      edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
      edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
      lhs = lhs->parent;
      rhs = rhs->parent;
    }
  }
  return lhs;
}

融合 OperatorFusor #

类构造:

class OperatorFusor : public ExprMutator {
  OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const std::vector<Group*>& groups,
                bool lift_constant = true)
      : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {}
};

Transform:

  • 处理relax函数,并且该relax函数不是kPrimitive的
  IRModule Transform() {
    for (const auto& [gv, func] : mod_->functions) {
      // Only visit Relax function without attr kPrimitive.
      if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
        auto updated_func = Downcast<Function>(VisitExpr(func));
        builder_->UpdateFunction(gv, updated_func);
      }
    }
    return builder_->GetContextIRModule();
  }

VisitBindingBlock:

  1. 收集每个grouped的函数
  2. 收集所有groups的边界
  3. 为每个grouped函数创建group函数
  4. 开始生成新的binding block
  BindingBlock VisitBindingBlock(const BindingBlock& block) final {
    if (const auto* df_block = block.as<DataflowBlockNode>()) {
      return VisitBindingBlock_(df_block);
    }
    // We skip ordinary binding blocks since they might be impure (with side effect or control flow)
    return block;
  }

  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
    group2func_.clear();

    // Step 1. Collect the bindings for each grouped function.
    CollectFuncBindings(block->bindings);

    // Step 2. Collect all group's boundary (i.e. the output vars for each group)
    CollectFuncBoundary(block->bindings);

    // Step 3. Create the grouped function for each group.
    for (auto& [g, creator] : group2func_) {
      creator.CreateFunction(g->attrs);
    }

    // Step 4. Start generating the new binding block.
    //  - For groups with single binding, we directly recurse into the binding and emit the new one.
    //  - For groups with multiple bindings, we emit the call to the grouped function only when
    //  visiting the last binding of the group, because only by doing this we don't break the
    //  dependencies among the bindings of different groups. And therefore, we will skip all but the
    //  last binding of the group.
    builder_->BeginDataflowBlock();

    // For each group, record which variables need to be remapped to the output of TupleGetItem.
    // Only relevant when the output of the grouped function is a tuple.
    std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;

    // A grouped function which returns a tuple requires attaching TupleGetItem to each element and
    // remapping variables in earlier bindings appropriately. Thus, a binding whose value depends on
    // some elements of a tuple from other group's function must be emitted after a call to the
    // tuple-producing function is emitted and remapping is done.
    // To guarantee this, we process bindings in the order of the topological sort of the group
    // dependency relations.
    for (const auto& binding : TopoSortByGroupDep(block->bindings)) {
      // Case 1. If the binding is the only binding in its group, recurse into it and emit the
      // transformed binding as usual.
      Group* group = GetGroupFromBinding(binding);
      if (group->num_nodes == 1 && group->attrs.empty()) {
        VisitBinding(binding);
        continue;
      }

      const auto& it_creator = group2func_.find(group);
      ICHECK(it_creator != group2func_.end());
      const FunctionCreator& func_info = it_creator->second;

      if (!func_info.function_.defined()) {
        // The function is not created yet, so we skip the binding.
        continue;
      }
      const Function& func = func_info.function_.value();

      // If this binding belongs to a group whose output is a tuple, the original bound variable
      // needs to be remapped to the output of TupleGetItem after the corresponding tuple is
      // emitted.
      if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) {
        if (!GetStructInfo(binding->var)->IsInstance<TupleStructInfoNode>() ||
            IsNestedTupleOutput(func)) {
          // When binding->var itself is a tuple, we do not need to remap this variable to the
          // output of TupleGetItem unless the output is a nested tuple.
          pending_tuple_get[group].push_back(binding->var);
        }
      }

      // Case 2. If the binding is not the last binding of the group, we skip it.
      if (!func_info.bindings_.back().same_as(binding)) {
        continue;
      }

      // Case 3. The binding is the last binding of the group.
      const auto* var_binding = binding.as<VarBindingNode>();
      ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 "
                                        "is supposed to be a variable binding";

      // Step a. Add the grouped function to the IRModule
      GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_);

      // Step b. Create the call to the deduplicated function, and then emit the call.
      //  - If this binding is an output binding, emit an output variable.
      //  - Otherwise, emit a dataflow variable.
      Var new_var;
      Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_));

      if (var_binding->var->IsInstance<DataflowVarNode>()) {
        new_var = builder_->Emit(call_to_emit);
      } else {
        new_var = builder_->EmitOutput(call_to_emit);
      }

      // Step c. Update the mapping used for the remapping of the binding variables.
      if (IsTupleOutput(func) && !pending_tuple_get.empty()) {
        // If the output is a tuple, attach TupleGetItem to all tuple elements, and
        // remap variables approriately.
        // The variables that need to be remapped and the corresponding tuple indices are
        // available in pending_tuple_get and tuple_get_indices_ respectively.
        for (const auto& var : pending_tuple_get[group]) {
          auto tuple_get = TupleGetItem(new_var, tuple_get_indices_[var.get()]);
          var_remap_[var->vid] = builder_->Emit(tuple_get);
        }
      } else {
        var_remap_[var_binding->var->vid] = new_var;
      }
    }
    // Step 5. Finish the binding block generation.
    return builder_->EndBlock();
  }

FuseOpsByPattern #

  • 类型:ModulePASS
  • 用途:常用fuseOpsByPattern实现对两个relax.call的合并为一个新的relax函数, 然后使用mutator重写该relax函数
Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
                      bool annotate_codegen) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
      [=](IRModule m, PassContext pc) {
        return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
      };
  return CreateModulePass(/*pass_function=*/pass_func,       //
                          /*opt_level=*/0,                   //
                          /*pass_name=*/"FuseOpsByPattern",  //
                          /*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern);
IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
                          bool bind_constants, bool annotate_codegen) {
  support::Arena arena;
  // 遍历所有的patterns
  for (const auto& pattern : patterns) {
    OperatorFusor::GroupMap group_map;
    // 遍历所有的函数:
    for (const auto& entry : mod->functions) {
      if (entry.second->IsInstance<tir::PrimFuncNode>()) {
        continue;
      }
      auto map = PatternBasedPartitioner::Run(
          pattern->name, pattern->pattern, pattern->annotation_patterns,
          pattern->check.value_or(nullptr), entry.second, &arena);
      group_map.insert(map.begin(), map.end());
    }
    mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
  }
  if (annotate_codegen) {
    return CompositeFunctionAnnotator(mod).Run();
  }
  return mod;
}

FuseTIR #

pass功能:

  • 将primitive relax函数融合到更大的TIR函数中
Pass FuseTIR() {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
      [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
  return CreateModulePass(/*pass_function=*/pass_func,  //
                          /*opt_level=*/0,              //
                          /*pass_name=*/"FuseTIR",      //
                          /*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);

调用TIRFuseMutator完成:

IRModule FuseTIR(IRModule mod) {
  mod = TIRFuseMutator::Transform(mod);
  return mod;
}

Mutator执行流程:

  1. VisitExpr,删除bunch of primfunc, 创建一个空的block builder
  2. 遍历所有的primitive relax func
  3. 更新所有的non-primitive relax func,
  4. 复制mod属性到modified_mod中

实例分析 #

FuseOps #

relax func:

def before(dtype):
    bb = relax.BlockBuilder()
    x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
    w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
    w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
    w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
    with bb.function("main", [x, w1, w2, w3]):
        with bb.dataflow():
            lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
            lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1)
            # this is the next dominator.
            lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1)
            lv3 = bb.emit_te(topi.add, lv1, lv2)
            # second path
            lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1)
            lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1)
            gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5))
        bb.emit_func_output(gv)

    return bb.get()

构造indexforwardGraph:

  • 先push参数,再push varBinding
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:116: IndexedForwardGraph, func: main
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 5, var: lv
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 6, var: lv1
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 8, var: lv2
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 9, var: lv3
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 10, var: lv4
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 11, var: lv5
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 12, var: gv
  1. push x, w1, w2, w3
  2. push relax.const(1), lv0(edge/outputs = {x, relax.const(1)})
  3. push lv1(edge/outputs=lv0, w1)
  4. push relax.const(1), lv2(edge/outputs = {relax.const(1), lv1})
  5. push lv3(edge/outputs= {lv1, lv2})
  6. push lv4(edge/outputs= {lv3, w2})
  7. push lv5(edge/outputs= {lv3, w3})
  8. push gv(edge/outputs= {lv4, lv5})
gidx type pattern output/edges tree-idx depth lca-parent(outputs)
x 0 params/Var kObaque 0 1 null
w1 1 params/Var kObaque 1 1 null
w2 2 params/Var kObaque 2 1 null
w3 3 params/Var kObaque 3 1 null
R.const 4 ConstantNode kObaque {5: lv} 4 5 5: lv kElemWise
lv = add 5 VarBinding:Call kElemWise {6: lv1} 5 4 6: lv1 kOutEWiseFusable
lv1 = conv2d 6 VarBinding:Call kOutEWiseFusable {8: lv2, 9: lv3} 6 3 9: lv3 kElemWise
R.const 7 ConstantNode kObaque {8: lv2} 7 4 8: lv2 kElemWise
lv2 = add 8 VarBinding:Call kElemWise {9: lv3} 8 3 9: lv3 kElemWise
lv3 = add 9 VarBinding:Call kElemWise {10: lv4, 11: lv5} 9 2 12: gv kOutEWiseFusable
lv4 = conv2d 10 VarBinding:Call kOutEWiseFusable {12: gv} 10 2 12: gv kElemWise
lv5 = conv2d 11 VarBinding:Call kOutEWiseFusable {12: gv} 11 2 12: gv kElemWise
gv = add 12 Var->Extern? kElemWise 12 1 null kObaque
graph TB
	x --> lv0
	
digraph G {
  x [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.Var("x")</td></tr>
    </table>>
    shape = "plaintext"
  ];
  x -> lv0
  Constant1 [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.const(1)</td></tr>
    </table>>
    shape = "plaintext"
  ];
  Constant1 -> lv0
  W1 [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.Var("w1")</td></tr>
    </table>>
    shape = "plaintext"
  ];
  lv0 -> lv1
  W1 -> lv1
  Constant2 [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.const(1)</td></tr>
    </table>>
    shape = "plaintext"
  ];
  lv1 -> lv2
  Constant2 -> lv2
  
  lv1 -> lv3
  lv2 -> lv3
  
  lv3 -> lv4
  lv3 -> lv5
  W2 [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.Var("w2")</td></tr>
    </table>>
    shape = "plaintext"
  ];
  W3 [
    label = <
    <table border="0" cellborder="1" cellspacing="0">
        <tr><td>R.Var("w3")</td></tr>
    </table>>
    shape = "plaintext"
  ];
  W2 -> lv4
  W3 -> lv5
  
  lv4 -> gv
  lv5 -> gv

  start [shape=Mdiamond];
  gv [shape=Msquare];
}

dominator tree

  • gv, depth = 1
  • lv5, depth = 2
  • lv4, depth = 2
  • lv3, depth = 2