【TVM】TIR Schedule和Dlight学习

下面是TIR的AST结构: PrimFunc

TIR Schedule #

解析mod信息:

  • 获取block:sch.get_block
  • 获取loops:sch.get_loops

线程相关:

  • (线程绑定)sch.bind:sch.bind(bx, “blockIdx.x”)、sch.bind(tx, “threadIdx.x”) 循环相关:
  • (循环拆分)sch.split
  • (循环重排)sch.reorder
  • (循环展开标注)sch.annotate(tx, ann_key=“pragma_auto_unroll_max_step”/“pragma_unroll_explicit”)
  • (向量化)sch.vectorize
  • (合并)sch.fuse 块相关:
  • sch.reindex
  • sch.cache_read, sch.cache_write 存储相关:
  • sch.transform_layout

Thread #

  • 线程分配和运行:
    • 高级原语:thread_binding
    • 低级原语:allocate、launch_thread

allocate 和 launch_thread(T.thread_binding) #

thread reduce #

提供API #

NormalizePrimFunc #

# src/tir/schedule/transform.cc
TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc);

函数流程:

  1. check
  2. 构造index_map_inputs和index_map_outputs
  3. 根据map对block的layer进行变换
  4. 判断block是否是reduction block
Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
  BlockRV root_block = sch->GetBlock("root");
  Array<BlockRV> blocks = sch->GetChildBlocks(root_block);
  # 1. check
  for (const BlockRV& block : blocks) {
    StmtSRef block_sref = sch->GetSRef(block);
    Array<StmtSRef> loops = GetLoops(block_sref);
    Array<PrimExpr> binds = GetBlockRealize(sch->state(), block_sref)->iter_values;
    if (loops.size() != binds.size()) {
      return NullOpt;
    }
    for (int i = 0, n = loops.size(); i < n; ++i) {
      const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
      if (binds[i].get() != loop->loop_var.get()) {
        return NullOpt;
      }
      if (!is_zero(loop->min)) {
        return NullOpt;
      }
    }
  }
  Array<Array<LoopRV>> block_loops;
  Array<Array<IterVar>> block_iters;
  Array<IntImm> block_is_reduction;
  # 
  for (const BlockRV& block : blocks) {
    Array<IterVar> iters = sch->Get(block)->iter_vars;
    bool has_spatial_iter = false;
    # 2. 构造index_map_inputs和index_map_outputs
    Array<Var> index_map_inputs;
    Array<PrimExpr> index_map_outputs;
    for (const IterVar& iter : sch->Get(block)->iter_vars) {
      Var var = iter->var.copy_with_suffix("");
      index_map_inputs.push_back(var);
      if (!is_one(iter->dom->extent)) {
        index_map_outputs.push_back(var);
        if (iter->iter_type == IterVarType::kDataPar) {
          has_spatial_iter = true;
        }
      }
    }
    if (index_map_outputs.empty() || !has_spatial_iter) {
      index_map_outputs.insert(index_map_outputs.begin(), tir::make_const(DataType::Int(64), 0));
    }
    # 3. 根据map对block的layer进行变换
    sch->TransformBlockLayout(block, IndexMap(index_map_inputs, index_map_outputs));
    block_loops.push_back(sch->GetLoops(block));
    block_iters.push_back(sch->Get(block)->iter_vars);
    # 4. 判断block是否是reduction block
    bool is_reduction = IsReductionBlock(sch->state(),         //
                                         sch->GetSRef(block),  //
                                         sch->GetSRef(root_block));
    block_is_reduction.push_back(Bool(is_reduction));
  }
  return Array<ObjectRef>{blocks, block_loops, block_iters, block_is_reduction};
}

is_reduction #

针对是否reduction函数的判断:

  1. 收集所有的reduction轴;
  2. 收集所有的write buffer;
  3. 收集所有的alloc buffer;

DLight #

Matmul #

源码路径:python/tvm/dlight/gpu/matmul.py sch流程:

  1. 检查tensor core支持
  2. 获取schedule config
  3. do schedule

GEMV #

案例展示:

E = relax.Var("A", R.Tensor([10, 20, 1, 1], "int64"))
F = relax.Var("B", R.Tensor([10, 20, 1, 30], "int64"))
with bb.function("matmul_5"):
    with bb.dataflow():
        lv0 = relax.op.matmul(E, F)
        gv = bb.emit_output(lv0)
    bb.emit_func_output(gv, [E, F])
>>> 

实现:

  1. Norm Block,得到block_info
    def apply(  # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
        self,
        func: tir.PrimFunc,
        target: Target,
        _: bool,
    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
        sch = tir.Schedule(func)
        block_infos = normalize_prim_func(sch)
        block_infos = try_inline_contiguous_spatial(sch, block_infos)
        if len(block_infos) == 1:
            epilogue = None
        elif len(block_infos) == 2:
            epilogue = block_infos[1]
            if not epilogue.is_injective():
                return None
        else:
            return None

        block_info = block_infos[0]
        if len(block_info.iters) not in [2, 3]:
            # either [B, S, R] = [B, S, R] * [B, R]
            # or [S, R] = [S, R] * [R]
            return None
        block = block_info.block_rv
        vector_input_buffers = is_gemv(sch, block_info)
        if vector_input_buffers is None:
            return None

        # Step 1. Normalize the block, merge spatial and reduction iters
        is_inner_reduction = normalize(sch, block_info)

        # Step 2. Do the scheduling
        if is_inner_reduction:
            self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue)
            return sch
        else:
            # TODO: Need to handle GEMV with KN layout
            return None

normalize #

sch_inner_reduction #

Reduction #

案例展示:tir_script:包含3个spatial轴,一个reduce轴

@T.prim_func(private=True)
def main(
    A: T.Buffer((T.int64(10), T.int64(20), T.int64(1)), "int64"),
    B: T.Buffer((T.int64(10), T.int64(1), T.int64(30)), "int64"),
    matmul: T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "int64"),
):
    T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0, i1, i2, k in T.grid(T.int64(10), T.int64(20), T.int64(30), T.int64(1)):
        with T.block("matmul"):
            v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
            T.reads(A[v_i0, v_i1, v_k], B[v_i0, v_k, v_i2])
            T.writes(matmul[v_i0, v_i1, v_i2])
            with T.init():
                matmul[v_i0, v_i1, v_i2] = T.int64(0)
            matmul[v_i0, v_i1, v_i2] = (
                matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i0, v_k, v_i2]
            )
# 1. Norm
>>> block_infos = dl.base.analysis.normalize_prim_func(sch)
[BlockInfo("matmul", "SSS", [10, 20, 30])] # 只有三个spatial轴
>>> block_info.is_reduction(), block_stmt.writes, dl.gpu.reduction._get_reduction_expr(block_stmt)
T.bool(True) # 注意, 这里为true
[matmul[v0, v1, v2]] 
A[v0, v1, T.int64(0)] * B[v0, T.int64(0), v2]
>>> block_stmt = sch.get(block) # 这里没有reduce轴
with T.block("matmul", no_realize=True):
    v0 = T.axis.spatial(T.int64(10))
    v1 = T.axis.spatial(T.int64(20))
    v2 = T.axis.spatial(T.int64(30))
    A = T.Buffer((T.int64(10), T.int64(20), T.int64(1)), "int64")
    B = T.Buffer((T.int64(10), T.int64(1), T.int64(30)), "int64")
    T.reads(A[v0, v1, T.int64(0)], B[v0, T.int64(0), v2])
    matmul = T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "int64")
    T.writes(matmul[v0, v1, v2])
    with T.init():
        matmul[v0, v1, v2] = T.int64(0)
    matmul[v0, v1, v2] = matmul[v0, v1, v2] + A[v0, v1, T.int64(0)] * B[v0, T.int64(0), v2]
>>> dl.base.analysis.detect_dominant_read(block_stmt)
v0 * T.int64(20) + v1
>>> access # 读取的轴:v0, v1
IterSum([
 IterSplit(IterMark(v0, extent=T.int64(10)), lower_factor=T.int64(1), extent=T.int64(10), scale=T.int64(20)), IterSplit(IterMark(v1, extent=T.int64(20)), lower_factor=T.int64(1), extent=T.int64(20), scale=T.int64(1))], T.int64(0))

schedule流程:

  1. 对func进行norm,得到block_infos
  2. 检查reduction block
  3. 合并spatial和reduction iters
  4. do schedule
	# 1. 得到block_infos
	sch = tir.Schedule(func)
	block_infos = normalize_prim_func(sch)
	# 2. check 
	if (
		(not block_info.is_reduction())
		or len(block_stmt.writes) != 1
		or _get_reduction_expr(block_stmt) is None
	):
		return None
	# 3. 合并spatial和reduction iters
	is_inner_reduction, c_factor = self._normalize(
		sch,
		block_info,
		arith.normalize_to_iter_sum(
			detect_dominant_read(block_stmt),
			input_iters={i.var: i.dom for i in block_stmt.iter_vars},
		),
	)
	# 4. 执行schedule
	if is_inner_reduction:
		self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
	else:
		self._sch_inner_spatial(sch, target, block, c_factor, epilogue)

_normalize #

schedule #

函数签名:

  • 参数:mod的sch, target,
  • 返回:sch
def _sch_inner_spatial(
	self,
	sch: tir.Schedule,
	target: Target,
	block: tir.schedule.BlockRV,
	unroll_spatial_factor: Optional[int],
	epilogue_info: Optional[BlockInfo],
):

_sch_inner_reduction #

_sch_inner_spatial #

Tranpose #

build #

Reduction #

  1. 循环展开,warp/block外add
  2. 二次warpReduce
  • warp内shfl,计算结果存储到shared memory中
  • 一组thread继续shfl,得到计算结果

lower to tir mod 降级之后的tir script:

@T.prim_func
def main(
	A: T.Buffer((T.int64(2048), T.int64(8192)), "float32"),
	A_red: T.Buffer((T.int64(2048),), "float32"),
):
	T.func_attr(
		{"op_pattern": 3, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}
	)
	blockIdx_x = T.launch_thread("blockIdx.x", 2048) # blockIdx
	
	A_red_rf_local = T.allocate([1], "float32", "local")
	cross_thread_A_red = T.allocate([1], "float32", "local")
	
	threadIdx_x = T.launch_thread("threadIdx.x", 256) # threadIdx
	A_red_rf_local_1 = T.Buffer((T.int64(1),), data=A_red_rf_local, scope="local")
	A_red_rf_local_1[0] = T.float32(0)
	
	A_1 = T.Buffer((T.int64(16777216),), data=A.data) # 引用参数数据
	# 循环展开
	A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 8192 + threadIdx_x]
	A_red_rf_local_1[0] = (
		A_red_rf_local_1[0] + A_1[blockIdx_x * 8192 + threadIdx_x + 256]
	)
	A_red_rf_local_1[0] = (
		A_red_rf_local_1[0] + A_1[blockIdx_x * 8192 + threadIdx_x + 512]
	)
	A_red_rf_local_1[0] = (
		A_red_rf_local_1[0] + A_1[blockIdx_x * 8192 + threadIdx_x + 768]
	)
	A_red_rf_local_1[0] = (
		A_red_rf_local_1[0] + A_1[blockIdx_x * 8192 + threadIdx_x + 1024]
	)
	...
	cross_thread_A_red_1 = T.Buffer((1,), data=cross_thread_A_red, scope="local")
	# 256个线程reduce
	with T.attr(
		T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
		"reduce_scope",
		T.reinterpret("handle", T.uint64(0)),
	):
		T.tvm_thread_allreduce(
			T.uint32(1),
			A_red_rf_local_1[0],
			T.bool(True),
			cross_thread_A_red_1[0],
			threadIdx_x,
		)
	# 写回返回参数中
	if threadIdx_x == 0:
		A_red_1 = T.Buffer((T.int64(2048),), data=A_red.data)
		A_red_1[blockIdx_x] = cross_thread_A_red_1[0]

【tvm.build->tir_to_runtime】mix mod lower之后:

  • 包含main_kernel和main 元函数
@I.ir_module
class Module:
    I.module_attrs({"runtime": None})
    @T.prim_func
    def main_kernel(A: T.handle("float32", "global"), A_red: T.handle("float32", "global")):
    
        T.func_attr({"calling_conv": 2, "target": T.target({"arch": "sm_75", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "registers_per_block": 65536, "tag": "", "thread_warp_size": 32}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x"], "tir.noalias": T.bool(True)})
        
        blockIdx_x = T.launch_thread("blockIdx.x", 2048)
        A_red_rf_local = T.allocate([1], "float32", "local")
        red_result = T.allocate([1], "float32", "shared")
        T.attr(red_result, "volatile_scope", 1)
        threadIdx_x = T.launch_thread("threadIdx.x", 256)
        A_red_rf_local_1 = T.Buffer((1,), data=A_red_rf_local, scope="local")
        A_red_rf_local_1[0] = T.float32(0)
        A_1 = T.Buffer((T.int64(4194304),), data=A)
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 256]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 512]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 768]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1024]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1280]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1536]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1792]
        red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
        with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
            red_buf0 = T.allocate([1], "float32", "local")
            mask = T.allocate([1], "uint32", "local")
            t0 = T.allocate([1], "float32", "local")
            red_buf0_1 = T.allocate([1], "float32", "local")
            mask_1 = T.allocate([1], "uint32", "local")
            t0_1 = T.allocate([1], "float32", "local")
            red_buf_staging = T.allocate([8], "float32", "shared")
            red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
            red_buf0_2[0] = A_red_rf_local_1[0]
            mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
            mask_2[0] = T.tvm_warp_activemask()
            t0_2 = T.Buffer((1,), data=t0_1, scope="local")
            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
            if threadIdx_x % 32 == 0:
                red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
            T.tvm_storage_sync("shared")
            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
            if threadIdx_x < 8:
                red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
            mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
            mask_3[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(255))
            t0_3 = T.Buffer((1,), data=t0, scope="local")
            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            if threadIdx_x == 0:
                red_result_1[0] = red_buf0_3[0]
            T.tvm_storage_sync("shared")
        if threadIdx_x == 0:
            A_red_1 = T.Buffer((T.int64(2048),), data=A_red)
            A_red_1[blockIdx_x] = red_result_1[0]

    @T.prim_func
    def main(args: T.handle, arg_type_ids: T.handle("int32"), num_args: T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), resource_handle: T.handle) -> T.int32:
        T.func_attr({"calling_conv": 1, "op_pattern": 3, "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_entry_func": T.bool(True), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        assert num_args == 2, "main: num_args should be 2"
        arg_type_ids_1 = T.decl_buffer((2,), "int32", data=arg_type_ids)
        var_A_code: T.int32 = arg_type_ids_1[0]
        var_A_red_code: T.int32 = arg_type_ids_1[1]
        var_A: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
        var_A_red: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
        A: T.handle("float32", "global") = T.tvm_struct_get(var_A, 0, 1, "handle")
        T.attr(A, "storage_alignment", 64)
        main_var_A_shape: T.handle("int64") = T.tvm_struct_get(var_A, 0, 2, "handle")
        main_var_A_shape_1 = T.decl_buffer((2,), "int64", data=main_var_A_shape)
        main_var_A_strides: T.handle("int64") = T.tvm_struct_get(var_A, 0, 3, "handle")
        main_var_A_strides_1 = T.decl_buffer((0,), "int64", data=main_var_A_strides)
        dev_id: T.int32 = T.tvm_struct_get(var_A, 0, 9, "int32")
        A_red: T.handle("float32", "global") = T.tvm_struct_get(var_A_red, 0, 1, "handle")
        T.attr(A_red, "storage_alignment", 64)
        main_var_A_red_shape: T.handle("int64") = T.tvm_struct_get(var_A_red, 0, 2, "handle")
        main_var_A_red_shape_1 = T.decl_buffer((1,), "int64", data=main_var_A_red_shape)
        main_var_A_red_strides: T.handle("int64") = T.tvm_struct_get(var_A_red, 0, 3, "handle")
        main_var_A_red_strides_1 = T.decl_buffer((0,), "int64", data=main_var_A_red_strides)
        assert var_A_code == 3 or var_A_code == 13 or var_A_code == 7 or var_A_code == 4, "main: Expect arg[0] to be pointer"
        assert var_A_red_code == 3 or var_A_red_code == 13 or var_A_red_code == 7 or var_A_red_code == 4, "main: Expect arg[1] to be pointer"
        T.attr("default", "device_id", dev_id)
        T.attr("default", "device_type", 2)
        assert 2 == T.tvm_struct_get(var_A, 0, 4, "int32"), "main.var_A.ndim is expected to equal 2"
        assert 2 == T.tvm_struct_get(var_A, 0, 4, "int32"), "main.var_A.ndim is expected to equal 2"
        assert T.tvm_struct_get(var_A, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(var_A, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(var_A, 0, 7, "uint16") == T.uint16(1), "main.var_A.dtype is expected to be float32"
        assert main_var_A_shape_1[0] == T.int64(2048), "Argument main.var_A.shape[0] has an unsatisfied constraint: T.int64(2048) == main_var_A_shape[0]"
        assert main_var_A_shape_1[1] == T.int64(2048), "Argument main.var_A.shape[1] has an unsatisfied constraint: T.int64(2048) == main_var_A_shape[1]"
        if not T.isnullptr(main_var_A_strides):
            assert T.int64(1) == main_var_A_strides_1[1] and T.int64(2048) == main_var_A_strides_1[0], "main.var_A.strides: expected to be compact array"
            T.evaluate(0)
        assert T.uint64(0) == T.tvm_struct_get(var_A, 0, 8, "uint64"), "Argument main.var_A.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(var_A, 0, 8, \"uint64\")"
        assert T.tvm_struct_get(var_A, 0, 10, "int32") == 2, "Argument main.var_A.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(var_A, 0, 10, \"int32\")"
        assert 1 == T.tvm_struct_get(var_A_red, 0, 4, "int32"), "main.var_A_red.ndim is expected to equal 1"
        assert 1 == T.tvm_struct_get(var_A_red, 0, 4, "int32"), "main.var_A_red.ndim is expected to equal 1"
        assert T.tvm_struct_get(var_A_red, 0, 5, "uint8") == T.uint8(2) and T.tvm_struct_get(var_A_red, 0, 6, "uint8") == T.uint8(32) and T.tvm_struct_get(var_A_red, 0, 7, "uint16") == T.uint16(1), "main.var_A_red.dtype is expected to be float32"
        assert main_var_A_red_shape_1[0] == T.int64(2048), "Argument main.var_A_red.shape[0] has an unsatisfied constraint: T.int64(2048) == main_var_A_red_shape[0]"
        if not T.isnullptr(main_var_A_red_strides):
            assert T.int64(1) == main_var_A_red_strides_1[0], "main.var_A_red.strides: expected to be compact array"
            T.evaluate(0)
        assert T.uint64(0) == T.tvm_struct_get(var_A_red, 0, 8, "uint64"), "Argument main.var_A_red.byte_offset has an unsatisfied constraint: T.uint64(0) == T.tvm_struct_get(var_A_red, 0, 8, \"uint64\")"
        assert T.tvm_struct_get(var_A_red, 0, 10, "int32") == 2, "Argument main.var_A_red.device_type has an unsatisfied constraint: 2 == T.tvm_struct_get(var_A_red, 0, 10, \"int32\")"
        assert dev_id == T.tvm_struct_get(var_A_red, 0, 9, "int32"), "Argument main.var_A_red.device_id has an unsatisfied constraint: dev_id == T.tvm_struct_get(var_A_red, 0, 9, \"int32\")"
        
        A_1 = T.decl_buffer((T.int64(2048), T.int64(2048)), data=A)
        A_red_1 = T.decl_buffer((T.int64(2048),), data=A_red)
        T.call_packed("__tvm_set_device", 2, dev_id)
        with T.attr(0, "compute_scope", "main_compute_"):
            T.call_packed("main_kernel", A, A_red, 2048, 256)
        T.ret(0)

after device mod transform:

  • T.tvm_warp_activemask -> T.tir.cuda.__activemask
  • T.tvm_warp_shuffle_down –> T.tir.cuda.__shfl_down_sync
  • truncmod
@I.ir_module
class Module:
    I.module_attrs({"runtime": None})
    @T.prim_func
    def main_kernel(A: T.handle("float32", "global"), A_red: T.handle("float32", "global")):
        T.func_attr({"calling_conv": 2, "target": T.target({"arch": "sm_75", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "registers_per_block": 65536, "tag": "", "thread_warp_size": 32}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x"], "tir.noalias": T.bool(True)})
        blockIdx_x = T.launch_thread("blockIdx.x", 2048)
        A_red_rf_local = T.allocate([1], "float32", "local")
        red_result = T.allocate([1], "float32", "shared")
        T.attr(red_result, "volatile_scope", 1)
        threadIdx_x = T.launch_thread("threadIdx.x", 256)
        A_red_rf_local_1 = T.Buffer((1,), data=A_red_rf_local, scope="local")
        A_red_rf_local_1[0] = T.float32(0)
        A_1 = T.Buffer((T.int64(4194304),), data=A)
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 256]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 512]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 768]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1024]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1280]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1536]
        A_red_rf_local_1[0] = A_red_rf_local_1[0] + A_1[blockIdx_x * 2048 + threadIdx_x + 1792]
        red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
        # attr node
        with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
            red_buf0 = T.allocate([1], "float32", "local")
            mask = T.allocate([1], "uint32", "local")
            t0 = T.allocate([1], "float32", "local")
            red_buf0_1 = T.allocate([1], "float32", "local")
            mask_1 = T.allocate([1], "uint32", "local")
            t0_1 = T.allocate([1], "float32", "local")
            red_buf_staging = T.allocate([8], "float32", "shared")
            # seq stmt node
            # ... BufferStore stmt node
            red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
            red_buf0_2[0] = A_red_rf_local_1[0]
            mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
            mask_2[0] = T.tir.cuda.__activemask()
            t0_2 = T.Buffer((1,), data=t0_1, scope="local")
            t0_2[0] = T.tir.cuda.__shfl_down_sync(mask_2[0], red_buf0_2[0], 16, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tir.cuda.__shfl_down_sync(mask_2[0], red_buf0_2[0], 8, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tir.cuda.__shfl_down_sync(mask_2[0], red_buf0_2[0], 4, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tir.cuda.__shfl_down_sync(mask_2[0], red_buf0_2[0], 2, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            t0_2[0] = T.tir.cuda.__shfl_down_sync(mask_2[0], red_buf0_2[0], 1, 32)
            red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
            # if then else node
            red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
            if T.truncmod(threadIdx_x, 32) == 0:
                red_buf_staging_1[T.shift_right(threadIdx_x, 5)] = red_buf0_2[0]
            # evaluate node
            T.tvm_storage_sync("shared")
            # if then else node
            red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
            if threadIdx_x < 8:
                red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
            mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
            mask_3[0] = T.bitwise_and(T.tir.cuda.__activemask(), T.uint32(255))
            t0_3 = T.Buffer((1,), data=t0, scope="local")
            t0_3[0] = T.tir.cuda.__shfl_down_sync(mask_3[0], red_buf0_3[0], 4, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            t0_3[0] = T.tir.cuda.__shfl_down_sync(mask_3[0], red_buf0_3[0], 2, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            t0_3[0] = T.tir.cuda.__shfl_down_sync(mask_3[0], red_buf0_3[0], 1, 32)
            red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
            # if then else node
            if threadIdx_x == 0:
                red_result_1[0] = red_buf0_3[0]
            T.tvm_storage_sync("shared")
        if threadIdx_x == 0:
            A_red_1 = T.Buffer((T.int64(2048),), data=A_red)
            A_red_1[blockIdx_x] = red_result_1[0]

Unroll #

shfl替换warp内add #

lower to tir mod #

lower后的tir script,这是一个stmt.AttrStmt节点:

  • node:ObjectRef –> tvm.tir.expr.CommReducer
  • attr_key:String
  • value:PrimExpr –> tvm.tir.expr.Call
  • body: stmt.Evaluate
with T.attr(
	T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), # node
	"reduce_scope", # attr_key
	T.reinterpret("handle", T.uint64(0)), # value
): # body
	T.tvm_thread_allreduce(
		T.uint32(1),
		A_red_rf_local_1[0],
		T.bool(True),
		cross_thread_A_red_1[0],
		threadIdx_x,
	)
>>> seq_stmt[9].node, type(seq_stmt[9].node)
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
tvm.tir.expr.CommReducer
>>> seq_stmt[9].attr_key, 
'reduce_scope',
>>> seq_stmt[9].value, type(seq_stmt[9].value)
T.reinterpret("handle", T.uint64(0))
tvm.tir.expr.Call
>>> seq_stmt[9].body
...

分析Evaluate节点:

  • value:expr.Call
>>> seq_stmt[9].body.value,
T.tvm_thread_allreduce(T.uint32(1), A_red_rf_local_1[0], T.bool(True), cross_thread_A_red_1[0], threadIdx_x),
>>> type(seq_stmt[9].body.value)
tvm.tir.expr.Call
>>> seq_stmt[9].body.value.op # call.op
Op(tir.tvm_thread_allreduce)

lower to mixed transform mod #

在tir to runtime中的split mixed mod中 apply mixed seq passes

  • 其中对T.tvm_thread_allreduce的转换->
with T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
	red_buf0 = T.allocate([1], "float32", "local")
	mask = T.allocate([1], "uint32", "local")
	t0 = T.allocate([1], "float32", "local")
	red_buf0_1 = T.allocate([1], "float32", "local")
	mask_1 = T.allocate([1], "uint32", "local")
	t0_1 = T.allocate([1], "float32", "local")
	red_buf_staging = T.allocate([8], "float32", "shared")
	red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
	red_buf0_2[0] = A_red_rf_local_1[0]
	mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local")
	mask_2[0] = T.tvm_warp_activemask()
	t0_2 = T.Buffer((1,), data=t0_1, scope="local")
	# warp shuffle
	t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32)
	red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
	t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32)
	red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
	t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32)
	red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
	t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32)
	red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
	t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
	red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
	# combine warped reduced value
	red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
	if threadIdx_x % 32 == 0:
		red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
	T.tvm_storage_sync("shared")
	red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
	if threadIdx_x < 8:
		red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
	mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
	mask_3[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(255))
	t0_3 = T.Buffer((1,), data=t0, scope="local")
	t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32)
	red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
	t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
	red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
	t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
	red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
	# write to ret
	if threadIdx_x == 0:
		red_result_1[0] = red_buf0_3[0]
	T.tvm_storage_sync("shared")

codegen #

AttrStmtNode #

数据成员:

  • node:ObjectRef –> tvm.tir.expr.CommReducer
  • attr_key:String
  • value:PrimExpr –> tvm.tir.expr.Call
  • body: stmt.Evaluate
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
	CodeGenC::VisitStmt_(op);
}
EvaluateNode #

数据成员:

  • value:expr.Call
VisitExpr:Call #

函数描述 函数流程:

  1. 默认开启warp_shuffle
  2. 根据不同的op类型,执行不同的codegen
  • 【TensorCore】tvm_fill_fragment
  • 【TensorCore】tvm_load_matrix_sync
  • 【TensorCore】tvm_store_matrix_sync
  • 【TensorCore】tvm_mma_sync
  • 【TensorCore】tvm_bmma_sync
  • 【TensorCore】ptx_mma
  • 【TensorCore】ptx_mma_sp
  • 【TensorCore】ptx_ldmatrix
  • 【TensorCore】mma_store
  • 【TensorCore】mma_fill
  • ptx_cp_async
  • ptx_commit_group
  • ptx_wait_group
  • ptx_ldg32
  1. 调用CodegenC的visit call

案例分析 #

reduce:mean #

mean实现:

  1. 求和:sum(axis=-1)
  2. 除法:divide
@T.prim_func(private=True)
def mean(
	A: T.Buffer((T.int64(2048), T.int64(8192)), "float32"),
	T_divide: T.Buffer((T.int64(2048),), "float32"),
):
	T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
	# with T.block("root"):
	A_red = T.alloc_buffer((T.int64(2048),)) # 负责求和, (2048, 8192) -> (2048, )
	for ax0, k1 in T.grid(T.int64(2048), T.int64(8192)):
		with T.block("A_red"):
			v_ax0, v_k1 = T.axis.remap("SR", [ax0, k1])
			T.reads(A[v_ax0, v_k1])
			T.writes(A_red[v_ax0])
			with T.init():
				A_red[v_ax0] = T.float32(0)
			A_red[v_ax0] = A_red[v_ax0] + A[v_ax0, v_k1]
	for ax0 in range(T.int64(2048)): # 负责求均值, (2048,) -> (2048,)
		with T.block("T_divide"):
			v_ax0 = T.axis.spatial(T.int64(2048), ax0)
			T.reads(A_red[v_ax0])
			T.writes(T_divide[v_ax0])
			T_divide[v_ax0] = A_red[v_ax0] * T.float32(0.0001220703125)

dlight schedule后:

  1. blockIdx绑定到weight的第一个维度
  2. 【build之后为循环展开+直接sum】threadIdx限制为256,每个reduce轴计算得到256元素,用于处理reduce轴, 每个线程处理32个元素,包含一个for循环
    • 指定unroll标记参数:“pragma_unroll_explicit”: 1,表示使用循环展开
    • 指定展开最大深度:256
  3. 【build之后为:warp内部shfl reduce到shared memory中;一组thread继续读取shared memory进行shfl reduce,得到sum结果;】由于第二步计算得到了256元素,还需要reduce:继续开辟256个线程对256元素求和
@T.prim_func(private=True)
def mean(
	A: T.Buffer((T.int64(2048), T.int64(8192)), "float32"),
	T_divide: T.Buffer((T.int64(2048),), "float32"),
):
	T.func_attr(
		{"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}
	)
	# with T.block("root"):
	A_red_local = T.alloc_buffer((T.int64(2048),), scope="local")
	A_red_rf_local = T.alloc_buffer((T.int64(256), T.int64(2048)), scope="local")
	# sch: thread_binding
	for ax0_fused in T.thread_binding(T.int64(2048), thread="blockIdx.x"):
		for ax1_fused_1 in T.thread_binding(
			T.int64(256),
			thread="threadIdx.x",
			# sch: unroll
			annotations={
				"pragma_auto_unroll_max_step": 256,
				"pragma_unroll_explicit": 1,
			},
		): # 循环展开, (2048, 8192) -> (2048, 256) // 32循环展开
			with T.block("A_red_rf_init"):
				vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
				T.reads()
				T.writes(A_red_rf_local[vax1_fused_1, v0])
				A_red_rf_local[vax1_fused_1, v0] = T.float32(0)
			for ax1_fused_0, u in T.grid(T.int64(32), 1):
				with T.block("A_red_rf_update"):
					vax1_fused_1, v0, vax1_fused_0 = T.axis.remap(
						"SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]
					)
					T.reads(
						A_red_rf_local[vax1_fused_1, v0],
						A[v0, vax1_fused_0 * T.int64(256) + vax1_fused_1],
					)
					T.writes(A_red_rf_local[vax1_fused_1, v0])
					A_red_rf_local[vax1_fused_1, v0] = (
						A_red_rf_local[vax1_fused_1, v0]
						+ A[v0, vax1_fused_0 * T.int64(256) + vax1_fused_1]
					)
		for ax1_fused in range(T.int64(1)):
			for ax0 in T.thread_binding(T.int64(256), thread="threadIdx.x"): # 求和
				with T.block("A_red"):
					vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
					T.reads(A_red_rf_local[vax1_fused_1, v0])
					T.writes(A_red_local[v0])
					with T.init():
						A_red_local[v0] = T.float32(0)
					A_red_local[v0] = (
						A_red_local[v0] + A_red_rf_local[vax1_fused_1, v0]
					)
		with T.block("T_divide"):
			v0 = T.axis.spatial(T.int64(2048), ax0_fused) # 求均值
			T.reads(A_red_local[v0])
			T.writes(T_divide[v0])
			T_divide[v0] = A_red_local[v0] * T.float32(0.0001220703125)

build之后:

rt_mod = relax.build(After, target)
source = rt_mod.mod.imported_modules[0].imported_modules[0].get_source()

build策略:

  • 变量分配
  • 循环展开
  • shuffle
extern "C" __global__ void __launch_bounds__(256) mean_kernel(float* __restrict__ A, float* __restrict__ T_divide) {
  float A_red_rf_local[1];
  __shared__ float red_result[1];
  # 1. 循环展开
  A_red_rf_local[0] = 0.000000e+00f;
  A_red_rf_local[0] = (A_red_rf_local[0] + A[((((int)blockIdx.x) * 8192) + ((int)threadIdx.x))]);
  A_red_rf_local[0] = (A_red_rf_local[0] + A[(((((int)blockIdx.x) * 8192) + ((int)threadIdx.x)) + 256)]);
  A_red_rf_local[0] = (A_red_rf_local[0] + A[(((((int)blockIdx.x) * 8192) + ((int)threadIdx.x)) + 512)]);
  A_red_rf_local[0] = (A_red_rf_local[0] + A[(((((int)blockIdx.x) * 8192) + ((int)threadIdx.x)) + 768)]);
  ...
  A_red_rf_local[0] = (A_red_rf_local[0] + A[(((((int)blockIdx.x) * 8192) + ((int)threadIdx.x)) + 7936)]);
  # 2. sum: 第一次warp reduce -> 8
  float red_buf0[1];
  uint mask[1];
  float t0[1];
  float red_buf0_1[1];
  uint mask_1[1];
  float t0_1[1];
  __shared__ float red_buf_staging[8];
  red_buf0_1[0] = A_red_rf_local[0];
  mask_1[0] = __activemask();
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  if ((((int)threadIdx.x) % 32) == 0) {
    red_buf_staging[(((int)threadIdx.x) >> 5)] = red_buf0_1[0];
  }
  __syncthreads();
  # 2. sum: 第二次warp reduce -> 1
  if (((int)threadIdx.x) < 8) {
    red_buf0[0] = red_buf_staging[((int)threadIdx.x)];
  }
  mask[0] = (__activemask() & (uint)255);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  if (((int)threadIdx.x) == 0) {
    ((volatile float*)red_result)[0] = red_buf0[0];
  }
  __syncthreads();
  # 3. multiply
  if (((int)threadIdx.x) == 0) {
    T_divide[((int)blockIdx.x)] = (((volatile float*)red_result)[0] * 1.220703e-04f);
  }
}