RewriteDataflowReshape #
pass功能:
- 将所有的类似reshape的call_tir转换为VM reshape操作符call。VM reshape operator将会被降级为运行期间的CreateView操作,而不是做一些真实的数据拷贝工作。
- 这里的reshape类的算子包含:reshape、expand_dims、flatten等
类实现:
Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) {
return DataflowReshapeRewriter(mod)(f);
}
namespace transform {
Pass RewriteDataflowReshape() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(RewriteDataflowReshape(f, m));
};
return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {});
}
TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape")
.set_body_typed(RewriteDataflowReshape);
}
CallTIRRewrite #
# src/relax/transform/call_tir_rewrite.cc
Pass CallTIRRewrite() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CallTIRRewrite(f));
};
return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
}
TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
调用CallTIRMutator实现:
Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
LowerAllocTensor #
实例 #
@R.function
def main():
x = R.builtin.alloc_tensor(R.shape([16, 32]), "float32", 0)
return x
after LowerAllocTensor()(Before)
@R.function
def main():
storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") # 分配storage
x = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") # 分配tensor
return x
VMBuiltinLower #
源文件位置:src/relax/backend/vm/vm_builtin_lower.cc pass类型:Mutator pass功能: global注册:
- 核心类:VMBuiltinLowerMutator
Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Pass VMBuiltinLower() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(VMBuiltinLower(f));
};
return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
}
TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
- VMBuiltinLowerMutator属性
class VMBuiltinLowerMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;
...
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const StructInfo object_sinfo_ = ObjectStructInfo();
const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
// mem alloc//kill
const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage");
const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor");
const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage");
const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor");
// functions to lower to
const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& vm_kill_object_op_ = Op::Get("relax.vm.kill_object");
// Function to compute allocated shape.
const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
VisitExpr #
Expr VisitExpr_(const CallNode* call_node) final {
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == call_tir_dyn_op_) {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
return InvokeClosure(call);
} else if (call->op == alloc_tensor_op_) {
LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in expression "
<< GetRef<Call>(call_node) << ". "
<< "This operation should have been lowered earlier "
<< "using the 'relax.transform.LowerAllocTensor' pass.";
// 内存存储分配//张量存储分配//内存释放
} else if (call->op == mem_alloc_storage_op_) {
return MakeMemAllocStorage(call);
} else if (call->op == mem_alloc_tensor_op_) {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
return MakeMemKillObject(call);
} else {
return call;
}
}
Mem #
MakeMemAllocStorage #
MakeMemAllocTensor #
relax.memory.alloc_tensor 降级为 relax.vm.alloc_tensor
Expr MakeMemAllocTensor(const Call& call) {
PrimValue offset = Downcast<PrimValue>(call->args[1]);
DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]);
return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], dtype}, Attrs());
}
调用的ops:const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
VMShapeLower #
源文件位置:src/relax/backend/vm/vm_shape_lower.cc global注册:
Pass VMShapeLower(bool emit_err_ctx) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) {
return VMShapeLowerMutator::Lower(mod, emit_err_ctx);
};
return CreateModulePass(pass_func, 0, "VMShapeLower", {});
}
TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) {
return VMShapeLower(emit_err_ctx);
});
VMShapeLowerMutator:Mutator,重写Rewrite,对IRModule进行重写
class VMShapeLowerMutator
: public ExprMutator,
public StructInfoFunctor<void(const StructInfo&, Expr, bool, bool, const String&,
std::vector<MatchShapeTodoItem>*)> {
public:
static IRModule Lower(IRModule mod, bool emit_err_ctx) {
# 1. visit
VMShapeLowerMutator mutator(mod, emit_err_ctx);
# 2. 遍历所有的functions, 针对tir.Func做处理-> Rewrite, 最后将更新后的func替换builder_的函数
for (auto& kv : mod->functions) {
if (auto* func = kv.second.as<FunctionNode>()) {
Function updated_func = mutator.Rewrite(kv.first, GetRef<Function>(func));
mutator.builder_->UpdateFunction(kv.first, updated_func);
}
}
return mutator.builder_->GetContextIRModule();
}
源码分析 #
算法流程:
- 预处理:PrimExprSlot 集合,我们扫描函数并为每个 PrimExpr 分配 PrimExprSlot。在上面的例子中,从slot索引到 expr 的结果映射将为 {0:m, 1: n+1: 2: n}。注意到:“n+1"也会获得一个slot。PrimExprSlot还带有辅助字段,用于跟踪其值是否易于计算。 每个匹配点的步骤:
- 步骤一 调用
CheckMatchCast
:将会递归式的unpack StructInfo,并且生成静态信息检查。请注意,此步骤仅生成用于检查类型和 ndim 信息的函数,而不生成符号形状变量。符号shape-matching的结果将被返回作为vector<MatchShapeTodoItem>。这是因为符号形状匹配可能无法在一轮中完成。重要的是,CheckMatchCast 还处理元组解包。 - 步骤二:调用
RunMatch
来生成匹配符号形状的语句。在上面的例子中,第一轮会将 M、N 的值存储到其对应的插槽中。RunMatch 可能会返回outstanding的项。 - 步骤三:
Rewrite #
// Unit rewrite function per function.
Function Rewrite(GlobalVar gvar, Function func) {
// prepare mapping and heap var
slot_vec_.clear();
slot_map_.clear();
PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_);
heap_size_ = IntImm(ShapeDType(), static_cast<int64_t>(slot_vec_.size()));
VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_);
shape_heap_ = shape_heap_binding->var;
// prepare slot information
this->PopulateSlotInfo();
Array<BindingBlock> blocks;
builder_->BeginScope(func->params);
{
// Check the parameter section.
builder_->BeginBindingBlock();
this->builder_->EmitNormalized(shape_heap_binding);
std::vector<MatchShapeTodoItem> match_todos;
size_t num_input = func->params.size();
if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
// If the function has the attribute 'num_input', do shape checking on for the real inputs
// and skip weights.
num_input = static_cast<size_t>(opt_num_input.value()->value);
}
for (size_t i = 0; i < func->params.size(); ++i) {
StructInfo sinfo = GetStructInfo(func->params[i]);
std::ostringstream err_ctx;
err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
<< "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") ";
this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input, err_ctx.str(),
&match_todos);
}
// insert heap generation logic.
match_todos = this->RunMatch(match_todos, false);
this->EmitOutstandingPrimExprCompute();
this->RunMatch(match_todos, true);
BindingBlock pre_block = builder_->EndBlock();
blocks.push_back(pre_block);
}
// new body.
auto body_seq = Downcast<SeqExpr>(this->VisitWithNewScope(func->body, func->params));
blocks.insert(blocks.end(), body_seq->blocks.begin(), body_seq->blocks.end());
{
// Insert the return value check
builder_->BeginBindingBlock();
std::ostringstream err_ctx;
err_ctx << "ErrorContext(fn=" << gvar->name_hint
<< ", loc=return, annotation=" << func->ret_struct_info << ") ";
std::vector<MatchShapeTodoItem> match_todos;
// NOTE: the return value's shape computation must already be defined.
this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, false, err_ctx.str(),
&match_todos);
// NOTE: the return value's shape computation must already be defined.
this->RunMatch(match_todos, true);
BindingBlock post_block = builder_->EndBlock();
blocks.push_back(post_block);
}
auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body));
// create a new function
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs);
}
visitExpr #
PrimValueNode #
Expr VisitExpr_(const PrimValueNode* op) final {
using runtime::relax_vm::MakeShapeCode;
// Constant shape can be preserved.
bool is_const_value =
op->value->IsInstance<IntImmNode>() || op->value->IsInstance<FloatImmNode>();
if (is_const_value) {
return GetRef<Expr>(op);
}
Array<Expr> args = {shape_heap_};
auto [code, value_or_index] = MakeSymbolicShapeArg(op->value);
args.push_back(code);
args.push_back(value_or_index);
// make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
Call call(builtin_make_prim_value_, args, Attrs(), {Downcast<StructInfo>(op->struct_info_)});
return call;
}
ShapeExprNode #
Expr VisitExpr_(const ShapeExprNode* op) final {
using runtime::relax_vm::MakeShapeCode;
// Constant shape can be preserved.
bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), [](const PrimExpr& e) {
return e->IsInstance<IntImmNode>();
});
if (is_const_shape) {
return GetRef<Expr>(op);
}
Array<Expr> args = {shape_heap_, PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
for (PrimExpr expr : op->values) {
auto [code, value_or_index] = MakeSymbolicShapeArg(expr);
args.push_back(code);
args.push_back(value_or_index);
}
// make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
Call call(builtin_make_shape_, args, Attrs(),
{ShapeStructInfo(static_cast<int>(op->values.size()))});
return call;
}
MatchCastNode #
void VisitBinding_(const MatchCastNode* binding) final {
Expr value = ExprMutator::VisitExpr(binding->value);
std::vector<MatchShapeTodoItem> match_todos;
std::ostringstream err_ctx;
err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") ";
// always_check=false
this->CheckMatchCast(binding->struct_info, value, false, false, err_ctx.str(), &match_todos);
match_todos = this->RunMatch(match_todos, false);
this->EmitOutstandingPrimExprCompute();
this->RunMatch(match_todos, true);
// These checks are emitted as extra, in codegen
// match-cast is simply ignored and treated as a normal binding.
builder_->EmitNormalized(GetRef<MatchCast>(binding));
}
相关builtin #
alloc_shape_heap #
# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap);
NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr);
// 1. host allocator, 总是最后一个元素
size_t host_device_index = vm->devices.size() - 1;
// 2. 特别指定on-device RT hexagon
if (vm->devices[0].device_type == kDLHexagon) {
host_device_index = 0;
}
auto* alloc = vm->allocators[host_device_index];
// 3. 调用alloc->Empty
return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]);
}
check_tensor_info #
简介:检测参数arg是否是Tensor(dtype, ndim)
- arg: 输入参数
- ndim:Tensor的ndim,可以是-1(表示不知道)
- dtype:期待的数据类型
- error_ctx
# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body(CheckTensorInfo);
void CheckTensorInfo(TVMArgs args, TVMRetValue* rv) {
ObjectRef arg = args[0];
int ndim = args[1];
DataType dtype;
Optional<String> err_ctx;
if (args.size() == 3) {
dtype = DataType::Void();
err_ctx = args[2].operator Optional<String>();
} else {
dtype = args[2];
err_ctx = args[3].operator Optional<String>();
}
auto* ptr = arg.as<NDArray::ContainerType>();
}
match_shape #
简介:内置的runtime函数 签名:
- 此函数提供 runtime 形状填充和对match-cast的检测支持。
- 当形状变量第一次出现时,我们应该加载形状并填充变量。
- 当形状变量已经出现时,我们应该断言它已经等于现有的形状值。
- (如果断定所有的code标识都是AssertEqualToImm,允许传递shape_heap为nullptr)
MatchShape(input_shape, shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n], err_ctx)
enum class MatchShapeCode : int {
kAssertEqualToImm = 0, // 立即数 assert input_shape[i] == r[i]
kStoreToHeap = 1, // 第一次看到symbolic shape, 存到heap中 shape_heap[r[i]] = input_shape[i]
kNoOp = 2, // 什么都不做
kAssertEqualToLoad = 3, // assert input_shape[i] == shape_heap[r[i]]
};
具体实现:
# ./src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape);
void MatchShape(TVMArgs args, TVMRetValue* rv) {
// 1. 第一个参数是shape或者tensor
ShapeTuple input_shape;
if (args[0].IsObjectRef<NDArray>()) {
input_shape = args[0].operator NDArray().Shape();
} else {
input_shape = args[0];
}
// 2. 第二个参数, heap->DLTensor
DLTensor* heap = args[1];
int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data);
// 2. 第三个参数:大小、第四个参数BeginCode
int64_t size = args[2];
const int64_t kBeginCode = 3;
// a function that lazily get context for error reporting
const int64_t kErrorContextOffset = kBeginCode + size * 2;
Optional<String> err_ctx = args[kErrorContextOffset];
for (int64_t i = 0; i < size; ++i) {
// 解析code
MatchShapeCode code = static_cast<MatchShapeCode>(args[kBeginCode + i * 2].operator int());
// 解析register-> args
int64_t reg = args[kBeginCode + i * 2 + 1];
if (code == MatchShapeCode::kAssertEqualToImm) {
CHECK_EQ(input_shape[i], reg)...
} else if (code == MatchShapeCode::kStoreToHeap) {
heap_data[reg] = input_shape[i];
} else if (code == MatchShapeCode::kNoOp) {
} else {
ICHECK(code == MatchShapeCode::kAssertEqualToLoad);
}
}
}
make_shape #
简介: 签名:
- 如果code都是UseImm的话,允许shape_heap为nullptr
MakeShape(shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n]).
enum class MakeShapeCode : int {
kUseImm = 0,
kLoadShape = 1,
};
实例分析1:lower #
before:
@tvm.script.ir_module
class Before:
@R.function
def main(x: R.Tensor(["n", 2, "m"], "float32")):
R.func_attr({"relax.force_pure": True})
return x
after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
sindex = {
"n": 0,
"m": 1,
}
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(["n", 2, "m"], "float32")):
R.func_attr({"relax.force_pure": True})
# 1. 分配shape堆, 使用R.builtin_with_ctx, 因为需要ctx上下文, 得到vm runtime
shape_heap = R.call_builtin_with_ctx(
"vm.builtin.alloc_shape_heap",
[R.prim_value(2)],
sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
)
# 2. 检查输入x的tensor类型是否匹配, x == Tensor(ndim=3, dtype="float32")
_ = R.call_packed(
"vm.builtin.check_tensor_info",
x,
3,R.dtype("float32"), # ndim == 3, dtype == "float32"
"",
sinfo_args=[R.Tuple()],
)
# 3. match_shape
_ = R.call_packed(
"vm.builtin.match_shape", # 内置的runtime函数, 运行期确定sidx的值
x, # 第一个参数: tensor
shape_heap, 3, # shape_heap, shape大小为3
MS.STORE_TO_HEAP, sindex["n"],
MS.ASSERT_EQUAL_TO_IMM, 2,
MS.STORE_TO_HEAP, sindex["m"],
"",
sinfo_args=[R.Tuple()],
)
return x
示例分析2:running #
Relax设计文档#将该程序降级为VM指令的挑战 运行期数据结构包含:
- ShapeTuple和NDArray。
- 输入和irmodule
# 输入
\inputs = tvm.nd.array(
np.array([1,2,3]).astype(np.int64)
)
# irmodule
d = tvm.tir.Var("m", "int64")
p = tvm.tir.Var("n", "int64")
nums = relax.Var("shape_d", R.Tensor([d, p], "int32"))
with bb.function("main"):
with bb.dataflow():
b = relax.op.ones(R.shape((d, p)), "int32")
gv = bb.emit_output(relax.op.astype(b, "int64"))
bb.emit_func_output(gv, [nums])
mod = bb.get()
- after pipeline, lower shape
mod_after_pipeline = relax.pipeline.get_pipeline("zero")(mod)
mod_after_builtinLower = relax.transform.VMBuiltinLower()(mod_after_pipeline)
mod_after_vmshapeLower = relax.transform.VMShapeLower()(mod_after_builtinLower)
- build and running
rt_mod = relax.build(mod_after_pipeline, tvm.target.Target("llvm"))
vm = relax.VirtualMachine(rt_mod, tvm.cpu())
inputs = tvm.nd.array(
np.array([[1,2,3], [4,5,6]]).astype(np.int32)
)
vm["main"](inputs)
分析:
- vm调用relax函数main, 传递Inputs作为输入
@main:
#
call vm.builtin.alloc_shape_heap in: %vm, i3 dst: %1
call vm.builtin.check_tensor_info in: %0, i2, c[0], c[1] dst: %void
call vm.builtin.match_shape in: %0, %1, i2, i1, i0, i1, i1, c[1] dst: %void
#
call shape_func in: %1 dst: %void
call vm.builtin.make_shape in: %1, i1, i1, i2 dst: %2
call vm.builtin.alloc_storage in: %vm, %2, i0, c[2], c[3] dst: %3
call vm.builtin.make_shape in: %1, i2, i1, i0, i1, i1 dst: %4
call vm.builtin.alloc_tensor in: %3, i0, %4, c[4] dst: %5
call vm.builtin.null_value in: dst: %3
call vm.builtin.make_shape in: %1, i2, i1, i1, i1, i0 dst: %6
call vm.builtin.call_tir_dyn in: f[fused_ones_cast], %5, %6 dst: %void
call vm.builtin.match_shape in: %5, %1, i2, i3, i0, i3, i1, c[5] dst: %void
ret %5