【TVM】Relax.Pipeline源码分析

RewriteDataflowReshape #

pass功能:

  • 将所有的类似reshape的call_tir转换为VM reshape操作符call。VM reshape operator将会被降级为运行期间的CreateView操作,而不是做一些真实的数据拷贝工作。
  • 这里的reshape类的算子包含:reshape、expand_dims、flatten等

类实现:

Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) {
  return DataflowReshapeRewriter(mod)(f);
}

namespace transform {

Pass RewriteDataflowReshape() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(RewriteDataflowReshape(f, m));
      };
  return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {});
}

TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape")
    .set_body_typed(RewriteDataflowReshape);

}

CallTIRRewrite #

# src/relax/transform/call_tir_rewrite.cc
Pass CallTIRRewrite() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) { 
		      return Downcast<Function>(CallTIRRewrite(f)); 
	  };
  return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
}

TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);

调用CallTIRMutator实现:

Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }

LowerAllocTensor #

实例 #

@R.function
def main():
	x = R.builtin.alloc_tensor(R.shape([16, 32]), "float32", 0)
	return x

after LowerAllocTensor()(Before)

@R.function
def main():
	storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8") # 分配storage
	x = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32") # 分配tensor
	return x

VMBuiltinLower #

源文件位置:src/relax/backend/vm/vm_builtin_lower.cc pass类型:Mutator pass功能: global注册:

  • 核心类:VMBuiltinLowerMutator
Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }

Pass VMBuiltinLower() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) { 
	      return Downcast<Function>(VMBuiltinLower(f)); 
	  };
  return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
  • VMBuiltinLowerMutator属性
class VMBuiltinLowerMutator : public ExprMutator {
 public:
  using ExprMutator::VisitExpr_;
  ...
  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
  const StructInfo object_sinfo_ = ObjectStructInfo();
  const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
  // object to pattern match.
  const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
  const Op& reshape_op_ = Op::Get("relax.reshape");
  const Op& shape_of_op_ = Op::Get("relax.shape_of");
  const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
  const Op& make_closure_op_ = Op::Get("relax.make_closure");
  const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
  const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
  // mem alloc//kill
  const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage");
  const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor");
  const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage");
  const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor");
  // functions to lower to
  const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
  const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
  const Op& vm_kill_object_op_ = Op::Get("relax.vm.kill_object");
  // Function to compute allocated shape.
  const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
  const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
  const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
  const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
  const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
  const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
  const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};

VisitExpr #

  Expr VisitExpr_(const CallNode* call_node) final {
    // post-order mutation
    Call call = Downcast<Call>(VisitExprPostOrder_(call_node));

    if (call->op == call_tir_dyn_op_) {
      return CallTIRDyn(call);
    } else if (call->op == reshape_op_) {
      return Reshape(call);
    } else if (call->op == shape_of_op_) {
      return ShapeOf(call);
    } else if (call->op == to_vdevice_op_) {
      return ToDevice(call);
    } else if (call->op == make_closure_op_) {
      return MakeClosure(call);
    } else if (call->op == invoke_closure_op_) {
      return InvokeClosure(call);
    } else if (call->op == alloc_tensor_op_) {
      LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in expression "
                 << GetRef<Call>(call_node) << ".  "
                 << "This operation should have been lowered earlier "
                 << "using the 'relax.transform.LowerAllocTensor' pass.";
    // 内存存储分配//张量存储分配//内存释放
    } else if (call->op == mem_alloc_storage_op_) {
      return MakeMemAllocStorage(call);
    } else if (call->op == mem_alloc_tensor_op_) {
      return MakeMemAllocTensor(call);
    } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
      return MakeMemKillObject(call);
    } else {
      return call;
    }
  }

Mem #

MakeMemAllocStorage #

MakeMemAllocTensor #

relax.memory.alloc_tensor 降级为 relax.vm.alloc_tensor

  Expr MakeMemAllocTensor(const Call& call) {
    PrimValue offset = Downcast<PrimValue>(call->args[1]);
    DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]);
    return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], dtype}, Attrs());
  }

调用的ops:const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");

VMShapeLower #

源文件位置:src/relax/backend/vm/vm_shape_lower.cc global注册:

Pass VMShapeLower(bool emit_err_ctx) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
      [=](IRModule mod, PassContext pc) { 
	      return VMShapeLowerMutator::Lower(mod, emit_err_ctx); 
	  };
  return CreateModulePass(pass_func, 0, "VMShapeLower", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) {
  return VMShapeLower(emit_err_ctx);
});

VMShapeLowerMutator:Mutator,重写Rewrite,对IRModule进行重写

class VMShapeLowerMutator
    : public ExprMutator,
      public StructInfoFunctor<void(const StructInfo&, Expr, bool, bool, const String&,
                                    std::vector<MatchShapeTodoItem>*)> {
 public:
  static IRModule Lower(IRModule mod, bool emit_err_ctx) {
    # 1. visit
    VMShapeLowerMutator mutator(mod, emit_err_ctx);
    # 2. 遍历所有的functions, 针对tir.Func做处理-> Rewrite, 最后将更新后的func替换builder_的函数
    for (auto& kv : mod->functions) {
      if (auto* func = kv.second.as<FunctionNode>()) {
        Function updated_func = mutator.Rewrite(kv.first, GetRef<Function>(func));
        mutator.builder_->UpdateFunction(kv.first, updated_func);
      }
    }
    return mutator.builder_->GetContextIRModule();
  }

源码分析 #

算法流程:

  1. 预处理:PrimExprSlot 集合,我们扫描函数并为每个 PrimExpr 分配 PrimExprSlot。在上面的例子中,从slot索引到 expr 的结果映射将为 {0:m, 1: n+1: 2: n}。注意到:“n+1"也会获得一个slot。PrimExprSlot还带有辅助字段,用于跟踪其值是否易于计算。 每个匹配点的步骤:
  2. 步骤一 调用CheckMatchCast:将会递归式的unpack StructInfo,并且生成静态信息检查。请注意,此步骤仅生成用于检查类型和 ndim 信息的函数,而不生成符号形状变量。符号shape-matching的结果将被返回作为vector<MatchShapeTodoItem>。这是因为符号形状匹配可能无法在一轮中完成。重要的是,CheckMatchCast 还处理元组解包。
  3. 步骤二:调用RunMatch来生成匹配符号形状的语句。在上面的例子中,第一轮会将 M、N 的值存储到其对应的插槽中。RunMatch 可能会返回outstanding的项。
  4. 步骤三:

Rewrite #

  // Unit rewrite function per function.
  Function Rewrite(GlobalVar gvar, Function func) {
    // prepare mapping and heap var
    slot_vec_.clear();
    slot_map_.clear();
    PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_);
    heap_size_ = IntImm(ShapeDType(), static_cast<int64_t>(slot_vec_.size()));
    VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_);
    shape_heap_ = shape_heap_binding->var;

    // prepare slot information
    this->PopulateSlotInfo();

    Array<BindingBlock> blocks;

    builder_->BeginScope(func->params);

    {
      // Check the parameter section.
      builder_->BeginBindingBlock();
      this->builder_->EmitNormalized(shape_heap_binding);
      std::vector<MatchShapeTodoItem> match_todos;
      size_t num_input = func->params.size();
      if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
        // If the function has the attribute 'num_input', do shape checking on for the real inputs
        // and skip weights.
        num_input = static_cast<size_t>(opt_num_input.value()->value);
      }
      for (size_t i = 0; i < func->params.size(); ++i) {
        StructInfo sinfo = GetStructInfo(func->params[i]);
        std::ostringstream err_ctx;
        err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
                << "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") ";
        this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input, err_ctx.str(),
                             &match_todos);
      }
      // insert heap generation logic.
      match_todos = this->RunMatch(match_todos, false);
      this->EmitOutstandingPrimExprCompute();
      this->RunMatch(match_todos, true);

      BindingBlock pre_block = builder_->EndBlock();
      blocks.push_back(pre_block);
    }

    // new body.
    auto body_seq = Downcast<SeqExpr>(this->VisitWithNewScope(func->body, func->params));
    blocks.insert(blocks.end(), body_seq->blocks.begin(), body_seq->blocks.end());

    {
      // Insert the return value check
      builder_->BeginBindingBlock();
      std::ostringstream err_ctx;
      err_ctx << "ErrorContext(fn=" << gvar->name_hint
              << ", loc=return, annotation=" << func->ret_struct_info << ") ";
      std::vector<MatchShapeTodoItem> match_todos;
      // NOTE: the return value's shape computation must already be defined.
      this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, false, err_ctx.str(),
                           &match_todos);
      // NOTE: the return value's shape computation must already be defined.
      this->RunMatch(match_todos, true);
      BindingBlock post_block = builder_->EndBlock();
      blocks.push_back(post_block);
    }

    auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body));
    // create a new function
    return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs);
  }

visitExpr #

PrimValueNode #

  Expr VisitExpr_(const PrimValueNode* op) final {
    using runtime::relax_vm::MakeShapeCode;
    // Constant shape can be preserved.
    bool is_const_value =
        op->value->IsInstance<IntImmNode>() || op->value->IsInstance<FloatImmNode>();
    if (is_const_value) {
      return GetRef<Expr>(op);
    }

    Array<Expr> args = {shape_heap_};
    auto [code, value_or_index] = MakeSymbolicShapeArg(op->value);
    args.push_back(code);
    args.push_back(value_or_index);

    // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
    Call call(builtin_make_prim_value_, args, Attrs(), {Downcast<StructInfo>(op->struct_info_)});
    return call;
  }

ShapeExprNode #

  Expr VisitExpr_(const ShapeExprNode* op) final {
    using runtime::relax_vm::MakeShapeCode;
    // Constant shape can be preserved.
    bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), [](const PrimExpr& e) {
      return e->IsInstance<IntImmNode>();
    });
    if (is_const_shape) {
      return GetRef<Expr>(op);
    }

    Array<Expr> args = {shape_heap_, PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
    for (PrimExpr expr : op->values) {
      auto [code, value_or_index] = MakeSymbolicShapeArg(expr);
      args.push_back(code);
      args.push_back(value_or_index);
    }

    // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
    Call call(builtin_make_shape_, args, Attrs(),
              {ShapeStructInfo(static_cast<int>(op->values.size()))});
    return call;
  }

MatchCastNode #

  void VisitBinding_(const MatchCastNode* binding) final {
    Expr value = ExprMutator::VisitExpr(binding->value);
    std::vector<MatchShapeTodoItem> match_todos;
    std::ostringstream err_ctx;
    err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") ";
    // always_check=false
    this->CheckMatchCast(binding->struct_info, value, false, false, err_ctx.str(), &match_todos);

    match_todos = this->RunMatch(match_todos, false);
    this->EmitOutstandingPrimExprCompute();
    this->RunMatch(match_todos, true);

    // These checks are emitted as extra, in codegen
    // match-cast is simply ignored and treated as a normal binding.
    builder_->EmitNormalized(GetRef<MatchCast>(binding));
  }

相关builtin #

alloc_shape_heap #

# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap);
NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
  VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr);
  // 1. host allocator, 总是最后一个元素
  size_t host_device_index = vm->devices.size() - 1;
  // 2. 特别指定on-device RT hexagon
  if (vm->devices[0].device_type == kDLHexagon) {
    host_device_index = 0;
  } 
  auto* alloc = vm->allocators[host_device_index];
  // 3. 调用alloc->Empty
  return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]);
}

check_tensor_info #

简介:检测参数arg是否是Tensor(dtype, ndim)

  • arg: 输入参数
  • ndim:Tensor的ndim,可以是-1(表示不知道)
  • dtype:期待的数据类型
  • error_ctx
# src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body(CheckTensorInfo);
void CheckTensorInfo(TVMArgs args, TVMRetValue* rv) {
  ObjectRef arg = args[0];
  int ndim = args[1];
  DataType dtype;
  Optional<String> err_ctx;

  if (args.size() == 3) {
    dtype = DataType::Void();
    err_ctx = args[2].operator Optional<String>();
  } else {
    dtype = args[2];
    err_ctx = args[3].operator Optional<String>();
  }

  auto* ptr = arg.as<NDArray::ContainerType>();
}

match_shape #

简介:内置的runtime函数 签名:

  • 此函数提供 runtime 形状填充和对match-cast的检测支持。
  • 当形状变量第一次出现时,我们应该加载形状并填充变量。
  • 当形状变量已经出现时,我们应该断言它已经等于现有的形状值。
  • (如果断定所有的code标识都是AssertEqualToImm,允许传递shape_heap为nullptr)
MatchShape(input_shape, shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n], err_ctx)

enum class MatchShapeCode : int {
  kAssertEqualToImm = 0, // 立即数 assert input_shape[i] == r[i]
  kStoreToHeap = 1, // 第一次看到symbolic shape, 存到heap中 shape_heap[r[i]] = input_shape[i]
  kNoOp = 2, // 什么都不做
  kAssertEqualToLoad = 3, // assert input_shape[i] == shape_heap[r[i]]
};

具体实现:

# ./src/runtime/relax_vm/builtin.cc
TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape);
void MatchShape(TVMArgs args, TVMRetValue* rv) {
  // 1. 第一个参数是shape或者tensor
  ShapeTuple input_shape;
  if (args[0].IsObjectRef<NDArray>()) {
    input_shape = args[0].operator NDArray().Shape();
  } else {
    input_shape = args[0];
  }
  // 2. 第二个参数, heap->DLTensor
  DLTensor* heap = args[1];
  int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data);
  // 2. 第三个参数:大小、第四个参数BeginCode
  int64_t size = args[2];
  const int64_t kBeginCode = 3;
  // a function that lazily get context for error reporting
  const int64_t kErrorContextOffset = kBeginCode + size * 2;
  Optional<String> err_ctx = args[kErrorContextOffset];

  for (int64_t i = 0; i < size; ++i) {
    // 解析code
    MatchShapeCode code = static_cast<MatchShapeCode>(args[kBeginCode + i * 2].operator int());
    // 解析register-> args
    int64_t reg = args[kBeginCode + i * 2 + 1];

    if (code == MatchShapeCode::kAssertEqualToImm) {
      CHECK_EQ(input_shape[i], reg)...
    } else if (code == MatchShapeCode::kStoreToHeap) {
      heap_data[reg] = input_shape[i];
    } else if (code == MatchShapeCode::kNoOp) {
    } else {
      ICHECK(code == MatchShapeCode::kAssertEqualToLoad);
    }
  }
}

make_shape #

简介: 签名:

  • 如果code都是UseImm的话,允许shape_heap为nullptr
MakeShape(shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n]).

enum class MakeShapeCode : int {
  kUseImm = 0,
  kLoadShape = 1,
};

实例分析1:lower #

before:

@tvm.script.ir_module
class Before:
	@R.function
	def main(x: R.Tensor(["n", 2, "m"], "float32")):
		R.func_attr({"relax.force_pure": True})
		return x

after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)

sindex = {
	"n": 0,
	"m": 1,
}
@tvm.script.ir_module
class Expected:
	@R.function
	def main(x: R.Tensor(["n", 2, "m"], "float32")):
		R.func_attr({"relax.force_pure": True})
		# 1. 分配shape堆, 使用R.builtin_with_ctx, 因为需要ctx上下文, 得到vm runtime
		shape_heap = R.call_builtin_with_ctx(
			"vm.builtin.alloc_shape_heap",
			[R.prim_value(2)],
			sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
		)
		# 2. 检查输入x的tensor类型是否匹配, x == Tensor(ndim=3, dtype="float32")
		_ = R.call_packed(
			"vm.builtin.check_tensor_info",
			x, 
			3,R.dtype("float32"), # ndim == 3, dtype == "float32"
			"",
			sinfo_args=[R.Tuple()],
		)
		# 3. match_shape
		_ = R.call_packed(
			"vm.builtin.match_shape", # 内置的runtime函数, 运行期确定sidx的值
			x, # 第一个参数: tensor
			shape_heap, 3, # shape_heap, shape大小为3
			MS.STORE_TO_HEAP, sindex["n"],
			MS.ASSERT_EQUAL_TO_IMM, 2,
			MS.STORE_TO_HEAP, sindex["m"],
			"",
			sinfo_args=[R.Tuple()],
		)
		return x

示例分析2:running #

Relax设计文档#将该程序降级为VM指令的挑战 运行期数据结构包含:

  • ShapeTuple和NDArray。
  1. 输入和irmodule
# 输入
\inputs = tvm.nd.array(
    np.array([1,2,3]).astype(np.int64)
)
# irmodule
d = tvm.tir.Var("m", "int64")
p = tvm.tir.Var("n", "int64")
nums = relax.Var("shape_d", R.Tensor([d, p], "int32"))
with bb.function("main"):
    with bb.dataflow():
        b = relax.op.ones(R.shape((d, p)), "int32")
        gv = bb.emit_output(relax.op.astype(b, "int64"))
    bb.emit_func_output(gv, [nums])
mod = bb.get()
  1. after pipeline, lower shape
mod_after_pipeline = relax.pipeline.get_pipeline("zero")(mod)
mod_after_builtinLower = relax.transform.VMBuiltinLower()(mod_after_pipeline)
mod_after_vmshapeLower = relax.transform.VMShapeLower()(mod_after_builtinLower)
  1. build and running
rt_mod = relax.build(mod_after_pipeline, tvm.target.Target("llvm"))

vm = relax.VirtualMachine(rt_mod, tvm.cpu())
inputs = tvm.nd.array(
    np.array([[1,2,3], [4,5,6]]).astype(np.int32)
)
vm["main"](inputs)

分析:

  • vm调用relax函数main, 传递Inputs作为输入
@main:
  # 
  call  vm.builtin.alloc_shape_heap in: %vm, i3      dst: %1
  call  vm.builtin.check_tensor_info in: %0, i2, c[0], c[1] dst: %void
  call  vm.builtin.match_shape in: %0, %1, i2, i1, i0, i1, i1, c[1] dst: %void
  #
  call  shape_func       in: %1           dst: %void
  call  vm.builtin.make_shape in: %1, i1, i1, i2 dst: %2
  
  call  vm.builtin.alloc_storage in: %vm, %2, i0, c[2], c[3] dst: %3
  call  vm.builtin.make_shape in: %1, i2, i1, i0, i1, i1 dst: %4
  
  call  vm.builtin.alloc_tensor in: %3, i0, %4, c[4] dst: %5
  call  vm.builtin.null_value in:              dst: %3
  call  vm.builtin.make_shape in: %1, i2, i1, i1, i1, i0 dst: %6
  
  call  vm.builtin.call_tir_dyn in: f[fused_ones_cast], %5, %6 dst: %void
  call  vm.builtin.match_shape in: %5, %1, i2, i3, i0, i3, i1, c[5] dst: %void
  ret   %5