【TVM】Runtime源码分析

参考文章:

相关源码文件:

  • 【relax.build】python/tvm/relax/vm_build.py

  • 【tvm.build】python/tvm/driver/build_module.py

  • src/relax/backend/vm/codegen_vm.cc

  • tests/python/relax/test_vm_build.py

  • tests/python/relax/test_vm_codegen_only.py

  • tests/python/relax/test_vm_codegen_tir.py

  • tests/python/relax/test_vm_execbuilder.py

核心API:

  • relax.build
  • relax.VirtualMachine

API 接口 #

relax.build #

函数功能:将IRModule构建为可执行的VM 函数参数:

  • mod
  • target:目标平台target
  • exec_mode:执行模式(字节码或者编译) 返回类型:
  • ex: tvm.relax.Executable(可以被virtual machine加载的可执行的…) 函数流程:
  1. relax function transform
  2. relax function codegen ->
  3. tir function build -> tvm.Module
  4. vmlink -> executable 函数实现:
def build(
    mod: tvm.IRModule,
    target: Union[str, tvm.target.Target],...
):
	# 1. 调用PASS完成mod转换
    passes = []
    passes.append(relax.transform.RewriteDataflowReshape())
    passes.append(relax.transform.ToNonDataflow())
    passes.append(relax.transform.RemovePurityChecking())
    passes.append(relax.transform.CallTIRRewrite())
    passes.append(relax.transform.StaticPlanBlockMemory())

    if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False):
        passes.append(relax.transform.RewriteCUDAGraph())

    passes.append(relax.transform.VMBuiltinLower())
    passes.append(relax.transform.VMShapeLower())
    passes.append(relax.transform.AttachGlobalSymbol())
    seq = tvm.transform.Sequential(passes)
    new_mod = seq(mod)

    # 2. builder collects the executable
    builder = relax.ExecBuilder()
    leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) # codegen: relax2vm code
    tir_mod = _filter_tir(leftover_mod) # 筛选出tir function
    return _vmlink(builder, target, tir_mod, ext_libs, params, system_lib=system_lib)

_vmcodegen:

if exec_mode == "bytecode":
	return _ffi_api.VMCodeGen(builder, mod)  # type:ignore
if exec_mode == "compiled":
	return _ffi_api.VMTIRCodeGen(builder, mod)  # type: ignore

_vmlink:

def _vmlink(
    builder: "relax.ExecBuilder",
    target: Union[str, tvm.target.Target],
    tir_mod: Optional[tvm.IRModule] = None,...
):
	# 只用到了tir_mod, 返回tvm.Module
	lib = tvm.build(
            tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib)
        )
    # 用到了builder和tvm.Module, 返回Executable
    return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params))

VMCodeGen #

函数功能:根据给定的mod中的relax.Function,创建可执行的relax VM,并且将其添加到exec_builder中 函数参数:

  • exec_builder:用于收集executables的Builder
  • mod
IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
  return CodeGenVM::Run(exec_builder, mod);
}

TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);

具体实现: CodeGenVM

VMTIRCodeGen #

函数功能:

  1. 根据给定的mod中的relax.Function,创建可执行的relax VM,并且将其添加到exec_builder中
  2. 然后创建额外的TIR Functions. 函数参数:
  • exec_builder:用于收集executables的Builder
IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) {
  return CodeGenVMTIR::Run(exec_builder, mod);
}

TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen);

具体实现: CodeGenVMTIR

函数功能:Link the libraries together. 函数返回:tvm.runtime.Module

Module VMLink(ExecBuilder builder, 
			  Target target, Optional<Module> lib, 
			  Array<Module> ext_libs,
              Map<String, runtime::NDArray> params) {
  // TODO(relax-team) Revisit the param and ext_lib options.
  ObjectPtr<Executable> executable = builder->Get();
  std::unordered_map<std::string, runtime::NDArray> conv_params;
  for (const auto& [name, param] : params) {
    conv_params[name] = param;
  }
  Module combined_lib = codegen::CreateMetadataModule(
      conv_params, lib.value(), ext_libs, target,

      // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support.
      // Before jumping into details, only support cpp runtime for now.
      relay::Runtime::Create("cpp"),
      relay::Executor::Create("graph"),  // TODO(@sunggg): pass arbitrarily executor. CPP runtime
                                         // won't use this anyways.
      relay::backend::ExecutorCodegenMetadata());
  executable->Import(combined_lib);
  return Module(executable);
}

TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink);

relax.VirtualMachine #

relax.build->tvm.build #

tvm.build:

def (inputs, args, target, target_host, runtime, name):
	# 1. 对输入func, mod进行降级
	input_mod = lower(inputs, name=name)
	target_input_mod = {target: input_mod}
	# 2. 标注runtime
	annotated_mods = {}
    for tar, mod in target_input_mod.items():
	    annotated_mods[tar] = mod.with_attr("runtime", runtime)
	# 
	annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)
	# 4. 将tir转换为rt_mod
	rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
	# 5. 
	return OperatorModule.from_module(to_return, ir_module_by_target=annotated_mods, name=name)

1. lower #

# python/tvm/driver/build_module.py:lower
def lower(inp, args, name, ...):
    if isinstance(inp, IRModule):
        return ffi.lower_module(inp, simple_mode)
    if isinstance(inp, PrimFunc):
        return ffi.lower_primfunc(inp, name, simple_mode)
    if isinstance(inp, te.Schedule):
        return ffi.lower_schedule(inp, args, name, binds, simple_mode)

cpp注册global func:

# src/driver/driver_api.cc
TVM_REGISTER_GLOBAL("driver.lower_primfunc")
    .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) {
      return LowerPrimFunc(std::move(func), name, simple_mode);
    });

2. tir_to_runtime #

# src/driver/driver_api.cc
TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
    .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
      return TIRToRuntime(inputs_arg, host_target);
    });

TIRToRuntime:

  1. 将ir_mod执行变换,后拆分为host_mod和device_mod,对拆分后的mod执行变换
  2. 对mhost_all进行codegen, 得到runtime::Module: target.build.llvm target.build.cuda
# src/driver/driver_api.cc
  # 1. 对混合mod执行变换并拆分
  auto pair = SplitMixedModule(ir_module, target, target_host);
  auto& host_mod = pair.first;
  auto& device_mod = pair.second;
  # 2. 对host_mod和device_mod执行codegen
  if (overrides_host_target && non_host_target_kind) {
	device_modules.push_back(codegen::Build(host_mod, it.first));
  } else {
	mhost_all->Update(host_mod);
  }

  if (device_mod->functions.size() != 0) {
	device_modules.push_back(codegen::Build(device_mod, it.first));
  }

SplitMixedModule #

std::pair<IRModule, IRModule> SplitMixedModule(
			IRModule mod_mixed, 
			const Target& target_arg,
            const Target& target_host_arg) {
  Target target = target_arg, target_host = target_host_arg;
  CheckAndUpdateHostConsistency(&target, &target_host);
  # 1. 对混合mod执行变换
  mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
  # 2. 得到host_mod, 并执行变换
  IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
  # 3. 得到device_mod, 并执行变换
  IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));

  auto keys = target->GetKeys();

  CheckAndUpdateHostConsistency(&target, &target_host);

  bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
  if (target_is_gpu && device_mod->functions.size() == 0) {
    DLOG(WARNING) << "Specified target " << target->str()
                  << " but cannot find device code. Did you forget to bind?";
  }

  return {host_mod, device_mod};
}
MixedModulePassManager #

包含的PASS:

  • VerifyVTCMLimit

  • LowerVtcmAlloc

  • BindTarget

  • VerifyMemory

  • AnnotateEntryFunc

  • ThreadSync

  • MergeDynamicSharedMemoryAllocations

  • InferFragment

  • LowerThreadAllreduce

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
  transform::PassContext pass_ctx = transform::PassContext::Current();

  Array<Pass> mixed_pass_list;

  // VerifyVTCMLimit must occur before LowerVtcmAlloc
  mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
  // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
  mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

  mixed_pass_list.push_back(tir::transform::BindTarget(target));

  mixed_pass_list.push_back(tir::transform::VerifyMemory());

  mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());

  bool detect_global_barrier =
      pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
  if (detect_global_barrier) {
    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
  }

  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
  mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
  mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
  mixed_pass_list.push_back(tir::transform::InferFragment());
  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

  bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

  if (use_async_copy) {
    mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
  }

  bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
  if (ptx_ldg32) {
    mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
  }

  mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
  mixed_pass_list.push_back(tir::transform::SplitHostDevice());

  bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
                          .value_or(relay::Executor::Create("graph", {}))
                          ->GetAttr<Bool>("unpacked-api")
                          .value_or(Bool(false));
  if (unpacked_api) {
    mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
  } else {
    mixed_pass_list.push_back(tir::transform::MakePackedAPI());
  }
  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
  mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

  mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

  return transform::Sequential(mixed_pass_list);
}
HostModulePassManager #

包含的PASS:

  • BindTarget
  • LowerTVMBuiltin
  • LowerCustomDatatypes
  • LowerIntrin
  • LowerDeviceStorageAccessInfo
  • CombineContextCall
transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
  transform::PassContext pass_ctx = transform::PassContext::Current();
  bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

  Array<tvm::transform::Pass> host_pass_list;

  runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
    return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
           CallingConv::kDeviceKernelLaunch;
  };
  host_pass_list.push_back(tir::transform::Filter(fcond));

  ICHECK(mixed_mod.defined()) << "This module must be defined";

  host_pass_list.push_back(tir::transform::BindTarget(target_host));

  host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
  host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
  host_pass_list.push_back(tir::transform::LowerIntrin());
  host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
  host_pass_list.push_back(tir::transform::CombineContextCall());

  if (enable_debug) {
    host_pass_list.push_back(tir::transform::InstallDebugSpans());
  }

  return transform::Sequential(host_pass_list);
}
DeviceModulePassManager #

包含的PASS:

  • Filter
  • BindTarget
  • LowerWarpMemory
  • Simplify
  • LowerCustomDatatypes
  • LowerDeviceStorageAccessInfo
  • LowerIntrin
transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
  Array<Pass> device_pass_list;
  runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
    return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
           CallingConv::kDeviceKernelLaunch;
  };
  device_pass_list.push_back(tir::transform::Filter(fcond));

  device_pass_list.push_back(tir::transform::BindTarget(target));

  device_pass_list.push_back(tir::transform::LowerWarpMemory());
  device_pass_list.push_back(tir::transform::Simplify());
  device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
  device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
  device_pass_list.push_back(tir::transform::LowerIntrin());

  return transform::Sequential(device_pass_list);
}

Codegen #

核心类 #

vm.builtin #

AllocNDArray #

参数:

  • offset
  • ShapeTuple
  • DLDataType
# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method<Storage>(&StorageObj::AllocNDArray);

runtime::NDArray StorageObj::AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype) {
  VerifyDataType(dtype);

  // critical zone: allocate header, cannot throw
  runtime::NDArray::Container* container =
      new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device);

  container->SetDeleter(StorageObj::Deleter);
  size_t needed_size = runtime::GetDataSize(container->dl_tensor);
  this->IncRef();
  // The manager context pointer must continue to point to the storage object
  // which owns the backing memory, and keeps track of the reference count.
  //
  // When we free a container we extract the storage object, decrement its
  // reference count, then destroy the container, but leave the underlying
  // buffer intact.
  container->manager_ctx = reinterpret_cast<void*>(this);

  // is this UB?
  // The only change we make w.r.t offset is modifying the data pointer
  // of the backing tensor to point into the buffer instead of its start.
  auto offset_ptr = reinterpret_cast<uint8_t*>(this->buffer.data) + offset;
  container->dl_tensor.data = reinterpret_cast<void*>(offset_ptr);

  runtime::NDArray ret(runtime::GetObjectPtr<Object>(container));
  // RAII in effect, now run the check.

  ICHECK(offset + needed_size <= this->buffer.size)
      << "storage allocation failure, attempted to allocate " << needed_size << " at offset "
      << offset << " in region that is " << this->buffer.size << "bytes";

  return ret;
}

CodeGenVM #

类功能:用于针对Relax函数生成可执行VM的类 Relax AST

  • 即将每个relax函数中的每个Varbinding都生成vm code(vm.builtin.xxx/relax.run.xxx/fused_xxx/) 类数据成员:
  • builder_:内部的ExecBuilder
class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
 public:...
 protected:
   /*! \brief Internal ExecBuilder. */
  relax::ExecBuilder builder_;
  /*!
   * \brief Total number of virtual registers allocated.
   * \note The first two registers are reserved for special registers.
   */
  size_t registers_num_ = 0;
  /*! \brief Map from var to register number. */
  std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual> var_arg_map_;
  /*! \brief the context module. */
  IRModule ctx_mod_;
  /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
  const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
  const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
  const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
  const Op& null_value_op_ = Op::Get("relax.null_value");
};

Run #

  1. 定义CodeGenVM,对mod中的每个relax function进行codegen
  2. 在ir_mod中移除已经完成codegen的relax function, 只保留PrimFunc函数
  static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
    IRModule res_mod = mod;
    res_mod.CopyOnWrite();
    # 1. 定义CodeGenVM
    CodeGenVM codegen(builder, mod);
    # 2. 移除relax function, 并且将其转换为TIR函数
    for (const auto& [gvar, f] : mod->functions) {
      if (auto* func = f.as<FunctionNode>()) {
        codegen.Codegen(GetRef<Function>(func));
        res_mod->Remove(gvar);
      }
    }
    return res_mod;
  }

Codegen #

  1. builder->emit函数申明: EmitFunction/DeclareFunction
  2. emit函数体:CodegenVM自己实现
    • VisitExpr: 调用builder.emit得到Instruction, 并放入executable中
  3. builder->emit函数返回值: EmitRet
  4. builder->emit函数END: EndFunction
  void Codegen(const Function& func) {
    # 1. 获取函数参数、函数名
    Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    Array<String> param_names;
    for (Var param : func->params) {
      param_names.push_back(param->name_hint());
    }
	# 2. 利用builder完成emit函数声明
    builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names);
	# 3. 遍历函数参数, 新建寄存器并插入到var_arg_map_中
    for (size_t i = 0; i < func->params.size(); ++i) {
      RegName r = NewRegister();
      this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)});
    }
    # 3. visitexpr生成函数body
    Instruction::Arg ret = ExprFunctor::VisitExpr(func->body);
    # 4. 利用builder完成生成函数返回值、EndFunction
    builder_->EmitRet(EnsureReg(ret));
    builder_->EndFunction(gsymbol.value());
    // reset register number to be 0;
    registers_num_ = 0;
    var_arg_map_.clear();
  }

VisitExpr(func->body) #

SeqExprNode:

  • 对于每个blocks中的每个binding, 更新var_arg_map_
  • 对于seq_expr的返回值(VarNode),Visit, 并返回
  Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
    for (auto block : op->blocks) {
      for (Binding binding : block->bindings) {
        Instruction::Arg value;
        if (auto* var_binding = binding.as<VarBindingNode>()) {
          value = this->VisitExpr(var_binding->value);
        } else if (auto* match_cast = binding.as<MatchCastNode>()) {
          value = this->VisitExpr(match_cast->value);
        }
        this->var_arg_map_.insert({binding->var, value});
      }
    }

    Instruction::Arg ret_reg = this->VisitExpr(op->body);
    return ret_reg;
  }

CallNode:

  1. 得到call.op
  2. 对于已经注册的PackedFunc,调用EmitPackedFuncCall(call, name, dst_reg);
  3. 如果call.op == call_builtin_with_ctx_op_
  4. 如果call.op == 分配内存
  5. 如果call.op == 分配tensor
  6. 如果call.op == kill obj
  Instruction::Arg VisitExpr_(const CallNode* call_node) final {
    Call call = GetRef<Call>(call_node);

    if (call_node->op == null_value_op_) {
      return Instruction::Arg::Register(Instruction::kVoidRegister);
    }

    // allocate dst register.
    RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister();
    if (call->op.as<OpNode>()) {
      # 判断是否packFunc已经注册
      FCallPacked name = GetPackedFuncName(call);
      if (!name.empty()) {
        // If the operator has a registered packed function implementation, emit call to that packed
        // function.
        EmitPackedFuncCall(call, name, dst_reg);
      } else if (call_node->op == call_builtin_with_ctx_op_) {
        // TODO(relax-team) migrate most handling of op to
        // directly map to call_builtin_with_ctx before codegen and simplify vm codegen.
        EmitCallBuiltinWithCtx(call, dst_reg);
      } else if (call_node->op == alloc_storage_op_) {
        EmitAllocStorage(call, dst_reg);
      } else if (call_node->op == alloc_tensor_op_) {
        EmitAllocTensor(call, dst_reg);
      } else if (call_node->op == kill_object_op_) {
        dst_reg = EmitKillObject(call);
      } else {
        // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
        // ops are handled in a pass when lowering them to TIR.
        LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op;
      }
    } else {
      EmitNormalCall(call, dst_reg);
    }
    return Instruction::Arg::Register(dst_reg);
  }
  • EmitPackedFuncCall:
  void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) {
    # 得到参数
    std::vector<Instruction::Arg> args = VisitArray(call_node->args);
    # 调用builder完成EmitCall
    builder_->EmitCall(name, args, dst_reg);
  }
  • EmitAllocTensor
  void EmitAllocTensor(const Call& call_node, RegName dst_reg) {
    std::vector<Instruction::Arg> args;
    args.reserve(4);
    for (Expr arg : call_node->args) {
      args.push_back(this->VisitExpr(arg));
    }
    builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
  }
  • EmitAllocStorage
  void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
    ICHECK_EQ(call_node->args.size(), 4);
    // Handle args of the call
    std::vector<Instruction::Arg> args;
    args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
    // buffer size, dtype, device index
    for (auto arg : call_node->args) {
      args.push_back(this->VisitExpr(arg));
    }
    builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg);
  }

CodeGenVMTIR #

Run #

Codegen #

relax.ExecBuilder #

类功能:

  • 辅助codegen,完成生成vm code 类数据成员:
  • exec_:集成自runtime::Module,保存了codegen的结果
private:
  vm::Instruction::Arg ConvertConstant_(TVMRetValue obj);
  /*! \brief The mutable internal executable. */
  ObjectPtr<vm::Executable> exec_;  // mutable
  /*! \brief internal dedup map when creating index for a new constant */
  std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual> const_dedup_map_;
}

EmitFunction/DeclareFunction #

  1. vm函数声明:创建vmfunc, 放入exec_的func_table中
  2. 从exec_的func_map中根据name得到index;根据index从func_table中得到vmfunc(类型为VMFuncInfo)
  3. 设置vm_func信息:参数数目、参数名称、
void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs,
                                   Optional<Array<String>> param_names,
                                   vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) {
  # 1. 申明函数, 创建vmfunc, 放入exec_的func_table中
  this->DeclareFunction(func_name, kind);
  # 2. 从exec_中得到vmfunc
  auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name));
  # 3. 设置vmfunc的参数数目,参数名、寄存器文件大小、start_instr
  vmfunc.num_args = num_inputs;
  if (param_names.defined()) {
    std::vector<std::string> names;
    for (auto name : param_names.value()) {
      names.push_back(name);
    }
    vmfunc.param_names = names;
  }
  vmfunc.register_file_size = init_register_size;
  if (kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
    vmfunc.start_instr = exec_->instr_offset.size();
  }
}

EmitCall #

  • 通过exec_设置指令offset,并且添加指令数据
void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, 
							   std::vector<vm::Instruction::Arg> args,
                               vm::RegName dst) {
  // store instruction
  exec_->instr_offset.push_back(exec_->instr_data.size());
  exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Call));
  exec_->instr_data.push_back(dst);
  exec_->instr_data.push_back(func.value());
  exec_->instr_data.push_back(args.size());
  for (Instruction::Arg arg : args) {
    exec_->instr_data.push_back(arg.data());
  }
}

EmitRet #

  • push instr_offset
  • push instr_data
void ExecBuilderNode::EmitRet(vm::Instruction::Arg result) {
  exec_->instr_offset.push_back(exec_->instr_data.size());
  exec_->instr_data.push_back(static_cast<ExecWord>(Opcode::Ret));
  exec_->instr_data.push_back(result.value());
}

EndFunction #

void ExecBuilderNode::EndFunction(const std::string& func_name) {
  # 得到vmfunc, 标记end_instr
  auto it = exec_->func_map.find(func_name);
  VMFuncInfo& vmfunc = exec_->func_table.at(it->second);

  if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) {
    vmfunc.end_instr = exec_->instr_offset.size();
  }
}

relax_vm.Executable #

类功能:包含Emit的函数指令和实现 数据成员:

  /*! \brief The virtual machine's function table. */
  std::vector<VMFuncInfo> func_table;
  /*! \brief A map from globals (as strings) to their index in the function map. */
  std::unordered_map<std::string, Index> func_map;
  /*! \brief The global constant pool. */
  std::vector<TVMRetValue> constants;
  /*! \brief The offset of instruction. */
  std::vector<Index> instr_offset;
  /*! \brief The byte data of instruction. */
  std::vector<ExecWord> instr_data;

GetFunction #

PackedFunc Executable::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
  if (name == "stats") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); });
  } else if (name == "as_text") {
    return PackedFunc(
        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); });
  } else if (name == "as_python") {
    return PackedFunc(
        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); });
  } else if (name == "vm_load_executable") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
      ICHECK(sptr_to_self.get() == this);
      vm->LoadExecutable(GetObjectPtr<Executable>(this));
      *rv = Module(vm);
    });
  } else if (name == "vm_profiler_load_executable") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
      ICHECK(sptr_to_self.get() == this);
      vm->LoadExecutable(GetObjectPtr<Executable>(this));
      *rv = Module(vm);
    });
  }
  return nullptr;
}

vm_load_executable #

关键部分:

  1. 创建VirtualMachine
  2. 调用vm->LoadExecutable
  3. 使用Module对VM进行包装
  } else if (name == "vm_load_executable") {
    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
      ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
      ICHECK(sptr_to_self.get() == this);
      vm->LoadExecutable(GetObjectPtr<Executable>(this));
      *rv = Module(vm);
    });
  } 

vm_profiler_load_executable #

GetInstruction #

  • 判断opcode类型
  • 从instr_data中获取指令信息,新建指令,并返回
Instruction Executable::GetInstruction(Index i) const {
  Index offset = instr_offset[i];
  Opcode op = static_cast<Opcode>(instr_data[offset]);
  switch (op) {
    case Opcode::Call: {
      RegName dst = instr_data[offset + 1];
      Index func_idx = instr_data[offset + 2];
      Index num_args = instr_data[offset + 3];
      ExecWord* args = const_cast<ExecWord*>(&instr_data[offset + 4]);
      return Instruction::Call(func_idx, num_args, reinterpret_cast<Instruction::Arg*>(args), dst);
    }
    case Opcode::Ret: {
      RegName result = instr_data[offset + 1];
      return Instruction::Ret(result);
    }
    case Opcode::Goto: {
      Index pc_offset = instr_data[offset + 1];
      return Instruction::Goto(pc_offset);
    }
    case Opcode::If: {
      RegName cond = instr_data[offset + 1];
      Index false_offset = instr_data[offset + 2];
      return Instruction::If(cond, false_offset);
    }
    default:
      LOG(FATAL) << "should never hit this case: " << static_cast<int>(op);
      break;
  }
  return Instruction();
}

AsText #

export_library #

函数功能:将可执行程序导出为一个链接库

VirtualMachine :: ModuleNode #

类功能:负责对executable的vm code解释执行

class VirtualMachine(object):
    def __init__(
        self,
        rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"],
        ...
	):
        load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable"
        self.module = rt_mod[load_exec]()
    
    def __getitem__(self, key: str) -> PackedFunc:
        return self.module[key]

数据成员:

# include/tvm/runtime/vm/vm.h
class TVM_DLL VirtualMachine : public runtime::ModuleNode {
 public:
 ...
 protected:
  /*! \brief The virtual machine's packed function table. */
  std::vector<PackedFunc> packed_funcs_;
  /*! \brief The current stack of call frames. */
  std::vector<VMFrame> frames_;
  /*! \brief The fuction table index of the current function. */
  Index func_index_;
  /*! \brief The current pointer to the code section. */
  const Instruction* code_;
  /*! \brief The virtual machine PC. */
  Index pc_;
  /*! \brief The special return register. */
  ObjectRef return_register_;
  /*! \brief The executable the VM will operate on. */
  ObjectPtr<Executable> exec_;
  /*! \brief The function name to inputs mapping. */
  std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
  /*! \brief The function name to flag enabling scenario with set outputs. */
  std::unordered_map<std::string, bool> set_outputs_enabled_;
  /*! \brief The index of operation which destination is result. */
  Index preresult_op_index_ = -1;
  /*! \brief The function name to indices of output tensors in register file. */
  std::unordered_map<std::string, std::vector<Index>> output_tensor_reg_indices_;
  /*! \brief The function name to pre-allocated outputs mapping. */
  std::unordered_map<std::string, std::vector<ObjectRef>> outputs_;
  /*!
   * \brief The "physical" devices the VM can execute primitives on. All "device indexes"
   * are w.r.t. this vector. Each entry in this vector must match the corresponding entry
   * in the executable's "virtual" devices vector.
   */
  std::vector<Device> devices_;
  // 内存分配器, 每个设备一个
  std::vector<Allocator*> allocators_;
  /*!
   * \brief The constant pool for runtime. It caches the device dependent
   * object to avoid rellocation of constants during inference.
   */
  std::vector<ObjectRef> const_pool_;
};

创建对象 #

auto vm = make_object<VirtualMachine>();

LoadExecutable #

  • 赋值:exec_
  • 赋值:packed_funcs_
void VirtualMachine::LoadExecutable(const ObjectPtr<Executable>& exec) {
  exec_ = exec;

  runtime::Module lib = exec_->GetLib();

  for (const auto& it : exec_->primitive_map) {
    const auto& packed_name = it.first;
    auto packed_index = static_cast<size_t>(it.second);
    if (packed_funcs_.size() <= packed_index) {
      packed_funcs_.resize(packed_index + 1);
    }
    tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, /*query_imports=*/true);
    packed_funcs_[packed_index] = pf;
  }
}

RuntimeModule #

vm_load_executable #

tvm.build核心类 #

主要包含两个部分:

  • lower
  • codegen

lower_primfunc #

函数功能:对PrimFunc进行降级

LowerPrimFunc 函数实现:

  1. 创建PassList
    • PHASE 1:Flatten、LowerInitBlock、SharedMemory、Narrow、
    • PHASE 2:LoopPartition、VectorizeLoop、InjectVirtualThread、UnrollLoop
    • PHASE 3:RemoveNoOp
  2. 根据PassList对func进行降级
# src/driver/driver_api.cc
IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) {
  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));

  // Get the pass list
  Array<transform::Pass> pass_list = CreatePassList(simple_mode);
  return LowerWithPassList(std::move(mod), pass_list);
}

PassList:

Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
  transform::PassContext pass_ctx = transform::PassContext::Current();

  bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
  bool disable_storage_rewrite =
      pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
  bool instrument_bound_checkers =
      pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
  bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
  bool enable_equiv_terms_in_cse_tir =
      pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();

  bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();

  // Get any user-added passes
  Array<Array<ObjectRef>> add_lower_pass =
      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
          .value();

  bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();

  Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
  Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
  Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
  Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();

  // phase passes is of the form
  // [[phase_number, pass], [phase_number, pass]... ]
  for (Array<ObjectRef> phase_pass : add_lower_pass) {
    const IntImmNode* phase_num = phase_pass[0].as<IntImmNode>();
    ICHECK(phase_num)
        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
    int phase_num_val = phase_num->value;

    CHECK_GE(phase_num_val, 0);

    auto pass = Downcast<tvm::transform::Pass>(phase_pass[1]);
    // Copy the pass into the correct phase
    if (phase_num_val == 0) {
      user_lower_phase0.push_back(pass);
    } else if (phase_num_val == 1) {
      user_lower_phase1.push_back(pass);
    } else if (phase_num_val == 2) {
      user_lower_phase2.push_back(pass);
    } else if (phase_num_val >= 3) {
      user_lower_phase3.push_back(pass);
    }
  }

  // Construct the pass list, inserting the user provided passes at the end of the phase

  // PHASE 0
  Array<tvm::transform::Pass> pass_list = user_lower_phase0;

  // PHASE 1
  pass_list.push_back(tir::transform::InjectPrefetch());
  pass_list.push_back(tir::transform::TextureFlatten());
  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
  pass_list.push_back(tir::transform::LowerCrossThreadReduction());
  pass_list.push_back(tir::transform::LowerInitBlock());
  pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
  pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
  pass_list.push_back(tir::transform::LiftThreadBinding());
  pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
  pass_list.push_back(tir::transform::CompactBufferAllocation());
  pass_list.push_back(tir::transform::LowerAutoCopy());
  pass_list.push_back(tir::transform::UnifyThreadBinding());
  pass_list.push_back(tir::transform::LowerMatchBuffer());
  pass_list.push_back(tir::transform::Simplify());
  pass_list.push_back(tir::transform::InjectPermutedLayout());
  pass_list.push_back(tir::transform::Simplify());
  pass_list.push_back(tir::transform::InjectSoftwarePipeline());
  pass_list.push_back(tir::transform::TransformMmaBufferLayout());
  pass_list.push_back(tir::transform::LowerOpaqueBlock());
  pass_list.push_back(tir::transform::FlattenBuffer());
  pass_list.push_back(tir::transform::FP8ComputeLegalize());
  pass_list.push_back(tir::transform::BF16ComputeLegalize());
  pass_list.push_back(tir::transform::NarrowDataType(32));
  pass_list.push_back(tir::transform::Simplify());

  // Add user-defined phase-1 passes
  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());

  // PHASE 2
  if (!disable_loop_partition) {
    pass_list.push_back(tir::transform::LoopPartition());
  }

  pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
  pass_list.push_back(tir::transform::InjectVirtualThread());
  pass_list.push_back(tir::transform::InjectDoubleBuffer());
  if (!disable_storage_rewrite) {
    pass_list.push_back(tir::transform::StorageRewrite());
  }
  bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

  if (use_async_copy) {
    pass_list.push_back(tir::transform::LowerAsyncDMA());
  }
  pass_list.push_back(tir::transform::UnrollLoop());

  // Add user-defined phase-2 passes
  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());

  // PHASE 3
  pass_list.push_back(tir::transform::RenormalizeSplitPattern());
  pass_list.push_back(tir::transform::Simplify());
  pass_list.push_back(tir::transform::RemoveNoOp());
  pass_list.push_back(tir::transform::RewriteUnsafeSelect());
  pass_list.push_back(tir::transform::HoistIfThenElse());

  // Add user-defined phase-3 passes
  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());

  if (instrument_bound_checkers) {
    pass_list.push_back(tir::transform::InstrumentBoundCheckers());
  }

  if (ptx_ldg32) {
    pass_list.push_back(tir::transform::InjectPTXLDG32(true));
  }

  pass_list.push_back(
      tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));

  // This pass instruments the loops with the profile builtin calls to capture the runtime
  // performance data (only enabled for Hexagon at the moment). To ensure that no other
  // optimizations are performed on the instrumented code, this pass must be added at the end
  // of the list.
  if (instrument_lwp) {
    pass_list.push_back(tir::transform::InstrumentProfileIntrinsics());
  }

  return pass_list;
}

codegen::build #

tvm::codegen::build

  • 得到目标target的属性
  • 调用目标target的build函数进行build
# src/target/codegen.cc
runtime::Module Build(IRModule mod, Target target) {
  # 1. 得到目标target的属性
  auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
  target_attr_map[target->kind](mod, target);
  # 2. 调用目标target的build函数进行build
  std::string build_f_name = "target.build." + target->kind->name;
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  return (*bf)(mod, target);
}

target.build.llvm #

# src/target/llvm/llvm_module.cc
TVM_REGISTER_GLOBAL("target.build.llvm")
    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
      # 1. 构造LLVMModuleNode
      auto n = make_object<LLVMModuleNode>();
      # 2. 使用IRModule对其进行初始化
      n->Init(mod, target);
      # 3. 返回
      return runtime::Module(n);
    });

相关数据结构:

  • LLVMModuleNode
  • LLVMTarget
  • CodeGenLLVM

LLVMModuleNode->Init #

  1. 创建LLVMTarget,根据LLVMTarget创建CodeGenLLVM实例
  2. 初始化codegenLLVM
  3. 向codegen中添加func
  4. 执行codegen
void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
  # 1. 获取codegen实例
  With<LLVMTarget> llvm_target(*(std::make_unique<LLVMInstance>()), target);
  std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(llvm_target.get());
  # 2. 初始化codegen
  cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.defined(),
	   target_c_runtime);
  # 3. 向codegen中添加func
  cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
  # 4. 执行codegen
  module_owning_ptr_ = cg->Finish();
  module_ = module_owning_ptr_.get();
}

LLVMTarget/LLVMInstance #

类功能:

  • 包含LLVM代码生成使用的信息for 特定设备 LLVMTarget类数据成员:
# src/target/llvm/llvm_instance.h
class LLVMTarget : public LLVMTargetInfo {
...
 private:
  std::vector<Option> saved_llvm_options_;
  const LLVMInstance& instance_;
  std::weak_ptr<llvm::LLVMContext> ctx_;
  static bool modified_llvm_state_;
};

LLVMTargetInfo类

  • 此 TVM Target与 LLVM 代码生成相关的信息摘要。 数据成员:
class LLVMTargetInfo {
...
 private:
  std::string triple_;
  std::string cpu_;
  std::vector<std::string> attrs_;
  std::vector<Option> llvm_options_;
  llvm::TargetOptions target_options_;
  llvm::FastMathFlags fast_math_flags_;
  llvm::CodeGenOpt::Level opt_level_;
  llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
  llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
  std::shared_ptr<llvm::TargetMachine> target_machine_;
};

LLVM Instance:

  • ctx_:
class LLVMInstance {
...
 private:
  std::shared_ptr<llvm::LLVMContext> ctx_;
};

CodeGenLLVM #

数据成员:

  • function_:当前函数
  • builder_:
  • module_:返回的module
  • llvm_target_:LLVM target信息
  • 数据类型
  • 元数据
  • linked module
# src/target/llvm/codegen_llvm.h
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
  std::unique_ptr<llvm::DataLayout> data_layout_;
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
  // llvm target info
  LLVMTarget* llvm_target_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
  llvm::PointerType* t_void_p_{nullptr};
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
  // meta data
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module>> link_modules_;
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
  /*! \brief the storage scope of allocation */
  std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
  // The definition of local variable.
  std::unordered_map<const VarNode*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;

  // Map from TVM's GlobalVar to the llvm::Function that represents
  // that function.
  std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;

  // Whether current function is restricted
  bool is_restricted_{true};
  // The analyzer information
  std::unique_ptr<arith::Analyzer> analyzer_;
  // set of var that are not restricted(can alias)
  std::unordered_set<const VarNode*> alias_var_set_;
  // set of volatile buffer.
  std::unordered_set<const VarNode*> volatile_buf_;
  // deep comparison of PrimExpr
  ExprDeepEqual deep_equal_;
  // binding of let variables. Enables duplicate var defs that map to same value
  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
  // debug info for function being compiled
  llvm::DISubprogram* di_subprogram_;
  // Cache potential common path ops to slightly improve lookup time.
  // global symbol table.
  OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
  const Op& builtin_call_extern_ = builtin::call_extern();
  const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
  const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin();
  const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin();
  const Op& builtin_lookup_param_ = builtin::lookup_param();
  const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered();

Create #

函数功能: 基于LLVM的各种CodeGen后端:

tvm.codegen.llvm.target_cpu
tvm.codegen.llvm.target_x86-64
tvm.codegen.llvm.target_arm
tvm.codegen.llvm.target_rocm
tvm.codegen.llvm.target_nvptx
tvm.codegen.llvm.target_hexagon

函数实现:

  • 根据target名称确定packedfunc f
  • 调用f创建CodeGenLLVM实例
std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(LLVMTarget* llvm_target) {
  std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName();
  std::string factory_template = "tvm.codegen.llvm.target_";
  void* handle = nullptr;
  if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) {
    handle = (*f)();
  } else if (const PackedFunc* f = runtime::Registry::Get(factory_template + "cpu")) {
    handle = (*f)();
  } else {
    LOG(FATAL) << "no factory function for codegen for target " << target;
  }
  if (handle) {
    return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle));
  } else {
    LOG(FATAL) << "unable to create codegen for target " << target;
  }
}

CodeGenLLVM::Init #

函数功能:

  • 虚函数,子类实现调用了该实现
# src/target/llvm/codegen_llvm.cc
void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target,
                       Optional<String> system_lib_prefix, bool dynamic_lookup,
                       bool target_c_runtime) {
  llvm_target_ = llvm_target;
  llvm::LLVMContext* ctx = llvm_target_->GetContext();
  builder_.reset(new IRBuilder(*ctx));
  module_.reset(new llvm::Module(module_name, *ctx));
  md_builder_.reset(new llvm::MDBuilder(*ctx));
  // types
  t_void_ = llvm::Type::getVoidTy(*ctx);
  t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace());
  t_int_ = llvm::Type::getInt32Ty(*ctx);
  t_char_ = llvm::Type::getInt8Ty(*ctx);
  t_int8_ = llvm::Type::getInt8Ty(*ctx);
  t_int16_ = llvm::Type::getInt16Ty(*ctx);
  t_int32_ = llvm::Type::getInt32Ty(*ctx);
  t_int64_ = llvm::Type::getInt64Ty(*ctx);
  t_float64_ = llvm::Type::getDoubleTy(*ctx);
  // meta data
  md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1);
  md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa");
  md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_);
  InitTarget();
}

void CodeGenLLVM::InitTarget() {
  llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
  module_->setTargetTriple(tm->getTargetTriple().str());
  module_->setDataLayout(tm->createDataLayout());
  data_layout_.reset(new llvm::DataLayout(module_.get()));
  if (native_vector_bits_ == 0) {
    const auto& arch = tm->getTargetTriple().getArch();
    if (arch == llvm::Triple::x86_64) {
      // for avx512
      native_vector_bits_ = 512;
    } else if (arch == llvm::Triple::x86) {
      native_vector_bits_ = 256;
    } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
      native_vector_bits_ = 128;
    } else {
      native_vector_bits_ = 128;
      std::string arch_name = std::string(tm->getTargetTriple().getArchName());
      LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
    }
  }
  .......
}
CodeGenCPU::Init #
CodeGenNVPTX::Init #

AddFunctionsOrdered #

  • 对函数进行排序
  • 对函数进行申明
  • 对函数进行添加
# src/target/llvm/codegen_llvm.h
void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) {
  std::vector<std::tuple<GlobalVar, PrimFunc>> funcs;
  sort(funcs)
  DeclareFunction(gvar, func);
  AddFunction(gvar, func);
}

DeclareFunction:

llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
  # 1. 创建llvm func
  function = llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get());
  # 2. 
  functions_[gvar.get()] = function;
}

AddFunction:

  1. 设置函数参数
  2. 添加函数体
  3. 设置函数返回值
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
  for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
    llvm::Argument* v = &(*arg_it);
	var_map_[f->params[i];.get()] = v;
	v->setName(std::string(var->name_hint));
  }
  # 获取irbuilder, 设置插入点, 然后调用Visit进行codegen
  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_);
  builder_->SetInsertPoint(entry);
  this->VisitStmt(f->body);

  builder_->CreateRet(ConstInt32(0));
}

CodeGenSourceBase #

类描述: 类数据成员:

  • stream:
  • 每个变量的名称:var_idmap_
class CodeGenSourceBase {
 ...
 protected:
  /*! \brief entry in ssa assign map */
  struct SSAEntry {
    /*! \brief The value id */
    std::string vid;
    /*! \brief The scope id, used to check if this entry is invalid. */
    int scope_id;
  };
  /*! \brief the declaration stream */
  std::ostringstream decl_stream;
  /*! \brief the stream to be printed */
  std::ostringstream stream;
  /*! \brief the forward declaration stream */
  std::ostringstream fwd_decl_stream;
  /*! \brief name of each variable */
  std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
  /*! \brief NameSupply for allocation */
  NameSupply name_supply_ = NameSupply("");

 private:
  /*! \brief assignment map of ssa */
  std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
  /*! \brief array to check whether we are inside certain scope */
  std::vector<bool> scope_mark_;
  /*! \brief The current indentation value */
  int indent_{0};
};

CodeGenC #

类层次结构:

CodeGenC
	CodeGenCUDA
	CodeGenMetal
	CodeGenCHost
	CodeGenWebGPU
	CodeGenOpenCL
	CodeGenVivadoHLS

类简介:CodeGenC有两种模式:生成SSA形式的C代码或者正常形式的C代码 数据成员:

  • let_binding_:let变量的绑定
  • alloc_storage_scope_:分配的storage scope
  • handle_data_type_:分配的buffer的数据类型
# src/target/source/codegen_c.h
class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
                 public StmtFunctor<void(const Stmt&)>,
                 public CodeGenSourceBase {
public:
  ......
  /*! \brief restrict keyword */
  std::string restrict_keyword_{""};
  /*! \brief the storage scope of allocation */
  std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
  /*! \brief the data type of allocated buffers */
  std::unordered_map<const VarNode*, DataType> handle_data_type_;
  /*! \brief Record of ops that have pre-defined global symbol. */
  OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
  // cache commonly used ops
  const Op& builtin_call_extern_ = builtin::call_extern();
  const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
  Integer constants_byte_alignment_ = 16;
  /*! \brief whether to print in SSA form */
  bool print_ssa_form_{false};

private:
  /*! \brief set of volatile buf access */
  std::unordered_set<const VarNode*> volatile_buf_;
  // deep comparison of PrimExpr
  ExprDeepEqual deep_equal_;
  // binding of let variables. Enables duplicate var defs that map to same value
  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

Init #

void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; }

PrintStmt… #

  • visitStmt:用于利用stream生成c source文件,所以会有一些语法字符串输出,也会包含Scope等
  • visitExpr:
void PrintStmt(const Stmt& n) { VisitStmt(n); }

std::string PrintExpr(const PrimExpr& n) {
  std::ostringstream os;
  PrintExpr(n, os);
  return os.str();
}
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
  if (print_ssa_form_) {
    std::ostringstream temp;
    VisitExpr(n, temp);
    os << SSAGetID(temp.str(), n.dtype());
  } else {
    VisitExpr(n, os);
  }
}

AddFunction #

  • 处理函数参数
  • BeginScope
  • PrintStmt(f->body)
  • EndScope
  • PrintIndent()
void CodeGenC::AddFunction(const PrimFunc& f) {
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  PrintType(f->ret_type, stream);
  this->PrintExtraAttrs(f);
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
  // 1. 处理函数参数
  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
    if (i != 0) stream << ", ";
    if (v.dtype().is_handle()) {
      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      PrintType(GetType(v), stream);
      // Register handle data type
      // TODO(tvm-team): consider simply keep type info in the
      // type annotation(via a normalizing rewriting).
      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  // VisitBody
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

【关键】 VisitStmt #

void PrintStmt(const Stmt& n) { VisitStmt(n); }
  • AttrStmt节点,常用于:T.launch_thread, with T.attr等
  • Allocate节点,常用于:T.allocate
  • BufferStore节点:常用于:常规写入、T.Buffer、T.tvm_warp_activemask、T.tvm_warp_shuffle_down等
  • evaluate节点(会被降级?):T.tvm_storage_sync(被device tir passes降级), T.tvm_thread_allreduce(被to_tir_runtime mixed mod降级)
  • If then else节点:if threadIdx_x <= 32, < 8, == 0:

AttrStmtNode #

函数描述: 函数流程:

  • attr_key是thread_extent:T.launch_thread
  • attr_key是volatile_scope:T.attr(red_result, “volatile_scope”, 1)
  • attr_key是pragma_import_c
  • 进一步PrintStmt(body)
void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
  if (op->attr_key == tir::attr::thread_extent) {
    IterVar iv = Downcast<IterVar>(op->node);
    if (iv->thread_tag.length() != 0) {
      if (!var_idmap_.count(iv->var.get())) {
        BindThreadIndex(iv);
      }
    }
  } else if (op->attr_key == tir::attr::volatile_scope) {
    const VarNode* v = op->node.as<VarNode>();
    ICHECK(v);
    volatile_buf_.insert(v);
  } else if (op->attr_key == tir::attr::pragma_import_c) {
    const StringImmNode* value = op->value.as<StringImmNode>();
    ICHECK(value != nullptr);
    decl_stream << value->value;
  }
  this->PrintStmt(op->body);
}

AllocateNode #

函数描述: 函数流程:

  1. 为变量分配vid
  2. 获取分配变量空间大小和scope,加入alloc_storage_scope_
  3. 输出变量定义statement
  4. 注册Handle
void CodeGenC::VisitStmt_(const AllocateNode* op) {
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());
  # 1. 获取分配空间大小
  this->PrintIndent();
  size_t constant_size = op->ConstantAllocationSize();
  ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
  # 2. 获取分配var的scope, 添加到alloc_storage_scope_中
  auto scope = GetPtrStorageScope(op->buffer_var);
  alloc_storage_scope_[op->buffer_var.get()] = scope;
  # 3. cout << local/shared/...
  PrintStorageScope(scope, stream);
  # 3. cout << float/int/...
  PrintType(op->dtype, stream);
  # 3. cout << vid << [num_size];
  stream << ' ' << vid << '[' << constant_size << "];\n";

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

EvaluateNode #

函数描述: 函数流程:

  • op.value为一个CallNode
  • call.op == tvm_storage_sync
  • call.op == tvm_struct_set()
  • 调用PrintExpr
void CodeGenC::VisitStmt_(const EvaluateNode* op) {
  if (is_const_int(op->value)) return;
  
  const CallNode* call = op->value.as<CallNode>();
  if (call) {
    if (call->op.same_as(builtin::tvm_storage_sync())) {
      this->PrintStorageSync(call);
      return;
    } else if (call->op.same_as(builtin::tvm_struct_set())) {
      ICHECK_EQ(call->args.size(), 4);
      int kind = call->args[2].as<IntImmNode>()->value;
      std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind);
      std::string value = PrintExpr(call->args[3]);
      std::string cast;
      if (kind == builtin::kArrStrides) {
        // cast void* to int64_t*
        cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "";
      } else if (kind == builtin::kArrDeviceType) {
        // cast int to enum
        cast = "(DLDeviceType)";
      }
      this->PrintIndent();
      this->stream << ref << " = " << cast << value << ";\n";
      return;
    }
  }
  std::string vid = this->PrintExpr(op->value);
  if (vid != "") {
    this->PrintIndent();
    this->stream << vid << ";\n";
  }
}

【关键】VisitExpr #

CallNode #

函数流程:

  1. 获取算子op
  2. 根据算子op类型,执行不同的codegen
  • builtin.tvm_check_return
  • builtin.ret
  • builtin_call_extern_、builtin_call_pure_extern_
  • bitwise_and
  • large_uint_imm
  • bitwise_xor
  • bitwise_or
  • bitwise_not
  • shift_left
  • shift_right
  • if_then_else
  • address_of
  • tvm_struct_get
  • isnullptr
  • reinterpret
  • isnan
  • lookup_param

Finish #

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }

target.build.cuda #

BuildCUDA函数实现流程:

  1. 遍历所有的PrimFun函数,添加到CodeGenCUDA中
  2. cg.Finish
  3. NVRTC编译
  4. 创建CUDAModule
runtime::Module BuildCUDA(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenCUDA cg;
  cg.Init(output_ssa);
  // 1. 遍历所有的PrimFun函数,添加到CodeGenCUDA中
  for (auto kv : mod->functions) {
    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
        << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
    cg.AddFunction(f);
  }
  // 2. 
  std::string code = cg.Finish();

  if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
    code = (*f)(code, target).operator std::string();
  }
  std::string fmt = "ptx";
  std::string ptx;
  const auto* f_enter = Registry::Get("target.TargetEnterScope");
  (*f_enter)(target);
  // 3. NVRTC编译
  if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
    ptx = (*f)(code, target).operator std::string();
    // Dirty matching to check PTX vs cubin.
    // TODO(tqchen) more reliable checks
    if (ptx[0] != '/') fmt = "cubin";
  } else {
    ptx = NVRTCCompile(code, cg.need_include_path());
  }
  const auto* f_exit = Registry::Get("target.TargetExitScope");
  (*f_exit)(target);
  // 4. 创建CUDAModule
  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}

TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

CodeGenCUDA #

类描述: 类数据属性:

class CodeGenCUDA final : public CodeGenC {
 public:
 ...
  // Whether global barrier is needed.
  bool need_global_barrier_{false};
  // Global barrier state
  std::string vid_global_barrier_state_;
  // Global barrier expected node.
  std::string vid_global_barrier_expect_;
  // whether enable fp16
  bool enable_fp16_{false};
  // whether enable bf16
  bool enable_bf16_{false};
  // whether enable fp8
  bool enable_fp8_{false};
  // whether enable int8
  bool enable_int8_{false};
  // whether enable warp shuffle intrinsics
  bool enable_warp_shuffle_{false};
  // whether need math_constants.h
  bool need_math_constants_h_{false};
  // whether need mma.h
  bool need_mma_h_{false};
  // Op attribute map
  OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");

  std::unordered_map<const VarNode*, std::string> fragment_shapes;
  std::unordered_map<const VarNode*, std::string> fragment_layouts;
  
  friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
  void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
                      std::ostream& os);
  int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
};

AddFunction #

直接调用CodeGenC的AddFunction

VisitStmt #

针对的stmt节点:

  • ForNode(添加#pragma unroll宏定义)
  • EvaluateNode
  • AllocateNode(关键)
  • AttrStmtNode
AttrStmtNode #
  • 属性为:fragment_shape,添加到数据属性fragment_shapes中
  • 属性为:fragment_layout,添加到数据属性fragment_layouts中
  • 属性为:async_commit_queue_scope,VisitStmt(body), visitExpr(commit_group)
  • 属性为:async_wait_queue_scope,visitExpr(wait_group), visitStmt(body) 回调CodeGenC的VisitStmt(AttrStmtNode)
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
  if (op->attr_key == tir::attr::fragment_shape) {
    const VarNode* buffer = op->node.as<VarNode>();
    const StringImmNode* shape_str = op->value.as<StringImmNode>();
    fragment_shapes[buffer] = shape_str->value;
  } else if (op->attr_key == tir::attr::fragment_layout) {
    const VarNode* buffer = op->node.as<VarNode>();
    const StringImmNode* layout_str = op->value.as<StringImmNode>();
    fragment_layouts[buffer] = layout_str->value;
  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
    const IntImmNode* queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
    auto wait_cnt = wait_attrs.second;
    auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}
AllocateNode #
void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());
  //
  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
  const VarNode* buffer = op->buffer_var.as<VarNode>();
  //
  if (scope.find("wmma.") == 0) {
    if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
      ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
             op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
             op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
             op->dtype == DataType::BFloat(16))
          << "Matrix_a and matrix_b only support half or char or unsigned char "
          << "or uint4 or int4 or int1 type for now";
    } else {
      ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
             op->dtype == DataType::Int(32))
          << "Accumulator only support half, float and int type for now";
    }
    PrintWmmaScope(scope, op->dtype, buffer, stream);
  } else {
    PrintStorageScope(scope, stream);
    PrintType(op->dtype, stream);
  }

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
    ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";

    if (scope.find("wmma.") == 0) {
      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
    }
    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
    stream << ' ' << vid << '[' << constant_size << "];\n";
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}
EvaluateNode #
  • op为tvm_global_barrier_kinit
  • 回调CodeGenC的EvaluateNode
void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
  if (is_const_int(op->value)) return;
  const CallNode* call = op->value.as<CallNode>();
  if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
    PrintIndent();
    stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
    PrintIndent();
    stream << "if (threadIdx.x == 0) {\n";
    PrintIndent();
    stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
    PrintIndent();
    stream << "}\n";
  } else {
    CodeGenC::VisitStmt_(op);
  }
}
ForNode #
void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
  ICHECK(is_const_int(op->min, 0));
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
  CodeGenC::VisitStmt_(op);
}

VisitExpr #

CodeGenCUDA实现的VisitExpr:

  • CastNode
  • CallNode
  • BroadcastNode
  • ShuffleNode
  • SelectNode
  • RampNode
  • FloatImmNode
CallNode #

函数流程:

  1. 默认开启warp_shuffle
  2. 根据不同的op类型,执行不同的codegen
  • 【TensorCore】tvm_fill_fragment
  • 【TensorCore】tvm_load_matrix_sync
  • 【TensorCore】tvm_store_matrix_sync
  • 【TensorCore】tvm_mma_sync
  • 【TensorCore】tvm_bmma_sync
  • 【TensorCore】ptx_mma
  • 【TensorCore】ptx_mma_sp
  • 【TensorCore】ptx_ldmatrix
  • 【TensorCore】mma_store
  • 【TensorCore】mma_fill
  • ptx_cp_async
  • ptx_commit_group
  • ptx_wait_group
  • ptx_ldg32
  1. 调用CodegenC的visit call
BroadcastNode #
ShuffleNode #
void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
  std::vector<std::string> to_shuffle(op->vectors.size());
  for (int i = 0, e = op->vectors.size(); i < e; ++i) {
    ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
    to_shuffle[i] = PrintExpr(op->vectors[i]);
  }
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0, e = op->indices.size(); i < e; ++i) {
    const int64_t* val = as_const_int(op->indices[i]);
    ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
    if (i != 0) os << ", ";
    os << to_shuffle[*val];
  }
  os << ')';
}

Finish #

NVRTCCompile #

CUDAModuleCreate #

测试用例 #

test_vm_build #

test_vm_codegen_only #

test_vm_codegen_tir #

test_vm_execbuilder #

单个函数:

def test_vm_execute():
    ib = relax.ExecBuilder()
    with ib.function("func0", num_inputs=2):
        ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
        ib.emit_ret(ib.r(2))
    ex = ib.get()
>>> print(ex.as_text())
@func0:
  call  test.vm.add      in: %0, %1       dst: %2
  ret   %2
@test.vm.add packed_func;
>>> type(ex)
tvm.relax.vm_build.Executable
>>> type(ex.mod), ex.mod
(tvm.runtime.module.Module, Module(relax.Executable, 36cf178))

多个函数:

def test_vm_multiple_func():
    ib = relax.ExecBuilder()
    with ib.function("func0", num_inputs=2):
        ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
        ib.emit_ret(ib.r(2))
    with ib.function("func1", num_inputs=2):
        ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
        ib.emit_ret(ib.r(2))
    ex = ib.get()
>>> print(ex.as_text())
@func0:
  call  test.vm.add      in: %0, %1       dst: %2
  ret   %2
@test.vm.add packed_func;

@func1:
  call  test.vm.mul      in: %0, %1       dst: %2
  ret   %2
@test.vm.mul packed_func;
>>> 

指令参数:

  • 寄存器
  • 立即数
  • 函数
# 函数参数为立即数
with ib.function("func0", num_inputs=1):
	ib.emit_call("test.vm.add_scalar", args=[ib.imm(-3), ib.r(0)], dst=ib.r(1))
	ib.emit_ret(ib.r(1))
ex = ib.get()
>>> print(ex.as_text())
@func0:
  call  test.vm.add_scalar in: i-3, %0      dst: %1
  ret   %1
@test.vm.add_scalar packed_func;
>>> 

函数参数为函数:

def test_vm_invoke_closure():
    ib = relax.ExecBuilder()
    with ib.function("lifted_func_1", num_inputs=4):
        ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(4))
        ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(4)], dst=ib.r(5))
        ib.emit_call("test.vm.add", args=[ib.r(3), ib.r(5)], dst=ib.r(6))
        ib.emit_ret(ib.r(6))
    with ib.function("main", num_inputs=2):
        ib.emit_call(
            "vm.builtin.make_closure", args=[ib.f("lifted_func_1"), ib.r(0), ib.r(1)], dst=ib.r(2)
        )
        ib.emit_ret(ib.r(2))
    ex = ib.get()
>>> print(ex.as_text())
@lifted_func_1:
  call  test.vm.add      in: %0, %1       dst: %4
  call  test.vm.add      in: %2, %4       dst: %5
  call  test.vm.add      in: %3, %5       dst: %6
  ret   %6
@test.vm.add packed_func;
@main:
  call  vm.builtin.make_closure in: f[lifted_func_1], %0, %1 dst: %2
  ret   %2
@vm.builtin.make_closure packed_func;
>>>