源码位置:
- https://github.com/apache/tvm/blob/main/python/tvm/relax/pipeline.py
- https://github.com/apache/tvm/blob/main/src/relax/transform/legalize_ops.cc
- https://github.com/apache/tvm/blob/main/src/relax/transform/fold_constant.cc
本文分析的是zero_pipeline: 完成relax func降级,算子融合、常量折叠等工作。
mod_after_pipeline = relax.pipeline.get_pipeline()(mod)
pipeline包含的transform op
seq = tvm.transform.Sequential(
[
transform.LegalizeOps(enable_warning=enable_warning),
transform.AnnotateTIROpPattern(),
transform.FoldConstant(),
transform.FuseOps(),
transform.FuseTIR(),
]
)
LegalizeOps #
pass功能:对relax op进行降级,生成 tir op 相关类:
- LegalizeMutator:对mod进行重写
Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
bool apply_legalize_ops =
pc->GetConfig<Bool>("relax.transform.apply_legalize_ops").value_or(Bool(true))->value;
if (apply_legalize_ops) {
mod = LegalizeMutator(mod, cmap, enable_warning).Transform();
}
return mod;
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"LegalizeOps",
/*required=*/{});
LegalizeMutator #
数据成员:
- mod_:IRModule
- cmap_: Map<String, PackedFunc> cmap_:包含自定义的legalization function map
- enable_warning_:bool:
Transform:
- 遍历所有的函数,只对relax function做处理
IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
if (func->IsInstance<FunctionNode>()) {
auto updated_func = Downcast<Function>(this->VisitExpr(func));
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
return builder_->GetContextIRModule();
}
处理relax.call的op:
- 获取降级属性Map:legalize_map
- 获取Op:call_pure_packed_op、call_tir_op、call_dps_packed_op
- 判断是否自定义的算子实现,如果是,使用自定义的cmap进行legalize, 并WrapPureCall
- 使用Legalize_map进行降级,并WrapPureCall
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;
Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
auto* op_node = visited_call->op.as<OpNode>();
auto op = GetRef<Op>(op_node);
std::string op_name(op->name);
bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
// Priority: customize > default.
// Check if it has customize legalization registered.
if (cmap_.defined() && cmap_.value().count(op->name)) {
auto ret = cmap_.value()[op->name](this->builder_, visited_call);
if (ret.IsObjectRef<Expr>() && WrapPureCondition(op, ret.AsObjectRef<Expr>())) {
return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
}
return ret;
}
// Check if it has default legalization registered.
if (legalize_map.count(op)) {
auto ret = legalize_map[op](this->builder_, visited_call);
if (WrapPureCondition(op, ret)) {
return WrapPureCall(Downcast<Call>(ret));
}
return ret;
}
return visited_call;
}
WrapPureCall:
Call WrapPureCall(const Call& ret) {
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
Array<Expr> ret_args = {ret->op};
for (auto arg : ret->args) {
ret_args.push_back(arg);
}
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}
实例分析 add #
大多数op的legalize注册使用topi实现:
# /python/tvm/relax/transform/legalize_ops/binary.py
register_legalize("relax.add", _binary(topi.add))
def _binary(te_func: TEFunc) -> LegalizeFunc:
"""A common wrapper util for the legalization of binary operators.
It detects if one of the binary op arguments is a constant scalar. It so,
it extracts the scalar value to simplify the generated PrimFunc.
"""
def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
# To simplify the created PrimFunc, we first check if arg1 is a constant scalar.
# If it is not, we then check if arg0 is a constant scalar.
arg0 = call.args[0]
arg1 = _try_convert_to_scalar_const(call.args[1])
if isinstance(arg1, Expr): # type: ignore
arg0 = _try_convert_to_scalar_const(arg0)
return bb.call_te(te_func, arg0, arg1)
return binary_call_te
topi实现:
# /src/topi/broadcast.cc
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_DEFINE_BCAST_OP(add, { return a + b; });
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
std::string name = "T_" #Name, std::string tag = kBroadcast) { \
auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return detail::WithBroadcast(l, A, B, name, tag); \
} \
inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
std::string name = "T_" #Name, std::string tag = kElementWise) { \
auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return tvm::te::compute( \
A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
} \
inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
std::string name = "T_" #Name, std::string tag = kElementWise) { \
auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return tvm::te::compute( \
B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
}
ConstantFolder #
- 对于每个函数,调用ConstantFolder::Fold(f, m)
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return ConstantFolder::Fold(f, m);
};
return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant);
ConstantFolder实现:
static Function Fold(Function func, IRModule ctx_module) {
ConstantFolder folder(std::move(ctx_module));
func = RemoveAllUnused(Downcast<Function>(folder(func)));
return func;
}
VisitExpr #
Expr VisitExpr_(const CallNode* call) final {
// post-order mutation
Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
// Check if it is useful to fold this call
if (!ShouldBeFolded(post_call)) return post_call;
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
auto* op_node = post_call->op.as<OpNode>();
auto op = GetRef<Op>(op_node);
if (op.same_as(call_tir_op)) {
return VisitCallTIR(post_call).value_or(post_call);
}
post_call =
Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span);
...
// If we are in a dataflow block, we can fold ops.
if (builder_->CurrentBlockIsDataFlow()) {
// Check if we can them to call_tir
if (legalize_map.count(op)) {
// Get the legalized expression
Call post_call_normalized = Downcast<Call>(builder_->Normalize(post_call));
Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call_normalized));
// If the legalized expression is call_tir, try to fold it.
const CallNode* call = legalized_expr.as<CallNode>();
if (call && call->op.same_as(call_tir_op)) {
return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
}
...
}
return std::move(post_call);
}
AnnotateTIROpPattern #
类型:函数PASS PASS功能:标记TIR函数的Op模式类型
# src/relax/transform/annotate_tir_op_pattern.cc
namespace transform {
Pass AnnotateTIROpPattern() {
auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
return AnnotateOpPattern(std::move(f));
};
return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {});
}
TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);
实现:
- 分析函数的算子模式类型
- 将类型属性添加到函数属性中
# src/relax/transform/annotate_tir_op_pattern.cc
tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
if (f->HasNonzeroAttr("op_pattern")) {
return f;
} else {
relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
return WithAttr(std::move(f), "op_pattern", Integer(static_cast<int>(kind)));
}
}
算子模式类型: kElemWise < kBroadcast < kInjective < kCommReduce < kOutEWiseFusable < kTuple < kOpaque
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// The pattern for tuple nodes. Can fuse into subsequent injective ops,
// but treated specially
kTuple = 7,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
PatternKindAnalyzer #
分析器:
# src/relax/analysis/tir_op_pattern_kind.cc
relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) {
PatternKindAnalyzer analyzer(func);
analyzer(func->body);
return analyzer.GetResult();
}
VisitStmt实现:
- visit 函数参数
- visit block
- visit BufferStoreNode:只支持一个BufferStore
- visit BufferLoadNoe:push_back BufferLoad
explicit PatternKindAnalyzer(const tir::PrimFunc& func) {
for (const tir::Var& param : func->params) {
Optional<Buffer> param_buf = func->buffer_map.Get(param);
if (param_buf.defined()) {
param_buffers_.insert(param_buf.value());
}
}
}
BufferStore和BufferLoad
void VisitStmt_(const BufferStoreNode* op) final {
// If we have already seen buffer store in the current block, classify as Opaque.
if (store_.defined() && !IsSameArray(op->indices, store_.value()->indices)) {
kind_ = relay::kOpaque;
return;
}
store_ = GetRef<BufferStore>(op);
StmtVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode* op) final {
loads_.push_back(GetRef<BufferLoad>(op));
ExprVisitor::VisitExpr_(op);
}
数据成员 #
- store_:bufferstore节点
- loads_:bufferload节点
- kind_:节点模式
- param_buffers:
private:
/*!
* \brief The BufferStore node in the current block.
* \note We only support one BufferStore node in a block (usually generated by TE compute)
*/
Optional<BufferStore> store_;
/*! \brief The BufferLoad nodes in the current block. */
Array<BufferLoad> loads_;
/*! \brief The result of op pattern. */
relay::OpPatternKind kind_ = relay::kElemWise;
/*! \brief The buffers from function params. I.e. the input and output buffers. */
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> param_buffers_;
public:
relay::OpPatternKind GetResult() { return kind_; }
};
Visit Block #
- 不处理root block
- 如果当前block没有任何存储,则分类为Opaque
- 遍历所有的BufferLoad,根据load和store判断op模式
- IsElemwisePattern
- IsBroadcastPattern
- IsInjectivePattern
void VisitStmt_(const BlockNode* op) final {
// Step 1. Clear loads and store
loads_.clear();
store_ = NullOpt;
// Step 2. Visit block body.
StmtVisitor::VisitStmt(op->body);
BufferStore store = store_.value();
// Step 3. Checking load store indices pattern
relay::OpPatternKind index_pair_pattern = relay::kElemWise;
bool has_elem_wise = false;
for (const BufferLoad& load : loads_) {
// Since elemwise is stricter than broadcast and broadcast is stricter than injective,
// while the order amount enums: kElemWise < kBroadcast < kInjective.
// We can simply use `std::max` to detect these three patterns.
// E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i]
// Buffer C and A are elemwise but C and B are broadcast. So the whole block follows
// broadcast pattern.
if (IsElemwisePattern(store, load)) {
index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
has_elem_wise = true;
} else if (IsBroadcastPattern(store, load)) {
index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
} else if (IsInjectivePattern(store, load)) {
index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
} else {
index_pair_pattern = relay::kOpaque;
break;
}
}
// If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise
// e.g. A[i, j] = B[i, j] + C[i]
if (index_pair_pattern == relay::kBroadcast && has_elem_wise) {
index_pair_pattern = relay::kElemWise;
}
// If the block index pattern is not opaque, update kind.
if (index_pair_pattern != relay::kOpaque) {
// This rule for softmax: reduce + injective.
if (IsOutputBlock(op) && kind_ == relay::kCommReduce) {
kind_ = relay::kOutEWiseFusable;
} else {
kind_ = std::max(kind_, index_pair_pattern);
}
return;
}
// Step 4. Checking if the block contains reduce axis by looking into block iterators.
bool has_reduction = false;
Array<tir::Var> reduce_vars;
for (const IterVar& it : op->iter_vars) {
if (it->iter_type == kCommReduce) {
has_reduction = true;
reduce_vars.push_back(it->var);
}
}
if (has_reduction) {
if (IsFMA(op->body)) {
// FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
kind_ = std::max(kind_, relay::kOutEWiseFusable);
return;
} else {
for (size_t i = 0; i < loads_.size(); ++i) {
// If it's not a pure reduce, regards as kOutEWiseFusable.
// This rule works for pooling for now.
if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
kind_ = std::max(kind_, relay::kOutEWiseFusable);
return;
}
}
}
kind_ = std::max(kind_, relay::kCommReduce);
} else {
kind_ = relay::kOpaque;
}
}
判断算子模式类型 #
模式顺序:kElemWise(最严格) < kBroadcast < kInjective. 例子:判断为broadcast
C[i, j] = A[i, j] + B[i]
# C[i, j], A[i, j] == elemwise
# C[i, j], B[i] == broadcast
- IsElemwisePattern:
- load indices == store indices, 则是elementwise算子
- 不相同的情况:长度不同,值不同
static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) {
return IsSameArray(store->indices, load->indices);
}
- IsBroadcastPattern
- 判断依据:Load的indice是否在Store的indices中
- 不处理
[0][1]
static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) {
size_t ndim_load_buf = load->buffer->shape.size();
size_t ndim_store_buf = store->buffer->shape.size();
for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) {
if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) {
// Skip unit load dimensions
// E.g. A[i, j] = B[1, j] is still broadcast
continue;
}
// Try to find the i-th load index in the store indices.
while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) {
++j;
}
// It's not broadcast if we cannot find load indices in the store indices in order.
if (j == ndim_store_buf) {
return false;
}
}
return true;
}
- IsInjectivePattern
- 所有的load indices在store indices中,不考虑顺序
static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) {
std::unordered_set<const tir::VarNode*> vars;
for (const PrimExpr& store_index : store->indices) {
if (const auto* v = store_index.as<tir::VarNode>()) {
vars.insert(v);
} else {
return false;
}
}
for (const PrimExpr& load_index : load->indices) {
// return false if there are vars used in load indices but not in store indices.
if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) {
return false;
}
}
return true;
}
reduce判断/kOutEWiseFusable #
寻找reduce轴:
bool has_reduction = false;
Array<tir::Var> reduce_vars;
for (const IterVar& it : op->iter_vars) {
if (it->iter_type == kCommReduce) {
has_reduction = true;
reduce_vars.push_back(it->var);
}
}
判断流程:
- 保底是:kCommReduce
- 如果包含:乘加:
C[i, j] += A[i, k] * B[j, k]
, kOutEWiseFusable - 如果非Pure Reduce:kOutEWiseFusable
if (has_reduction) {
if (IsFMA(op->body)) {
// FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv.
kind_ = std::max(kind_, relay::kOutEWiseFusable);
return;
} else {
for (size_t i = 0; i < loads_.size(); ++i) {
// If it's not a pure reduce, regards as kOutEWiseFusable.
// This rule works for pooling for now.
if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) {
kind_ = std::max(kind_, relay::kOutEWiseFusable);
return;
}
}
}
kind_ = std::max(kind_, relay::kCommReduce);
}
pure reduce判断:
- 所有的reduce轴是一个reduce var, 则pure
- 如果出现j+k这种,是一个add,则not pure
/*!
* \brief Checking if it is pure reduce pattern.
* It's pure reduce pattern iff all reduces axis are directly reduce var
* E.g. A[i] = sum(B[i, j]) is pure reduce
* A[i] = sum(B[i, j + k]) is not pure reduce
* pooling is not pure reduce
*/
static bool IsPureReducePattern(Array<tir::Var> reduce_loops, Array<PrimExpr> indices) {
for (const PrimExpr& e : indices) {
int id = -1;
if (UsesVar(e, [&](const tir::VarNode* var) {
for (size_t i = 0; i < reduce_loops.size(); ++i) {
if (reduce_loops[i].get() == var) {
id = i;
return true;
}
}
return false;
})) {
if (!reduce_loops[id].same_as(e)) {
return false;
}
}
}
return true;
}
FuseOps #
类型:模块PASS PASS功能:对特定特征的算子作融合
实现:
# src/relax/transform/fuse_ops.cc
Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps));
return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue());
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOps", //
/*required=*/{});
}
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
FuseOps:
- 根据输入IRModule创建index-forward图; arena是一个内存分配器
- 应用融合算法对图进行分区
- 根据图分区结果,对IRModule中的算子进行融合
IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {
support::Arena arena;
// Step 1. Create the indexed-forward graph according to the input IRModule.
IndexedForwardGraph graph = GraphCreator::Create(mod, &arena);
// Step 2. Partition the graph by applying the fusion algorithm.
std::vector<GraphPartitioner::Group*> groups =
GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);
// Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition
// results.
return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform();
}
创建IndexedForward图 #
GraphCreator分析, 数据成员:
- mod: IRModule
- arenas_:所有内部节点对象的内存分配器
- graph_:创建的indexed forward图
- initialized_nodes:已经设置模式的图节点
class GraphCreator : public ExprVisitor {
public:...
private:
/*! \brief The IRModule from which the indexed forward graph is created */
IRModule mod_;
/*! \brief The allocator of all the internal node objects */
support::Arena* arena_;
/*! \brief The created indexed forward graph */
IndexedForwardGraph graph_;
/*! \brief The graph nodes whose patterns are set */
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};
Create分析:
- 处理的对象:relax function
- 保证以post-dfs的顺序添加创建的节点,然后我们检查是否containers有相同的大小
static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
GraphCreator creator(mod, arena);
for (const auto& it : mod->functions) {
// Only visit Relax function without attr kPrimitive.
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) {
continue;
}
creator(GetRef<Function>(func));
}
// The algorithm of the graph creator ensures that each created node will be added to the
// post-dfs order and will be set its op pattern. Thus we check whether all these containers
// have the same size.
size_t n_nodes = creator.graph_.node_map.size();
return creator.graph_;
}
图分析 #
以forward方向的索引的数据流图 数据成员:
- note_map:将node映射到图节点
- post_dfs_order:post-dfs顺序
class IndexedForwardGraph {
public:
/*! \brief The node map that maps node to graph */
std::unordered_map<const tvm::Object*, Node*> node_map;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
}
边:
struct Edge {
/*! \brief The corresponding node */
Node* node{nullptr};
/*! \brief The respective pattern of this op */
OpPatternKind pattern{kOpaque};
};
节点:
struct Node {
/*! \brief weak reference to the corresponding edge. */
const tvm::Object* ref{nullptr};
/*! \brief The index of the node in topological order. */
size_t index{0}; // 拓扑顺序的index
/*! \brief Whether this node is referenced by external source */
bool extern_ref{false};
/*! \brief The general pattern in the node */
OpPatternKind pattern{kOpaque};
/*! \brief The outputs of the node. */
LinkedList<Edge> outputs; // 节点的输出边
};
创建节点:
IndexedForwardGraph::Node* CreateNode(const Object* key) {
auto* node = arena_->make<IndexedForwardGraph::Node>();
graph_.node_map[key] = node;
return node;
}
visit Function #
- 遍历所有的参数,对每个参数创建Graph节点,并标记为Extern, 且节点模式标记为kOpaque
void VisitExpr_(const FunctionNode* func) final {
for (const Var& param : func->params) {
IndexedForwardGraph::Node* param_node = CreateNode(param.get());
// The parameter is passed in from the outside, and thus it's marked as an external reference,
// and it's pattern is `kOpaque`.
MarkAsExternRef(param_node);
SetNodePattern(param_node, OpPatternKind::kOpaque);
AddToPostDFSOrder(param_node, param.get());
}
ExprVisitor::VisitExpr_(func);
}
visit bindingBlock #
void VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
}
visit binding #
- visit VarBindingNode
- visit MatchCast
MatchCast:
- 创建图Node,并设置其节点模式为kOpaque
- 将节点添加到graph.post_dfs_order中
void VisitBinding_(const MatchCastNode* binding) final {
IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
LOG(INFO) << "[FuseOps][GraphCreator] MatchCastNode, not call node var: " << binding->var << ", value: " << binding->value << ", struct_info: " << binding->struct_info;
SetNodePattern(node, OpPatternKind::kOpaque);
AddToPostDFSOrder(node, binding->var.get());
}
VarBinding:
- 创建节点,根据value类型,选择visit目标
- 添加到post_dfs_order中
- Call:如果函数已经定义模式,则设置;若无,则设置为kOpaque
- TupleGetItem:设置节点模式为kInjective
- VisitUnsupportedNode:设置节点模式为kOpaque
void VisitBinding_(const VarBindingNode* binding) final {
IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
// If the variable is not a dataflow variable, it must be the output variable of this dataflow
// block
if (!binding->var->IsInstance<DataflowVarNode>()) {
this->MarkAsExternRef(node);
}
if (const auto* call = binding->value.as<CallNode>()) {
// Case 1. The expression is a CallNode
VisitCall(call, node);
} else if (const auto* tuple_get_item = binding->value.as<TupleGetItemNode>()) {
// Case 2. The expression is a TupleGetItemNode
VisitTupleGetItem(tuple_get_item, node);
} else {
VisitUnsupportedNode(binding->value, node);
// Case 3. The type of the expression is not fusion-supported.
// In this case, we skip adding edges, adding an empty node into graph.
}
AddToPostDFSOrder(node, binding->var.get());
}
VisitCall:
- 设置创建节点的模式
- Visit每个函数参数,将会对每个参数创建节点Node/选择已有的node,并将每个节点放到CallNode的edges中
void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) {
ICHECK_NOTNULL(binding_var_node);
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
OpPatternKind pattern = OpPatternKind::kOpaque;
Array<Expr> args = call->args;
// - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the
// function attribute and visit the arguments one by one.
// - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we
// recurse into the call expression.
const auto* op = call->op.as<OpNode>();
if (op == call_tir_op_.get()) {
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
// Override args for call_tir
args = Downcast<Tuple>(call->args[1])->fields;
Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
if (opt_pattern.defined()) {
pattern = static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
} else {
pattern = OpPatternKind::kOpaque;
}
}
// The pattern of the current binding variable node is set to the pattern of this operator.
SetNodePattern(binding_var_node, pattern);
// Visit all call args
for (const Expr& arg : args) {
ICHECK(IsLeafOrTuple(arg));
VisitLeaf(arg, binding_var_node, pattern);
}
}
VisitLeaf:
void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node,
const OpPatternKind& pattern) {
ICHECK_NOTNULL(binding_var_node);
// Recursive visit if it's Tuple
if (const auto* tuple = leaf_expr.as<TupleNode>()) {
for (const Expr& expr : tuple->fields) {
VisitLeaf(expr, binding_var_node, pattern);
}
return;
}
if (!leaf_expr->IsInstance<LeafExprNode>()) {
// Skip GlobalVar, ExternFunc, OpNode.
return;
}
auto it = graph_.node_map.find(leaf_expr.get());
IndexedForwardGraph::Node* leaf_node = nullptr;
if (it != graph_.node_map.end()) {
leaf_node = it->second;
} else if (leaf_expr->IsInstance<ConstantNode>() || leaf_expr->IsInstance<ShapeExprNode>() ||
leaf_expr->IsInstance<PrimValueNode>() || leaf_expr->IsInstance<StringImmNode>() ||
leaf_expr->IsInstance<DataTypeImmNode>()) {
leaf_node = CreateNode(leaf_expr.get());
// Since we never fuse constants, the pattern of the constant is set to `kOpaque`.
SetNodePattern(leaf_node, OpPatternKind::kOpaque);
AddToPostDFSOrder(leaf_node, leaf_expr.get());
} else {
LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr
<< " used before definition.";
}
AddEdge(leaf_node, binding_var_node, pattern);
}
// leaf --> curr node
void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end,
OpPatternKind pattern) {
auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
link->value.node = end;
link->value.pattern = pattern;
start->outputs.Push(link);
}
AddToPostDFSOrder #
函数功能:将输入节点以post-dfs顺序添加到图中
void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) {
auto it = graph_.node_map.find(key);
node->ref = key;
node->index = graph_.post_dfs_order.size();
graph_.post_dfs_order.push_back(node);
}
图分区 #
// Step 2. Partition the graph by applying the fusion algorithm.
std::vector<GraphPartitioner::Group*> groups =
GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);
GraphPartitioner #
类描述:由联合查找数据结构标记的图的分区 数据成员:
- arena_:内部的arena用于临时空间
- opt_level_:融合操作的优化级别
- max_fuse_depth:一个fused函数中的最大的算子数目
- groups_:内部group
- visited:
# src/relay/analysis/graph_partitioner.h
class GraphPartitioner {
private:
/*! \brief The internal arena for temporary space. */
support::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
}
Group:
struct Group {
Group* parent{nullptr}; // The parent in the union find data structure.
OpPatternKind pattern; //当前group的Op模式类型
const tvm::Object* root_ref{nullptr}; //
const tvm::Object* anchor_ref{nullptr};
uint32_t num_nodes{1}; // 属于当前group的节点输入
runtime::Map<runtime::String, ObjectRef> attrs; // 当前group函数的属性
Group* FindRoot(); // 找到group root, 执行路径压缩
};
Partition #
函数描述:对图进行拆分 函数流程:
- 初始化groups:每个节点一个group
- 获取dominator tree:reverse topo遍历,组合OPPattern
- 运行融合算法
# src/relay/analysis/graph_partitioner.cc
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
const IndexedForwardGraph& graph) {
this->InitGroups(graph);
if (opt_level_ == 0) return std::move(groups_);
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
}
return std::move(groups_);
}
InitGroups #
函数描述:每个节点新建一个group 函数流程:
- 如果pattern为kOutEWiseFusable,设置anchor_ref
void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
// 遍历图的每个节点
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
auto* group_node = arena_->make<Group>();
group_node->pattern = graph_node->pattern;
group_node->root_ref = graph_node->ref;
// set anchor ref if necessary.
if (group_node->pattern == relay::kOutEWiseFusable) {
group_node->anchor_ref = graph_node->ref;
}
groups_[nid] = group_node;
}
}
RunFuse #
函数描述 函数流程:
- 遍历所有的groups
- 当前节点==kObqque或者当前节点没有dominator,则不fuse
- Phase0:完成(kElemWise,kBroadcast)、(kOutEWiseFusable)
- Phase1:
- Phase2:
void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
const DominatorTree& post_dom_tree, //
int phase) {
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
size_t dom_parent_gindex = dom_node->parent->gnode->index;
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > relay::kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == relay::kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}
// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
}
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
ICHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern <= kBroadcast) {
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
(dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast anchor.
return kind <= kInjective;
} else {
return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else {
// do nothing.
ICHECK(group_node->pattern == kCommReduce);
}
}
}
Phase0 #
- kOutEWiseFusable 当前节点和dom节点:conv2d是kOutEWiseFusable, dom_node是kElemWise, 中间节点:kBroadcast
if (group_node->pattern == kOutEWiseFusable) {
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
ICHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
}
- kBroadcast 当前节点/dom节点:当前节点(kElemWise、kBroadcast),dom节点(kElemWise、kBroadcast、kInjective、kCommReduce) 中间节点:not sink(kElemWise、kBroadcast、kInjective)、sink()
} else if (group_node->pattern <= kBroadcast) {
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
(dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast anchor.
return kind <= kInjective;
} else {
return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
}
Phase 1 #
- kInjective、kTuple 中间路径节点:(kElemWise、KBroadCast、kInjective)
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
Phase 2 #
- 当前节点:(kCommReduce、kOutEWiseFusable、kTuple)
- dom节点:(kElemWise、kBroadcast、kInjective、kTuple)
- 中间节点:(kElemWise、kBroadcast、kInjective)
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > relay::kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == relay::kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}
CheckPath #
函数描述: 函数流程:
template <typename F>
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
F fcond) {
ICHECK(!src->extern_ref);
visited_.clear();
ICHECK(src != sink);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
}
template <typename F>
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
F fcond) {
if (visited_.count(src)) return true;
visited_.insert(src);
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
gnode = gnode->FindRoot();
if (!fcond(gnode->pattern, src == sink)) return false;
if (src == sink) return true;
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
}
CommitFuse #
函数描述:输入graph node和dom node, 合并到dom target group中 函数流程:
- 清空visited
- 将src合并到dom target group中,
- 遍历src的outputs,合并到dom target group中
- 遍历outputs的output,合并到dom target group中,直到src == sink
void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
Group* target = groups_[sink->index];
visited_.clear();
ICHECK(src != sink);
CommitFuse_(src, sink, target);
}
void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
Group* target) {
if (src == sink) return;
if (visited_.count(src)) return;
// 1. 更新visit
visited_.insert(src);
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
// 2. merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);
}
}
void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
child->parent = parent;
// update anchor ref and pattern
if (child->anchor_ref != nullptr) {
ICHECK(parent->anchor_ref == nullptr);
parent->anchor_ref = child->anchor_ref;
parent->pattern = CombinePattern(child->pattern, parent->pattern);
}
}
DominatorTree #
类描述:用于表示domination或者节点的post domination关系 数据成员:std::vector<Node*> nodes;
- gnode:来源于indexForwardgraph
- depth, parent, pattern都是新创建的,reverse topo得到的, pattern是parent的combine pattern
struct Node {
/*! \brief The node in the tree */
IndexedForwardGraph::Node* gnode{nullptr};
/*! \brief parent of the tree */
Node* parent{nullptr};
/*! \brief current depth*/
int depth{0};
/*! \brief aggregated pattern to parent */
OpPatternKind pattern{kOpaque};
};
PostDom #
函数描述: 函数流程:
- reverse topo序, 从gv开始遍历
DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}
GetNode #
函数描述: 函数流程:
- make Node
- gv和函数参数的depth = 1, parent=nullptr, pattern = opaque
- 当前节点的最近公共祖先==当前节点所有edges/outputs的最近公共祖先
DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
tnode->parent = nullptr;
tnode->pattern = kOpaque;
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); // g->outputs, 为gnode的输出边(指向callnode的函数参数)
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern; // pattern为所有参数的pattern combine
}
return tnode;
}
LeastCommonAncestor: #
DominatorTree::Node* DominatorTree::LeastCommonAncestor(
const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index;
ICHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
ICHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}
// 比较两个node的depth
DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
OpPatternKind* edge_pattern) {
while (lhs != rhs) {
if (lhs == nullptr) return nullptr;
if (rhs == nullptr) return nullptr;
if (lhs->depth < rhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
rhs = rhs->parent;
} else if (rhs->depth < lhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
lhs = lhs->parent;
} else {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
lhs = lhs->parent;
rhs = rhs->parent;
}
}
return lhs;
}
融合 OperatorFusor #
类构造:
class OperatorFusor : public ExprMutator {
OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const std::vector<Group*>& groups,
bool lift_constant = true)
: OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {}
};
Transform:
- 处理relax函数,并且该relax函数不是kPrimitive的
IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
auto updated_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, updated_func);
}
}
return builder_->GetContextIRModule();
}
VisitBindingBlock:
- 收集每个grouped的函数
- 收集所有groups的边界
- 为每个grouped函数创建group函数
- 开始生成新的binding block
BindingBlock VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
return block;
}
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
group2func_.clear();
// Step 1. Collect the bindings for each grouped function.
CollectFuncBindings(block->bindings);
// Step 2. Collect all group's boundary (i.e. the output vars for each group)
CollectFuncBoundary(block->bindings);
// Step 3. Create the grouped function for each group.
for (auto& [g, creator] : group2func_) {
creator.CreateFunction(g->attrs);
}
// Step 4. Start generating the new binding block.
// - For groups with single binding, we directly recurse into the binding and emit the new one.
// - For groups with multiple bindings, we emit the call to the grouped function only when
// visiting the last binding of the group, because only by doing this we don't break the
// dependencies among the bindings of different groups. And therefore, we will skip all but the
// last binding of the group.
builder_->BeginDataflowBlock();
// For each group, record which variables need to be remapped to the output of TupleGetItem.
// Only relevant when the output of the grouped function is a tuple.
std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;
// A grouped function which returns a tuple requires attaching TupleGetItem to each element and
// remapping variables in earlier bindings appropriately. Thus, a binding whose value depends on
// some elements of a tuple from other group's function must be emitted after a call to the
// tuple-producing function is emitted and remapping is done.
// To guarantee this, we process bindings in the order of the topological sort of the group
// dependency relations.
for (const auto& binding : TopoSortByGroupDep(block->bindings)) {
// Case 1. If the binding is the only binding in its group, recurse into it and emit the
// transformed binding as usual.
Group* group = GetGroupFromBinding(binding);
if (group->num_nodes == 1 && group->attrs.empty()) {
VisitBinding(binding);
continue;
}
const auto& it_creator = group2func_.find(group);
ICHECK(it_creator != group2func_.end());
const FunctionCreator& func_info = it_creator->second;
if (!func_info.function_.defined()) {
// The function is not created yet, so we skip the binding.
continue;
}
const Function& func = func_info.function_.value();
// If this binding belongs to a group whose output is a tuple, the original bound variable
// needs to be remapped to the output of TupleGetItem after the corresponding tuple is
// emitted.
if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) {
if (!GetStructInfo(binding->var)->IsInstance<TupleStructInfoNode>() ||
IsNestedTupleOutput(func)) {
// When binding->var itself is a tuple, we do not need to remap this variable to the
// output of TupleGetItem unless the output is a nested tuple.
pending_tuple_get[group].push_back(binding->var);
}
}
// Case 2. If the binding is not the last binding of the group, we skip it.
if (!func_info.bindings_.back().same_as(binding)) {
continue;
}
// Case 3. The binding is the last binding of the group.
const auto* var_binding = binding.as<VarBindingNode>();
ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 "
"is supposed to be a variable binding";
// Step a. Add the grouped function to the IRModule
GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_);
// Step b. Create the call to the deduplicated function, and then emit the call.
// - If this binding is an output binding, emit an output variable.
// - Otherwise, emit a dataflow variable.
Var new_var;
Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_));
if (var_binding->var->IsInstance<DataflowVarNode>()) {
new_var = builder_->Emit(call_to_emit);
} else {
new_var = builder_->EmitOutput(call_to_emit);
}
// Step c. Update the mapping used for the remapping of the binding variables.
if (IsTupleOutput(func) && !pending_tuple_get.empty()) {
// If the output is a tuple, attach TupleGetItem to all tuple elements, and
// remap variables approriately.
// The variables that need to be remapped and the corresponding tuple indices are
// available in pending_tuple_get and tuple_get_indices_ respectively.
for (const auto& var : pending_tuple_get[group]) {
auto tuple_get = TupleGetItem(new_var, tuple_get_indices_[var.get()]);
var_remap_[var->vid] = builder_->Emit(tuple_get);
}
} else {
var_remap_[var_binding->var->vid] = new_var;
}
}
// Step 5. Finish the binding block generation.
return builder_->EndBlock();
}
FuseOpsByPattern #
- 类型:ModulePASS
- 用途:常用fuseOpsByPattern实现对两个relax.call的合并为一个新的relax函数, 然后使用mutator重写该relax函数
Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants,
bool annotate_codegen) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOpsByPattern", //
/*required=*/{});
}
TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern);
IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns, IRModule mod,
bool bind_constants, bool annotate_codegen) {
support::Arena arena;
// 遍历所有的patterns
for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
// 遍历所有的函数:
for (const auto& entry : mod->functions) {
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
continue;
}
auto map = PatternBasedPartitioner::Run(
pattern->name, pattern->pattern, pattern->annotation_patterns,
pattern->check.value_or(nullptr), entry.second, &arena);
group_map.insert(map.begin(), map.end());
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ !bind_constants);
}
if (annotate_codegen) {
return CompositeFunctionAnnotator(mod).Run();
}
return mod;
}
FuseTIR #
pass功能:
- 将primitive relax函数融合到更大的TIR函数中
Pass FuseTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseTIR", //
/*required=*/{});
}
TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);
调用TIRFuseMutator完成:
IRModule FuseTIR(IRModule mod) {
mod = TIRFuseMutator::Transform(mod);
return mod;
}
Mutator执行流程:
- VisitExpr,删除bunch of primfunc, 创建一个空的block builder
- 遍历所有的primitive relax func
- 更新所有的non-primitive relax func,
- 复制mod属性到modified_mod中
实例分析 #
FuseOps #
relax func:
def before(dtype):
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
with bb.function("main", [x, w1, w2, w3]):
with bb.dataflow():
lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1)
# this is the next dominator.
lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1)
lv3 = bb.emit_te(topi.add, lv1, lv2)
# second path
lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1)
lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1)
gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5))
bb.emit_func_output(gv)
return bb.get()
构造indexforwardGraph:
- 先push参数,再push varBinding
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:116: IndexedForwardGraph, func: main
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 5, var: lv
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 6, var: lv1
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 8, var: lv2
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 9, var: lv3
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 10, var: lv4
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 11, var: lv5
[21:31:09] /media/l0phtg/movable/tools-git/tvm/src/relax/transform/fuse_ops.cc:179: AddToPostDFSOrder: 12, var: gv
- push x, w1, w2, w3
- push relax.const(1), lv0(edge/outputs = {x, relax.const(1)})
- push lv1(edge/outputs=lv0, w1)
- push relax.const(1), lv2(edge/outputs = {relax.const(1), lv1})
- push lv3(edge/outputs= {lv1, lv2})
- push lv4(edge/outputs= {lv3, w2})
- push lv5(edge/outputs= {lv3, w3})
- push gv(edge/outputs= {lv4, lv5})
gidx | type | pattern | output/edges | tree-idx | depth | lca-parent(outputs) | ||
---|---|---|---|---|---|---|---|---|
x | 0 | params/Var | kObaque | 0 | 1 | null | ||
w1 | 1 | params/Var | kObaque | 1 | 1 | null | ||
w2 | 2 | params/Var | kObaque | 2 | 1 | null | ||
w3 | 3 | params/Var | kObaque | 3 | 1 | null | ||
R.const | 4 | ConstantNode | kObaque | {5: lv} | 4 | 5 | 5: lv | kElemWise |
lv = add | 5 | VarBinding:Call | kElemWise | {6: lv1} | 5 | 4 | 6: lv1 | kOutEWiseFusable |
lv1 = conv2d | 6 | VarBinding:Call | kOutEWiseFusable | {8: lv2, 9: lv3} | 6 | 3 | 9: lv3 | kElemWise |
R.const | 7 | ConstantNode | kObaque | {8: lv2} | 7 | 4 | 8: lv2 | kElemWise |
lv2 = add | 8 | VarBinding:Call | kElemWise | {9: lv3} | 8 | 3 | 9: lv3 | kElemWise |
lv3 = add | 9 | VarBinding:Call | kElemWise | {10: lv4, 11: lv5} | 9 | 2 | 12: gv | kOutEWiseFusable |
lv4 = conv2d | 10 | VarBinding:Call | kOutEWiseFusable | {12: gv} | 10 | 2 | 12: gv | kElemWise |
lv5 = conv2d | 11 | VarBinding:Call | kOutEWiseFusable | {12: gv} | 11 | 2 | 12: gv | kElemWise |
gv = add | 12 | Var->Extern? | kElemWise | 12 | 1 | null | kObaque |
graph TB
x --> lv0
digraph G {
x [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.Var("x")</td></tr>
</table>>
shape = "plaintext"
];
x -> lv0
Constant1 [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.const(1)</td></tr>
</table>>
shape = "plaintext"
];
Constant1 -> lv0
W1 [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.Var("w1")</td></tr>
</table>>
shape = "plaintext"
];
lv0 -> lv1
W1 -> lv1
Constant2 [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.const(1)</td></tr>
</table>>
shape = "plaintext"
];
lv1 -> lv2
Constant2 -> lv2
lv1 -> lv3
lv2 -> lv3
lv3 -> lv4
lv3 -> lv5
W2 [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.Var("w2")</td></tr>
</table>>
shape = "plaintext"
];
W3 [
label = <
<table border="0" cellborder="1" cellspacing="0">
<tr><td>R.Var("w3")</td></tr>
</table>>
shape = "plaintext"
];
W2 -> lv4
W3 -> lv5
lv4 -> gv
lv5 -> gv
start [shape=Mdiamond];
gv [shape=Msquare];
}
dominator tree
- gv, depth = 1
- lv5, depth = 2
- lv4, depth = 2
- lv3, depth = 2