参考文章:
相关源码文件:
-
【relax.build】python/tvm/relax/vm_build.py
-
【tvm.build】python/tvm/driver/build_module.py
-
src/relax/backend/vm/codegen_vm.cc
-
tests/python/relax/test_vm_build.py
-
tests/python/relax/test_vm_codegen_only.py
-
tests/python/relax/test_vm_codegen_tir.py
-
tests/python/relax/test_vm_execbuilder.py
核心API:
- relax.build
- relax.VirtualMachine
API 接口 #
relax.build #
函数功能:将IRModule构建为可执行的VM 函数参数:
- mod
- target:目标平台target
- exec_mode:执行模式(字节码或者编译) 返回类型:
- ex: tvm.relax.Executable(可以被virtual machine加载的可执行的…) 函数流程:
- relax function transform
- relax function codegen ->
- tir function build -> tvm.Module
- vmlink -> executable 函数实现:
def build(
mod: tvm.IRModule,
target: Union[str, tvm.target.Target],...
):
# 1. 调用PASS完成mod转换
passes = []
passes.append(relax.transform.RewriteDataflowReshape())
passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.RemovePurityChecking())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.StaticPlanBlockMemory())
if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False):
passes.append(relax.transform.RewriteCUDAGraph())
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
passes.append(relax.transform.AttachGlobalSymbol())
seq = tvm.transform.Sequential(passes)
new_mod = seq(mod)
# 2. builder collects the executable
builder = relax.ExecBuilder()
leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) # codegen: relax2vm code
tir_mod = _filter_tir(leftover_mod) # 筛选出tir function
return _vmlink(builder, target, tir_mod, ext_libs, params, system_lib=system_lib)
_vmcodegen:
if exec_mode == "bytecode":
return _ffi_api.VMCodeGen(builder, mod) # type:ignore
if exec_mode == "compiled":
return _ffi_api.VMTIRCodeGen(builder, mod) # type: ignore
_vmlink:
def _vmlink(
builder: "relax.ExecBuilder",
target: Union[str, tvm.target.Target],
tir_mod: Optional[tvm.IRModule] = None,...
):
# 只用到了tir_mod, 返回tvm.Module
lib = tvm.build(
tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib)
)
# 用到了builder和tvm.Module, 返回Executable
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params))
VMCodeGen #
函数功能:根据给定的mod中的relax.Function,创建可执行的relax VM,并且将其添加到exec_builder中 函数参数:
- exec_builder:用于收集executables的Builder
- mod
IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
return CodeGenVM::Run(exec_builder, mod);
}
TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);
具体实现: CodeGenVM
VMTIRCodeGen #
函数功能:
- 根据给定的mod中的relax.Function,创建可执行的relax VM,并且将其添加到exec_builder中
- 然后创建额外的TIR Functions. 函数参数:
- exec_builder:用于收集executables的Builder
IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) {
return CodeGenVMTIR::Run(exec_builder, mod);
}
TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen);
具体实现: CodeGenVMTIR
VMLink #
函数功能:Link the libraries together. 函数返回:tvm.runtime.Module
Module VMLink(ExecBuilder builder,
Target target, Optional<Module> lib,
Array<Module> ext_libs,
Map<String, runtime::NDArray> params) {
// TODO(relax-team) Revisit the param and ext_lib options.
ObjectPtr<Executable> executable = builder->Get();
std::unordered_map<std::string, runtime::NDArray> conv_params;
for (const auto& [name, param] : params) {
conv_params[name] = param;
}
Module combined_lib = codegen::CreateMetadataModule(
conv_params, lib.value(), ext_libs, target,
// TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support.
// Before jumping into details, only support cpp runtime for now.
relay::Runtime::Create("cpp"),
relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily executor. CPP runtime
// won't use this anyways.
relay::backend::ExecutorCodegenMetadata());
executable->Import(combined_lib);
return Module(executable);
}
TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink);
relax.VirtualMachine #
relax.build->tvm.build #
tvm.build:
def (inputs, args, target, target_host, runtime, name):
# 1. 对输入func, mod进行降级
input_mod = lower(inputs, name=name)
target_input_mod = {target: input_mod}
# 2. 标注runtime
annotated_mods = {}
for tar, mod in target_input_mod.items():
annotated_mods[tar] = mod.with_attr("runtime", runtime)
#
annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
# 4. 将tir转换为rt_mod
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
# 5.
return OperatorModule.from_module(to_return, ir_module_by_target=annotated_mods, name=name)
1. lower #
# python/tvm/driver/build_module.py:lower
def lower(inp, args, name, ...):
if isinstance(inp, IRModule):
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, te.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
cpp注册global func:
# src/driver/driver_api.cc
TVM_REGISTER_GLOBAL("driver.lower_primfunc")
.set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) {
return LowerPrimFunc(std::move(func), name, simple_mode);
});
2. tir_to_runtime #
# src/driver/driver_api.cc
TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
.set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
return TIRToRuntime(inputs_arg, host_target);
});
TIRToRuntime:
- 将ir_mod执行变换,后拆分为host_mod和device_mod,对拆分后的mod执行变换
- 对mhost_all进行codegen, 得到runtime::Module: target.build.llvm target.build.cuda
# src/driver/driver_api.cc
# 1. 对混合mod执行变换并拆分
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;
# 2. 对host_mod和device_mod执行codegen
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
mhost_all->Update(host_mod);
}
if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
}
SplitMixedModule #
std::pair<IRModule, IRModule> SplitMixedModule(
IRModule mod_mixed,
const Target& target_arg,
const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
# 1. 对混合mod执行变换
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
# 2. 得到host_mod, 并执行变换
IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
# 3. 得到device_mod, 并执行变换
IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
auto keys = target->GetKeys();
CheckAndUpdateHostConsistency(&target, &target_host);
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && device_mod->functions.size() == 0) {
DLOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
}
return {host_mod, device_mod};
}
MixedModulePassManager #
包含的PASS:
-
VerifyVTCMLimit
-
LowerVtcmAlloc
-
BindTarget
-
VerifyMemory
-
AnnotateEntryFunc
-
ThreadSync
-
MergeDynamicSharedMemoryAllocations
-
InferFragment
-
LowerThreadAllreduce
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::PassContext pass_ctx = transform::PassContext::Current();
Array<Pass> mixed_pass_list;
// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
mixed_pass_list.push_back(tir::transform::BindTarget(target));
mixed_pass_list.push_back(tir::transform::VerifyMemory());
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
bool detect_global_barrier =
pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
if (detect_global_barrier) {
mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
if (use_async_copy) {
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
}
bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
if (ptx_ldg32) {
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}
mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
.value_or(Bool(false));
if (unpacked_api) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
return transform::Sequential(mixed_pass_list);
}
HostModulePassManager #
包含的PASS:
- BindTarget
- LowerTVMBuiltin
- LowerCustomDatatypes
- LowerIntrin
- LowerDeviceStorageAccessInfo
- CombineContextCall
transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();
Array<tvm::transform::Pass> host_pass_list;
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
CallingConv::kDeviceKernelLaunch;
};
host_pass_list.push_back(tir::transform::Filter(fcond));
ICHECK(mixed_mod.defined()) << "This module must be defined";
host_pass_list.push_back(tir::transform::BindTarget(target_host));
host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
host_pass_list.push_back(tir::transform::LowerIntrin());
host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
host_pass_list.push_back(tir::transform::CombineContextCall());
if (enable_debug) {
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}
return transform::Sequential(host_pass_list);
}
DeviceModulePassManager #
包含的PASS:
- Filter
- BindTarget
- LowerWarpMemory
- Simplify
- LowerCustomDatatypes
- LowerDeviceStorageAccessInfo
- LowerIntrin
transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
Array<Pass> device_pass_list;
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDeviceKernelLaunch;
};
device_pass_list.push_back(tir::transform::Filter(fcond));
device_pass_list.push_back(tir::transform::BindTarget(target));
device_pass_list.push_back(tir::transform::LowerWarpMemory());
device_pass_list.push_back(tir::transform::Simplify());
device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
device_pass_list.push_back(tir::transform::LowerIntrin());
return transform::Sequential(device_pass_list);
}
Codegen #
核心类 #
vm.builtin #
AllocNDArray #
参数:
- offset
- ShapeTuple
- DLDataType
# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method<Storage>(&StorageObj::AllocNDArray);
runtime::NDArray StorageObj::AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype) {
VerifyDataType(dtype);
// critical zone: allocate header, cannot throw
runtime::NDArray::Container* container =
new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device);
container->SetDeleter(StorageObj::Deleter);
size_t needed_size = runtime::GetDataSize(container->dl_tensor);
this->IncRef();
// The manager context pointer must continue to point to the storage object
// which owns the backing memory, and keeps track of the reference count.
//
// When we free a container we extract the storage object, decrement its
// reference count, then destroy the container, but leave the underlying
// buffer intact.
container->manager_ctx = reinterpret_cast<void*>(this);
// is this UB?
// The only change we make w.r.t offset is modifying the data pointer
// of the backing tensor to point into the buffer instead of its start.
auto offset_ptr = reinterpret_cast<uint8_t*>(this->buffer.data) + offset;
container->dl_tensor.data = reinterpret_cast<void*>(offset_ptr);
runtime::NDArray ret(runtime::GetObjectPtr<Object>(container));
// RAII in effect, now run the check.
ICHECK(offset + needed_size <= this->buffer.size)
<< "storage allocation failure, attempted to allocate " << needed_size << " at offset "
<< offset << " in region that is " << this->buffer.size << "bytes";
return ret;
}
CodeGenVM #
类功能:用于针对Relax函数生成可执行VM的类 Relax AST
- 即将每个relax函数中的每个Varbinding都生成vm code(
vm.builtin.xxx/relax.run.xxx/fused_xxx/
) 类数据成员: - builder_:内部的ExecBuilder
class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
public:...
protected:
/*! \brief Internal ExecBuilder. */
relax::ExecBuilder builder_;
/*!
* \brief Total number of virtual registers allocated.
* \note The first two registers are reserved for special registers.
*/
size_t registers_num_ = 0;
/*! \brief Map from var to register number. */
std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual> var_arg_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
};
Run #
- 定义CodeGenVM,对mod中的每个relax function进行codegen
- 在ir_mod中移除已经完成codegen的relax function, 只保留PrimFunc函数
static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
IRModule res_mod = mod;
res_mod.CopyOnWrite();
# 1. 定义CodeGenVM
CodeGenVM codegen(builder, mod);
# 2. 移除relax function, 并且将其转换为TIR函数
for (const auto& [gvar, f] : mod->functions) {
if (auto* func = f.as<FunctionNode>()) {
codegen.Codegen(GetRef<Function>(func));
res_mod->Remove(gvar);
}
}
return res_mod;
}
Codegen #
- builder->emit函数申明: EmitFunction/DeclareFunction
- emit函数体:CodegenVM自己实现
- VisitExpr: 调用builder.emit得到Instruction, 并放入executable中
- builder->emit函数返回值: EmitRet
- builder->emit函数END: EndFunction
void Codegen(const Function& func) {
# 1. 获取函数参数、函数名
Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
Array<String> param_names;
for (Var param : func->params) {
param_names.push_back(param->name_hint());
}
# 2. 利用builder完成emit函数声明
builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names);
# 3. 遍历函数参数, 新建寄存器并插入到var_arg_map_中
for (size_t i = 0; i < func->params.size(); ++i) {
RegName r = NewRegister();
this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)});
}
# 3. visitexpr生成函数body
Instruction::Arg ret = ExprFunctor::VisitExpr(func->body);
# 4. 利用builder完成生成函数返回值、EndFunction
builder_->EmitRet(EnsureReg(ret));
builder_->EndFunction(gsymbol.value());
// reset register number to be 0;
registers_num_ = 0;
var_arg_map_.clear();
}
VisitExpr(func->body) #
SeqExprNode:
- 对于每个blocks中的每个binding, 更新var_arg_map_
- 对于seq_expr的返回值(VarNode),Visit, 并返回
Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Instruction::Arg value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
}
this->var_arg_map_.insert({binding->var, value});
}
}
Instruction::Arg ret_reg = this->VisitExpr(op->body);
return ret_reg;
}
CallNode:
- 得到call.op
- 对于已经注册的PackedFunc,调用EmitPackedFuncCall(call, name, dst_reg);
- 如果call.op == call_builtin_with_ctx_op_
- 如果call.op == 分配内存
- 如果call.op == 分配tensor
- 如果call.op == kill obj
Instruction::Arg VisitExpr_(const CallNode* call_node) final {
Call call = GetRef<Call>(call_node);
if (call_node->op == null_value_op_) {
return Instruction::Arg::Register(Instruction::kVoidRegister);
}
// allocate dst register.
RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister();
if (call->op.as<OpNode>()) {
# 判断是否packFunc已经注册
FCallPacked name = GetPackedFuncName(call);
if (!name.empty()) {
// If the operator has a registered packed function implementation, emit call to that packed
// function.
EmitPackedFuncCall(call, name, dst_reg);
} else if (call_node->op == call_builtin_with_ctx_op_) {
// TODO(relax-team) migrate most handling of op to
// directly map to call_builtin_with_ctx before codegen and simplify vm codegen.
EmitCallBuiltinWithCtx(call, dst_reg);
} else if (call_node->op == alloc_storage_op_) {
EmitAllocStorage(call, dst_reg);
} else if (call_node->op == alloc_tensor_op_) {
EmitAllocTensor(call, dst_reg);
} else if (call_node->op == kill_object_op_) {
dst_reg = EmitKillObject(call);
} else {
// every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
// ops are handled in a pass when lowering them to TIR.
LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op;
}
} else {
EmitNormalCall(call, dst_reg);
}
return Instruction::Arg::Register(dst_reg);
}
- EmitPackedFuncCall:
void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) {
# 得到参数
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
# 调用builder完成EmitCall
builder_->EmitCall(name, args, dst_reg);
}
- EmitAllocTensor
void EmitAllocTensor(const Call& call_node, RegName dst_reg) {
std::vector<Instruction::Arg> args;
args.reserve(4);
for (Expr arg : call_node->args) {
args.push_back(this->VisitExpr(arg));
}
builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
}
- EmitAllocStorage
void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
ICHECK_EQ(call_node->args.size(), 4);
// Handle args of the call
std::vector<Instruction::Arg> args;
args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
// buffer size, dtype, device index
for (auto arg : call_node->args) {
args.push_back(this->VisitExpr(arg));
}
builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg);
}
CodeGenVMTIR #
Run #
Codegen #
relax.ExecBuilder #
类功能:
- 辅助codegen,完成生成vm code 类数据成员:
- exec_:集成自runtime::Module,保存了codegen的结果
private:
vm::Instruction::Arg ConvertConstant_(TVMRetValue obj);
/*! \brief The mutable internal executable. */
ObjectPtr<vm::Executable> exec_; // mutable
/*! \brief internal dedup map when creating index for a new constant */
std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual> const_dedup_map_;
}
EmitFunction/DeclareFunction #
- vm函数声明:创建vmfunc, 放入exec_的func_table中
- 从exec_的func_map中根据name得到index;根据index从func_table中得到vmfunc(类型为VMFuncInfo)
- 设置vm_func信息:参数数目、参数名称、
void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs,
Optional<Array<String>> param_names,
vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) {
# 1. 申明函数, 创建vmfunc, 放入exec_的func_table中
this->DeclareFunction(func_name, kind);
# 2. 从exec_中得到vmfunc
auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name));
# 3. 设置vmfunc的参数数目,参数名、寄存器文件大小、start_instr
vmfunc.num_args = num_inputs;
if (param_names.defined()) {
std::vector<std::string> names;
for (auto name : param_names.value()) {
names.push_back(name);
}
vmfunc.param_names = names;
}
vmfunc.register_file_size = init_register_size;
if (kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
vmfunc.start_instr = exec_->instr_offset.size();
}
}
EmitCall #
- 通过exec_设置指令offset,并且添加指令数据
void ExecBuilderNode::EmitCall(vm::Instruction::Arg func,
std::vector<vm::Instruction::Arg> args,
vm::RegName dst) {
// store instruction
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
exec_->instr_data.push_back(dst);
exec_->instr_data.push_back(func.value());
exec_->instr_data.push_back(args.size());
for (Instruction::Arg arg : args) {
exec_->instr_data.push_back(arg.data());
}
}
EmitRet #
- push instr_offset
- push instr_data
void ExecBuilderNode::EmitRet(vm::Instruction::Arg result) {
exec_->instr_offset.push_back(exec_->instr_data.size());
exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Ret));
exec_->instr_data.push_back(result.value());
}
EndFunction #
void ExecBuilderNode::EndFunction(const std::string& func_name) {
# 得到vmfunc, 标记end_instr
auto it = exec_->func_map.find(func_name);
VMFuncInfo& vmfunc = exec_->func_table.at(it->second);
if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
vmfunc.end_instr = exec_->instr_offset.size();
}
}
relax_vm.Executable #
类功能:包含Emit的函数指令和实现 数据成员:
/*! \brief The virtual machine's function table. */
std::vector<VMFuncInfo> func_table;
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> func_map;
/*! \brief The global constant pool. */
std::vector<TVMRetValue> constants;
/*! \brief The offset of instruction. */
std::vector<Index> instr_offset;
/*! \brief The byte data of instruction. */
std::vector<ExecWord> instr_data;
GetFunction #
PackedFunc Executable::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "stats") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); });
} else if (name == "as_text") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); });
} else if (name == "as_python") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); });
} else if (name == "vm_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
ICHECK(sptr_to_self.get() == this);
vm->LoadExecutable(GetObjectPtr<Executable>(this));
*rv = Module(vm);
});
} else if (name == "vm_profiler_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
ICHECK(sptr_to_self.get() == this);
vm->LoadExecutable(GetObjectPtr<Executable>(this));
*rv = Module(vm);
});
}
return nullptr;
}
vm_load_executable #
关键部分:
- 创建VirtualMachine
- 调用vm->LoadExecutable
- 使用Module对VM进行包装
} else if (name == "vm_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
ICHECK(sptr_to_self.get() == this);
vm->LoadExecutable(GetObjectPtr<Executable>(this));
*rv = Module(vm);
});
}
vm_profiler_load_executable #
GetInstruction #
- 判断opcode类型
- 从instr_data中获取指令信息,新建指令,并返回
Instruction Executable::GetInstruction(Index i) const {
Index offset = instr_offset[i];
Opcode op = static_cast<Opcode>(instr_data[offset]);
switch (op) {
case Opcode::Call: {
RegName dst = instr_data[offset + 1];
Index func_idx = instr_data[offset + 2];
Index num_args = instr_data[offset + 3];
ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
}
case Opcode::Ret: {
RegName result = instr_data[offset + 1];
return Instruction::Ret(result);
}
case Opcode::Goto: {
Index pc_offset = instr_data[offset + 1];
return Instruction::Goto(pc_offset);
}
case Opcode::If: {
RegName cond = instr_data[offset + 1];
Index false_offset = instr_data[offset + 2];
return Instruction::If(cond, false_offset);
}
default:
LOG(FATAL) << "should never hit this case: " << static_cast<int>(op);
break;
}
return Instruction();
}
AsText #
export_library #
函数功能:将可执行程序导出为一个链接库
VirtualMachine :: ModuleNode #
类功能:负责对executable的vm code解释执行
class VirtualMachine(object):
def __init__(
self,
rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"],
...
):
load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable"
self.module = rt_mod[load_exec]()
def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]
数据成员:
# include/tvm/runtime/vm/vm.h
class TVM_DLL VirtualMachine : public runtime::ModuleNode {
public:
...
protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs_;
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames_;
/*! \brief The fuction table index of the current function. */
Index func_index_;
/*! \brief The current pointer to the code section. */
const Instruction* code_;
/*! \brief The virtual machine PC. */
Index pc_;
/*! \brief The special return register. */
ObjectRef return_register_;
/*! \brief The executable the VM will operate on. */
ObjectPtr<Executable> exec_;
/*! \brief The function name to inputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*! \brief The function name to flag enabling scenario with set outputs. */
std::unordered_map<std::string, bool> set_outputs_enabled_;
/*! \brief The index of operation which destination is result. */
Index preresult_op_index_ = -1;
/*! \brief The function name to indices of output tensors in register file. */
std::unordered_map<std::string, std::vector<Index>> output_tensor_reg_indices_;
/*! \brief The function name to pre-allocated outputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> outputs_;
/*!
* \brief The "physical" devices the VM can execute primitives on. All "device indexes"
* are w.r.t. this vector. Each entry in this vector must match the corresponding entry
* in the executable's "virtual" devices vector.
*/
std::vector<Device> devices_;
// 内存分配器, 每个设备一个
std::vector<Allocator*> allocators_;
/*!
* \brief The constant pool for runtime. It caches the device dependent
* object to avoid rellocation of constants during inference.
*/
std::vector<ObjectRef> const_pool_;
};
创建对象 #
auto vm = make_object<VirtualMachine>();
LoadExecutable #
- 赋值:exec_
- 赋值:packed_funcs_
void VirtualMachine::LoadExecutable(const ObjectPtr<Executable>& exec) {
exec_ = exec;
runtime::Module lib = exec_->GetLib();
for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs_.size() <= packed_index) {
packed_funcs_.resize(packed_index + 1);
}
tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, /*query_imports=*/true);
packed_funcs_[packed_index] = pf;
}
}
RuntimeModule #
vm_load_executable #
tvm.build核心类 #
主要包含两个部分:
- lower
- codegen
lower_primfunc #
函数功能:对PrimFunc进行降级
LowerPrimFunc 函数实现:
- 创建PassList
- PHASE 1:Flatten、LowerInitBlock、SharedMemory、Narrow、
- PHASE 2:LoopPartition、VectorizeLoop、InjectVirtualThread、UnrollLoop
- PHASE 3:RemoveNoOp
- 根据PassList对func进行降级
# src/driver/driver_api.cc
IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) {
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
// Get the pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
}
PassList:
Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
bool disable_storage_rewrite =
pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
bool enable_equiv_terms_in_cse_tir =
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();
bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
// Get any user-added passes
Array<Array<ObjectRef>> add_lower_pass =
pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
.value();
bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();
Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();
// phase passes is of the form
// [[phase_number, pass], [phase_number, pass]... ]
for (Array<ObjectRef> phase_pass : add_lower_pass) {
const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>();
ICHECK(phase_num)
<< "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
int phase_num_val = phase_num->value;
CHECK_GE(phase_num_val, 0);
auto pass = Downcast<tvm::transform::Pass>(phase_pass[1]);
// Copy the pass into the correct phase
if (phase_num_val == 0) {
user_lower_phase0.push_back(pass);
} else if (phase_num_val == 1) {
user_lower_phase1.push_back(pass);
} else if (phase_num_val == 2) {
user_lower_phase2.push_back(pass);
} else if (phase_num_val >= 3) {
user_lower_phase3.push_back(pass);
}
}
// Construct the pass list, inserting the user provided passes at the end of the phase
// PHASE 0
Array<tvm::transform::Pass> pass_list = user_lower_phase0;
// PHASE 1
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::LiftThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerAutoCopy());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectPermutedLayout());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::TransformMmaBufferLayout());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::FP8ComputeLegalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
// Add user-defined phase-1 passes
pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
// PHASE 2
if (!disable_loop_partition) {
pass_list.push_back(tir::transform::LoopPartition());
}
pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
if (!disable_storage_rewrite) {
pass_list.push_back(tir::transform::StorageRewrite());
}
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
if (use_async_copy) {
pass_list.push_back(tir::transform::LowerAsyncDMA());
}
pass_list.push_back(tir::transform::UnrollLoop());
// Add user-defined phase-2 passes
pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
// PHASE 3
pass_list.push_back(tir::transform::RenormalizeSplitPattern());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
pass_list.push_back(tir::transform::HoistIfThenElse());
// Add user-defined phase-3 passes
pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
if (instrument_bound_checkers) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}
if (ptx_ldg32) {
pass_list.push_back(tir::transform::InjectPTXLDG32(true));
}
pass_list.push_back(
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));
// This pass instruments the loops with the profile builtin calls to capture the runtime
// performance data (only enabled for Hexagon at the moment). To ensure that no other
// optimizations are performed on the instrumented code, this pass must be added at the end
// of the list.
if (instrument_lwp) {
pass_list.push_back(tir::transform::InstrumentProfileIntrinsics());
}
return pass_list;
}
codegen::build #
tvm::codegen::build
- 得到目标target的属性
- 调用目标target的build函数进行build
# src/target/codegen.cc
runtime::Module Build(IRModule mod, Target target) {
# 1. 得到目标target的属性
auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
target_attr_map[target->kind](mod, target);
# 2. 调用目标target的build函数进行build
std::string build_f_name = "target.build." + target->kind->name;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
return (*bf)(mod, target);
}
target.build.llvm #
# src/target/llvm/llvm_module.cc
TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
# 1. 构造LLVMModuleNode
auto n = make_object<LLVMModuleNode>();
# 2. 使用IRModule对其进行初始化
n->Init(mod, target);
# 3. 返回
return runtime::Module(n);
});
相关数据结构:
- LLVMModuleNode
- LLVMTarget
- CodeGenLLVM
LLVMModuleNode->Init #
- 创建LLVMTarget,根据LLVMTarget创建CodeGenLLVM实例
- 初始化codegenLLVM
- 向codegen中添加func
- 执行codegen
void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
# 1. 获取codegen实例
With<LLVMTarget> llvm_target(*(std::make_unique<LLVMInstance>()), target);
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(llvm_target.get());
# 2. 初始化codegen
cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.defined(),
target_c_runtime);
# 3. 向codegen中添加func
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
# 4. 执行codegen
module_owning_ptr_ = cg->Finish();
module_ = module_owning_ptr_.get();
}
LLVMTarget/LLVMInstance #
类功能:
- 包含LLVM代码生成使用的信息for 特定设备 LLVMTarget类数据成员:
# src/target/llvm/llvm_instance.h
class LLVMTarget : public LLVMTargetInfo {
...
private:
std::vector<Option> saved_llvm_options_;
const LLVMInstance& instance_;
std::weak_ptr<llvm::LLVMContext> ctx_;
static bool modified_llvm_state_;
};
LLVMTargetInfo类
- 此 TVM Target与 LLVM 代码生成相关的信息摘要。 数据成员:
class LLVMTargetInfo {
...
private:
std::string triple_;
std::string cpu_;
std::vector<std::string> attrs_;
std::vector<Option> llvm_options_;
llvm::TargetOptions target_options_;
llvm::FastMathFlags fast_math_flags_;
llvm::CodeGenOpt::Level opt_level_;
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
};
LLVM Instance:
- ctx_:
class LLVMInstance {
...
private:
std::shared_ptr<llvm::LLVMContext> ctx_;
};
CodeGenLLVM #
数据成员:
- function_:当前函数
- builder_:
- module_:返回的module
- llvm_target_:LLVM target信息
- 数据类型
- 元数据
- linked module
# src/target/llvm/codegen_llvm.h
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
llvm::Function* function_;
// Internal builder
std::unique_ptr<IRBuilder> builder_;
// The module to be returned;
std::unique_ptr<llvm::Module> module_;
std::unique_ptr<llvm::DataLayout> data_layout_;
// Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm target info
LLVMTarget* llvm_target_{nullptr};
// helpful data types
llvm::Type* t_void_{nullptr};
llvm::PointerType* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
llvm::Type* t_int16_{nullptr};
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
// meta data
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
// modules to be linked.
std::vector<std::unique_ptr<llvm::Module>> link_modules_;
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
// The definition of local variable.
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// Map from TVM's GlobalVar to the llvm::Function that represents
// that function.
std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;
// Whether current function is restricted
bool is_restricted_{true};
// The analyzer information
std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias)
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// debug info for function being compiled
llvm::DISubprogram* di_subprogram_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin();
const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin();
const Op& builtin_lookup_param_ = builtin::lookup_param();
const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered();
Create #
函数功能: 基于LLVM的各种CodeGen后端:
tvm.codegen.llvm.target_cpu
tvm.codegen.llvm.target_x86-64
tvm.codegen.llvm.target_arm
tvm.codegen.llvm.target_rocm
tvm.codegen.llvm.target_nvptx
tvm.codegen.llvm.target_hexagon
函数实现:
- 根据target名称确定packedfunc f
- 调用f创建CodeGenLLVM实例
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(LLVMTarget* llvm_target) {
std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName();
std::string factory_template = "tvm.codegen.llvm.target_";
void* handle = nullptr;
if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) {
handle = (*f)();
} else if (const PackedFunc* f = runtime::Registry::Get(factory_template + "cpu")) {
handle = (*f)();
} else {
LOG(FATAL) << "no factory function for codegen for target " << target;
}
if (handle) {
return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
} else {
LOG(FATAL) << "unable to create codegen for target " << target;
}
}
CodeGenLLVM::Init #
函数功能:
- 虚函数,子类实现调用了该实现
# src/target/llvm/codegen_llvm.cc
void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target,
Optional<String> system_lib_prefix, bool dynamic_lookup,
bool target_c_runtime) {
llvm_target_ = llvm_target;
llvm::LLVMContext* ctx = llvm_target_->GetContext();
builder_.reset(new IRBuilder(*ctx));
module_.reset(new llvm::Module(module_name, *ctx));
md_builder_.reset(new llvm::MDBuilder(*ctx));
// types
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace());
t_int_ = llvm::Type::getInt32Ty(*ctx);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
// meta data
md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
InitTarget();
}
void CodeGenLLVM::InitTarget() {
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get()));
if (native_vector_bits_ == 0) {
const auto& arch = tm->getTargetTriple().getArch();
if (arch == llvm::Triple::x86_64) {
// for avx512
native_vector_bits_ = 512;
} else if (arch == llvm::Triple::x86) {
native_vector_bits_ = 256;
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
native_vector_bits_ = 128;
} else {
native_vector_bits_ = 128;
std::string arch_name = std::string(tm->getTargetTriple().getArchName());
LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
}
}
.......
}
CodeGenCPU::Init #
CodeGenNVPTX::Init #
AddFunctionsOrdered #
- 对函数进行排序
- 对函数进行申明
- 对函数进行添加
# src/target/llvm/codegen_llvm.h
void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) {
std::vector<std::tuple<GlobalVar, PrimFunc>> funcs;
sort(funcs)
DeclareFunction(gvar, func);
AddFunction(gvar, func);
}
DeclareFunction:
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
# 1. 创建llvm func
function = llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get());
# 2.
functions_[gvar.get()] = function;
}
AddFunction:
- 设置函数参数
- 添加函数体
- 设置函数返回值
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
var_map_[f->params[i];.get()] = v;
v->setName(std::string(var->name_hint));
}
# 获取irbuilder, 设置插入点, 然后调用Visit进行codegen
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0));
}
CodeGenSourceBase #
类描述: 类数据成员:
- stream:
- 每个变量的名称:var_idmap_
class CodeGenSourceBase {
...
protected:
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id, used to check if this entry is invalid. */
int scope_id;
};
/*! \brief the declaration stream */
std::ostringstream decl_stream;
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief the forward declaration stream */
std::ostringstream fwd_decl_stream;
/*! \brief name of each variable */
std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
/*! \brief NameSupply for allocation */
NameSupply name_supply_ = NameSupply("");
private:
/*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
/*! \brief The current indentation value */
int indent_{0};
};
CodeGenC #
类层次结构:
CodeGenC
CodeGenCUDA
CodeGenMetal
CodeGenCHost
CodeGenWebGPU
CodeGenOpenCL
CodeGenVivadoHLS
类简介:CodeGenC有两种模式:生成SSA形式的C代码或者正常形式的C代码 数据成员:
- let_binding_:let变量的绑定
- alloc_storage_scope_:分配的storage scope
- handle_data_type_:分配的buffer的数据类型
# src/target/source/codegen_c.h
class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)>,
public CodeGenSourceBase {
public:
......
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
Integer constants_byte_alignment_ = 16;
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
Init #
void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; }
PrintStmt… #
- visitStmt:用于利用stream生成c source文件,所以会有一些语法字符串输出,也会包含Scope等
- visitExpr:
void PrintStmt(const Stmt& n) { VisitStmt(n); }
std::string PrintExpr(const PrimExpr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
std::ostringstream temp;
VisitExpr(n, temp);
os << SSAGetID(temp.str(), n.dtype());
} else {
VisitExpr(n, os);
}
}
AddFunction #
- 处理函数参数
- BeginScope
- PrintStmt(f->body)
- EndScope
- PrintIndent()
void CodeGenC::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
// 1. 处理函数参数
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, stream);
}
} else {
PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
// VisitBody
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
【关键】 VisitStmt #
void PrintStmt(const Stmt& n) { VisitStmt(n); }
- AttrStmt节点,常用于:T.launch_thread, with T.attr等
- Allocate节点,常用于:T.allocate
- BufferStore节点:常用于:常规写入、T.Buffer、T.tvm_warp_activemask、T.tvm_warp_shuffle_down等
- evaluate节点(会被降级?):T.tvm_storage_sync(被device tir passes降级), T.tvm_thread_allreduce(被to_tir_runtime mixed mod降级)
- If then else节点:
if threadIdx_x <= 32, < 8, == 0
:
AttrStmtNode #
函数描述: 函数流程:
- attr_key是thread_extent:T.launch_thread
- attr_key是volatile_scope:T.attr(red_result, “volatile_scope”, 1)
- attr_key是pragma_import_c
- 进一步PrintStmt(body)
void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
BindThreadIndex(iv);
}
}
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
volatile_buf_.insert(v);
} else if (op->attr_key == tir::attr::pragma_import_c) {
const StringImmNode* value = op->value.as<StringImmNode>();
ICHECK(value != nullptr);
decl_stream << value->value;
}
this->PrintStmt(op->body);
}
AllocateNode #
函数描述: 函数流程:
- 为变量分配vid
- 获取分配变量空间大小和scope,加入alloc_storage_scope_
- 输出变量定义statement
- 注册Handle
void CodeGenC::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
# 1. 获取分配空间大小
this->PrintIndent();
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
# 2. 获取分配var的scope, 添加到alloc_storage_scope_中
auto scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
# 3. cout << local/shared/...
PrintStorageScope(scope, stream);
# 3. cout << float/int/...
PrintType(op->dtype, stream);
# 3. cout << vid << [num_size];
stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
EvaluateNode #
函数描述: 函数流程:
- op.value为一个CallNode
- call.op == tvm_storage_sync
- call.op == tvm_struct_set()
- 调用PrintExpr
void CodeGenC::VisitStmt_(const EvaluateNode* op) {
if (is_const_int(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call) {
if (call->op.same_as(builtin::tvm_storage_sync())) {
this->PrintStorageSync(call);
return;
} else if (call->op.same_as(builtin::tvm_struct_set())) {
ICHECK_EQ(call->args.size(), 4);
int kind = call->args[2].as<IntImmNode>()->value;
std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind);
std::string value = PrintExpr(call->args[3]);
std::string cast;
if (kind == builtin::kArrStrides) {
// cast void* to int64_t*
cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "";
} else if (kind == builtin::kArrDeviceType) {
// cast int to enum
cast = "(DLDeviceType)";
}
this->PrintIndent();
this->stream << ref << " = " << cast << value << ";\n";
return;
}
}
std::string vid = this->PrintExpr(op->value);
if (vid != "") {
this->PrintIndent();
this->stream << vid << ";\n";
}
}
【关键】VisitExpr #
CallNode #
函数流程:
- 获取算子op
- 根据算子op类型,执行不同的codegen
- builtin.tvm_check_return
- builtin.ret
- builtin_call_extern_、builtin_call_pure_extern_
- bitwise_and
- large_uint_imm
- bitwise_xor
- bitwise_or
- bitwise_not
- shift_left
- shift_right
- if_then_else
- address_of
- tvm_struct_get
- isnullptr
- reinterpret
- isnan
- lookup_param
Finish #
std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
target.build.cuda #
BuildCUDA函数实现流程:
- 遍历所有的PrimFun函数,添加到CodeGenCUDA中
- cg.Finish
- NVRTC编译
- 创建CUDAModule
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
// 1. 遍历所有的PrimFun函数,添加到CodeGenCUDA中
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
// 2.
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
// 3. NVRTC编译
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
// 4. 创建CUDAModule
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
CodeGenCUDA #
类描述: 类数据属性:
class CodeGenCUDA final : public CodeGenC {
public:
...
// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
};
AddFunction #
直接调用CodeGenC的AddFunction
VisitStmt #
针对的stmt节点:
- ForNode(添加#pragma unroll宏定义)
- EvaluateNode
- AllocateNode(关键)
- AttrStmtNode
AttrStmtNode #
- 属性为:fragment_shape,添加到数据属性fragment_shapes中
- 属性为:fragment_layout,添加到数据属性fragment_layouts中
- 属性为:async_commit_queue_scope,VisitStmt(body), visitExpr(commit_group)
- 属性为:async_wait_queue_scope,visitExpr(wait_group), visitStmt(body) 回调CodeGenC的VisitStmt(AttrStmtNode)
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
return;
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
this->VisitStmt(inner->body);
return;
}
CodeGenC::VisitStmt_(op);
}
AllocateNode #
void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
//
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>();
//
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}
if (scope == "shared.dyn") {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
if (scope.find("wmma.") == 0) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
}
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) &&
scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' ' << vid << '[' << constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
EvaluateNode #
- op为tvm_global_barrier_kinit
- 回调CodeGenC的EvaluateNode
void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
if (is_const_int(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
PrintIndent();
stream << "if (threadIdx.x == 0) {\n";
PrintIndent();
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
} else {
CodeGenC::VisitStmt_(op);
}
}
ForNode #
void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
ICHECK(is_const_int(op->min, 0));
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::VisitStmt_(op);
}
VisitExpr #
CodeGenCUDA实现的VisitExpr:
- CastNode
- CallNode
- BroadcastNode
- ShuffleNode
- SelectNode
- RampNode
- FloatImmNode
CallNode #
函数流程:
- 默认开启warp_shuffle
- 根据不同的op类型,执行不同的codegen
- 【TensorCore】tvm_fill_fragment
- 【TensorCore】tvm_load_matrix_sync
- 【TensorCore】tvm_store_matrix_sync
- 【TensorCore】tvm_mma_sync
- 【TensorCore】tvm_bmma_sync
- 【TensorCore】ptx_mma
- 【TensorCore】ptx_mma_sp
- 【TensorCore】ptx_ldmatrix
- 【TensorCore】mma_store
- 【TensorCore】mma_fill
- ptx_cp_async
- ptx_commit_group
- ptx_wait_group
- ptx_ldg32
- 调用CodegenC的visit call
BroadcastNode #
ShuffleNode #
void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
to_shuffle[i] = PrintExpr(op->vectors[i]);
}
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t* val = as_const_int(op->indices[i]);
ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
if (i != 0) os << ", ";
os << to_shuffle[*val];
}
os << ')';
}
Finish #
NVRTCCompile #
CUDAModuleCreate #
测试用例 #
test_vm_build #
test_vm_codegen_only #
test_vm_codegen_tir #
test_vm_execbuilder #
单个函数:
def test_vm_execute():
ib = relax.ExecBuilder()
with ib.function("func0", num_inputs=2):
ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
ex = ib.get()
>>> print(ex.as_text())
@func0:
call test.vm.add in: %0, %1 dst: %2
ret %2
@test.vm.add packed_func;
>>> type(ex)
tvm.relax.vm_build.Executable
>>> type(ex.mod), ex.mod
(tvm.runtime.module.Module, Module(relax.Executable, 36cf178))
多个函数:
def test_vm_multiple_func():
ib = relax.ExecBuilder()
with ib.function("func0", num_inputs=2):
ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
with ib.function("func1", num_inputs=2):
ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
ex = ib.get()
>>> print(ex.as_text())
@func0:
call test.vm.add in: %0, %1 dst: %2
ret %2
@test.vm.add packed_func;
@func1:
call test.vm.mul in: %0, %1 dst: %2
ret %2
@test.vm.mul packed_func;
>>>
指令参数:
- 寄存器
- 立即数
- 函数
# 函数参数为立即数
with ib.function("func0", num_inputs=1):
ib.emit_call("test.vm.add_scalar", args=[ib.imm(-3), ib.r(0)], dst=ib.r(1))
ib.emit_ret(ib.r(1))
ex = ib.get()
>>> print(ex.as_text())
@func0:
call test.vm.add_scalar in: i-3, %0 dst: %1
ret %1
@test.vm.add_scalar packed_func;
>>>
函数参数为函数:
def test_vm_invoke_closure():
ib = relax.ExecBuilder()
with ib.function("lifted_func_1", num_inputs=4):
ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(4))
ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(4)], dst=ib.r(5))
ib.emit_call("test.vm.add", args=[ib.r(3), ib.r(5)], dst=ib.r(6))
ib.emit_ret(ib.r(6))
with ib.function("main", num_inputs=2):
ib.emit_call(
"vm.builtin.make_closure", args=[ib.f("lifted_func_1"), ib.r(0), ib.r(1)], dst=ib.r(2)
)
ib.emit_ret(ib.r(2))
ex = ib.get()
>>> print(ex.as_text())
@lifted_func_1:
call test.vm.add in: %0, %1 dst: %4
call test.vm.add in: %2, %4 dst: %5
call test.vm.add in: %3, %5 dst: %6
ret %6
@test.vm.add packed_func;
@main:
call vm.builtin.make_closure in: f[lifted_func_1], %0, %1 dst: %2
ret %2
@vm.builtin.make_closure packed_func;
>>>