【TVM】Relax设计文档

架构 #

它旨在:作为第一个从架构级别理解relax的文档。有关具体方面的详细信息,请参阅其他设计文档。

关键目标 #

relax用于ML加速的三个关键目标:

G0:支持动态shape workloads #

具体来说,我们需要支持动态形状workloads,并优化新旧性能。

动态形状模型在当今的机器学习工作负载中无处不在。动态性可能来自 可变的输入大小(variable input size),或者 来自程序中缺少的信息(missing information)。

G1:支持携带高级语义的“计算图” #

大多数机器学习工程师都熟悉“计算图”及其优化(图中的每个操作都没有副作用)。虽然这种优化对大多数程序都很有用。 当我们开始处理随机数、状态和权重更新时,我们还需要能够表示包含更复杂语义的进程,例如控制、就地更新(inplace updates)和副作用(side effects)。

此外,一些高级优化可能需要我们处理mutations,例如scatter-gather操作的inplace更新。

我们需要找到一种方法,使大多数人能够编写计算图优化,同时仍然能够表示这些高级语义。

G2:统一跨层优化抽象 #

现在,TVM 在抽象之间有一个明确的边界。Relay以单次翻译方式完成到TIR的降级。 然而,我们发现我们对跨层执行优化有强烈需求。 例如,理想情况下,TensorIR 中的自动化决策应该为高层次的融合和布局决策提供信息。这种需求出现在我们的 TensorCore 自动调度应用进程以及 NPU 相关工作负载中。

关键设计点 #

import tvm.script
from tvm.script import tir as T, relax as R

@tvm.script.ir_module
class MyIRModule:
    @T.prim_func
    def tir_exp_func(x: T.handle, y: T.handle): ## <= D2
        X = T.match_buffer(x, (n,), "float32")
        Y = T.match_buffer(y, (n,), "float32")
        with T.grid(n) as i:
            Y[i] = T.exp(X[i]) 

    @R.function
    def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[_, "f32"]):
        # n, k above are implicitly defined by the signature
        # so we will be able to refer to n, k in the later part of the program
        with R.dataflow(): ### <= D0
            lv0 = R.match_shape(w, (k, m)) ## <= D1
            lv1: R.Tensor[(n, m), "f32"] = R.dot(x, lv0)
            lv2: R.Tensor[(n * m,), "f32"] = R.flatten(lv1) ## <= D1
            lv3: R.Shape = (n * m,)  ## <= D1 
            gv0: R.Tensor[lv2, "f32"] = R.call_tir(lv2, tir_exp_func, [lv3])   ## <= D2
            R.outputs(gv0)

        R.call_packed("custom_inplace_update", gv0)  ## <= D0, D2
        return gv0 

我们可以使用上面的代码片段来演示 relax 的关键设计点。请注意,script语法仍在不断发展,可能会发生变化。

D0:第一个构造的类是数据流块 #

大多数relax_func代码都封装在 R.dataflow() 构造中。数据流块中的所有操作是无副作用的,并且不包含高级控制流(如 if-then-else)或嵌套作用域。

数据流块可以有效地被视为嵌入在进程中的计算图。请注意:数据流块中的大部分绑定变量(lv0, lv1, lv2, lv3)都是“local”的,这意味着他们仅仅在块内是可见的。

  • 这些变量可以看作是计算图的“内部节点”。
  • 我们可以将变量标记为 output(gv0),在这种情况下,该变量将在进程的后面部分可见。这些输出变量可以看作是计算图中的输出节点。 请注意,R.call_packed("custom_inplace_update", gv0) 位于数据流块之外。数据流块之外的所有内容都可能产生副作用。因此,除非我们进行更仔细的分析,否则我们无法执行优化,例如根据拓扑顺序重新排序这些绑定 我们预计大多数优化将发生在数据流块级别

这些优化可以由熟悉计算图概念的 ML 工程师完成。 隔离(isolate)和表示(represent)有效组件的能力也为需要它们的地方提供了更高级优化的机会。

D1:第一个计算的类是shape deduction #

对于动态模型来说:shape deduction是很重要的。

  • 在动态形状设置下,我们通常需要在运行计算之前计算中间张量的形状
  • 此外,我们还需要处理形状本身与数据相关的情况(例如唯一);
  • 最后,大多数动态形状工作负载仍然包含大量(部分)静态形状,理想情况下,我们希望利用这些静态形状信息进行优化
from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
    with R.dataflow():
        # 1. 符号和静态形状推导 symbolic and static shape deduction
        lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4)) 
        lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
        lv2: R.Shape = (n * 4,)
        # 2. 外部不透明形状函数 external opaque shape function
        lv3: R.Shape = R.call_packed("myshape_func", lv2)
        lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1]) 
        # 3. 数据依赖情况 data dependent case
        lv5: R.Tensor[_, "f32"] = R.unique(lv4)
        # 3. 重新match形状 re-match shape
        lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
        gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
        R.outputs(gv0)
    return gv0

以上进程涵盖了形状推导的典型场景(正如注释所言):重要的是,shape 现在与 Tensor 值一起成为计算的一部分。这反映了一个事实,即形状的计算可以在运行时进行。 而文本格式类型注解 lv0: R.Tensor[(n, 4), “f32”] 显示每个值的形状:这只是一种句法糖,从 IR 的角度来看,形状场 (n, 4) 不是lv0.checked_type的一部分。

  • lv0 的类型是 DynTensor(rank=2, dtype=“f32”)
  • shape 是附加到每个 Expr 的特殊值字段 我们做出了这个明确的选择,以简化类型推断,这样我们就不需要进入完全依赖的类型领域。

有两个与符号形状计算相关的关键结构:

  1. match_shape value = match_shape(lhs, pattern) 匹配形状构造采用 lhs 值和模式(符号整数表达式)。它有两个重载语义:
  • 当lhs是一个tensor的时候,它会将 lhs.shape 与pattern匹配,如果该变量首次出现在pattern中,则填充相应的符号整数变量,然后返回与 lhs 相同但 shape 字段更新为 pattern 的新 Tensor
  • lhs 也可以是直接匹配pattern的 Shape:当我们想要隔离出与任何张量值不对应的形状函数时,这很有用。 例如:
from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[_, "f32"], y: R.Tensor[_, "f32"]):
    with R.dataflow():
        # the match shape defines n, m because it appears for the first time
        lv0: R.Tensor[(n, m)] = R.match_shape(x, (n, m))
        # the second occurance of n, m will translate into an assertion 
        # that y's shape equals (n, m)
        lv1: R.Tensor[(n, m)] = R.match_shape(y, (n, m)) 
        # we can also call match_shape on shape expressions
        lv2: Shape = R.match_shape(R.shape_of(y), (n, m)) 
  1. 形状推导来自于符号整数元组(symbolic integers tuple)

在我们得到 n 和 m 等符号整数之后。我们可以将它们重新组合在一起以形成一个 Expr。符号整数表达式的任何元组都可以在 relax 中被识别为 Shape 值。因此 (n, m) 是一个形状值。

形状传播的途径: 重要的是,因为形状现在是值的一部分,所以在计算时才会出现。 编译时形状推断可以看作是编译时常数折叠(或部分评估)对与形状有关的操作。

有几种方法来表示形状计算:

  • W1:符号形状传播。一个形状可以被分解成符号整数(在上面的进程中是n或m),然后我们可以使用符号整数(n*4)的表达式来表示形状计算。值得注意的是,静态形状是(常量符号)整数的特例。然后,符号整数可以重新组合以形成形状值(例如(n* 4,))。
  • W2:不透明形状函数调用。我们还可以实现不透明的形状函数(myshape_func)。这些不透明的形状函数对于快速hack up运行时形状函数是有用的。
  • W3:对于依赖于数据的 shape(unique),我们将简单地遵循运行时调用 f(inputs)->output,该调用接受输入张量,分配并返回输出张量。然后,我们可以通过构造从 Tensor 值中获取 lv5 的形状match_shape

对写pass的影响: 许多优化过程都需要查看形状信息。现在,许多形状都可以是符号化的 (n, 4),最理想的优化pass将需要泛化一点以利用符号信息。 例如,在上面的程序中,我们知道所有的 与 n 相同的值。这种限制是非常有用的。此外,由于arith模块中的符号整数,我们可以重用 proves 的机制来检查符号表达式的等价性和推导性(例如 prove(n4 == n4))。

因为符号整数(tir.PrimExpr)急切常数折叠,当输入是静态形状时,计算结果也应该急切地折叠为常数整数,保留静态形状相关优化所需的属性。

因为我们现在可以在元组 (n, 4) 中表示混合符号静态形状,所以我们可以尝试利用静态信息进行额外的优化。

D2:可以与TensorIR和PackedFunc直接交互 #

我们做出的最后一个关键设计决策是允许高级 IR 能够直接交互并调用较低级别的 TensorIR 和 PackedFunc。TensorIR 函数和许多外部库采用目标传递约定(我们需要显式分配输出并作为参数传入函数)。我们使用 dps(destination passing) 来表示这个约定。dps 在低级 ML 优化中非常重要,因为它允许我们在可能的情况下一次性全局分配中间存储,并在没有活跃的内存分配的情况下执行计算

调用 dps 函数意味着在调用后,结果通过 函数参数(例如,以下示例中的 result)而不是函数的返回值传回

// not destination passing
int func(int x) {
  return 1;
}
// destination passing
void func(int x, int *result) {  
  *result = 1;
}

DPS 样式意味着本质上发生了发辫。我们需要一种方法将调用桥接到高级(纯)数据流领域,以便我们可以对一系列 tir 调用执行计算图样式重写。

  1. call_tir call_tir:是弥合gap的内在因素。该函数意味着“按照管理调用一个TIR函数”
def call_tir(output_shape: Shape, lowlevel_func: Expr, inputs: Tuple[Expr]) -> Expr:
    """Example code to demonstrate the semantics of call tir"""
    out_tensor = alloc_tensor(output_shape, current_expr.dtype)
    lowlevel_func(*inputs, out_tensor)
    return out_tensor

call_tir接受:输出形状、lowlevel_func(可以打包 func、tir PrimFunc)和输入元组。call_tir的语义可以通过上面的代码来证明。值得注意的是,当我们对call_tir降级时,我们不需要选择单独的输出张量分配。编译器可以选择创建中间张量的内存计划,并将内容捆绑在一起以实现有效的重用。

AST 设计 #

为了支持体系结构概述中的关键目标(G0:支持动态形状工作负载,G1:作为一等公民的数据流块),Relax 将以下构造添加到 AST。

class ShapeExpr(Expr):
    """corresponds to a shape containing symbolic PrimExpr"""
    values: List[PrimExpr]

class Var(Expr):
    """global scope visible vars"""
    vid: Id
    type_annotation: Optional[Type]

class DataflowVar(Var):
    """dataflow scope visible vars"""
    pass

class Binding(ObjectRef):
    """the base class of bindings"""
    pass

class VarBinding(Binding):
    """variable bindings, bind the value to the var"""
    var: Var
    value: Expr

class MatchShape(Binding):
    """binding represents to match a shape"""
    value: Expr
    pattern: List[PrimExpr]
    var: Var

class BindingBlock(Node):
    """base class of binding block, bindings inside can be impure (with side effect or control flow)"""
    bindings: List[Binding]

class DataflowBlock(BindingBlock):
    """dataflow block, bindings inside are pure (no side effect and no control flow)"""
    pass

class SeqExpr(Expr):
    """sequence of BindingBlocks, can serve as the body of a Function"""
    blocks: List[BindingBlock]
    body: Expr

class Function(BaseFunc):
    """represents a Relax function"""
    params: List[Var]
    body: Expr   
    ret_type: Type

class ExternFunc(BaseFunc):
    """extern function, which can represent a TIR PrimFunc or a PackedFunc."""
    global_symbol: String

relax Arch

  • 函数的body可以由SeqExpr表示;
  • 一个SeqExpr由BindingBlock列表组成;
  • DataflowBlock 是一种特殊的 BindingBlock,与纯计算图相同。DataflowBlock 内部的绑定没有副作用,也没有控制。
  • 一个BindingBlock包含一个Binding的列表;
  • Binding可以是VarBinding或者MatchShape
  • DataflowVar的作用域是DataflowBlockDataflowBlock 中的正常的 Var 转义到包含该块的作用域(可以是函数作用域或其他作用域,如 if 分支)。

下面是一个relax程序,relax_func包含seqExprseqExpr包含一个DataflowBlock(with 2 VarBinding)和BindingBlock(包含VarBinding

from tvm.script import relax as R

@R.func
def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[(k, m), "f32"]):
    # start a DataflowBlock
    with R.dataflow(): ## <= DataflowBlock
        lv0: R.Tensor[(n, m), "f32"] = R.dot(x, w) ## <= VarBinding, lv0 is a DataflowVar
        gv0: R.Tensor[(n * m,), "f32"] = R.flatten(lv0) ## <= VarBinding, gv0 is a Var that escapes to the outer scope
        R.outputs(gv0)

    # start a BindingBlock
    gv1 = R.call_packed("custom_inplace_update", gv0) ## <= side-effect binding
    return gv1

为什么要区分DataflowBlock和BindingBlock #

大多数编写Pass的人是没有编译器或 PL 背景的 ML 研究人员和 ML 工程师,因此他们基于一个简单的假设来编写传递,即PASS正在改变一个pure的计算图。

  • Relay 没有明确说明哪些表达式具有副作用(side-effects),哪些表达式是纯的,因此,在存在副作用的情况下,许多优化是不合理的。
  • relax中,DataflowBlock表示计算图该区域中的binding是pure的(无控制流,无side-effects)。明确地划分该计算图区域并将其作为一等公民,使终端用户可以轻松编写计算图PASS。
  • 由于“纯”和“不纯”区域之间的这种明确分离,函数的主体可以由一个或多个纯或不纯的块组成,因此 SeqExpr 的主体带有 Array<BindingBlock>

MatchShape也是Binding类型 #

MatchShape(value: Expr, pattern: List[PrimExpr], var: Var)

  • MatchShape有两个重载语义:
    • 假设x是一个2维张量 (1)MatchShape(x.shape, [m, n], var) → 将x.shape 与符号变量 (m, n)相匹配, 并将 Shape 返回给返回变量; (2)MatchShape(x, [m, n], var) → 将x.shape 与符号变量 (m, n)相匹配, 并将与张量 x 形状相同的二维张量(但具有显式形状字段 [m, n])返回到输出 var。

IR转换的含义 #

  • DataflowBlock 是一个独立的数据结构,其中包含Binding列表。PASS编写者可以使用 DataflowMutator 接口对纯数据流块进行访问和转换。对于那些只有 ML 背景且熟悉计算数据流的 pass writer 来说,它可能更人性化,因为他们只需要面对 DataflowBlock 这个简单的概念并override(重写) DataflowMutator 中的访问者。。
  • 一个BindingBlock包含Binding列表。ExprMutator 在 ANF 进程上工作,因此访问者无需记忆即可遍历绑定,并且由于没有递归,因此不会出现堆栈溢出。
  • ExprMutator 有一个内部 BlockBuilder,可以向新创建的块emit binding。为什么在ExprMutator内部有一个BlockBuilder呢?
    • BlockBuilder 提供了用于向块emit bindings的 API。我们经常可以看到这样的情况:我们想将多个bindings折叠成一个 (n → 1),或者我们想将绑定重写为多个绑定 (1 → n)。使用 BlockBuilder 在访问者中发出绑定可以很容易地做到这两点。
    • BlockBuilder 可以进行eager形状和类型推断,因此,在发出新bindings时,可以填充 LHS var 和 rhs expr 的 shape_ checked_type_ 字段。

relax中Visitor pattern的层级 #

Shape Computation设计 #

在 relax 中设计形状计算时,需要考虑以下关键目标:

关键目标 #

G0:第一类-支持动态符号整数形状 #

from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[(n, m, 2), "f32"]):
    with R.dataflow():
        # symbolic and static shape deduction
	lv0: R.Tensor[(n, m * 2), "f32"] = R.reshape(x, (n, m*2)) 
	lv1: R.Tensor[(n * m * 2,), "f32"] = R.flatten(lv0)
	...

符号整数形状使我们能够在符号表达式中有效地表示形状计算及其关系。 上面的例子给出了这种情况的一个例子。编译器需要知道相同的 n 通过多个计算块进行线程处理。 推导、常数折叠和分析(证明相等和其他关系)的能力将实现更多的优化。

我们还需要从 API 层面轻松让开发人员(和用户)能够在 AST 构建、重写和检查期间直接查找、证明属性、运行符号计算。

G1:一般情况下的安全网络 #

虽然rank是固定的,但动态的符号形状关系涵盖了大多数用例。不可避免地,我们还需要能够涵盖可能不属于该类别的一般情况:

  • C0:动态形状关系,其中输出形状是依赖于输入的数据(例如唯一运算符)。
  • C1:张量的秩是未知的(可能发生在极少数循环的情况下)
  • C2:张量的dtype是未知的
  • C3:其他情况,低级库的不透明运行时对象(例如 PRNG 句柄、cuDNN 上下文)。

因此,重要的是要有一个“安全网络”解决方案,以便我们涵盖一般情况。

G2:未来的提升改进路线 #

与往常一样,有一些方法可以不断提高总体上进行形状支持的能力

  • 例如,构建有效地为动态rank运算符生成代码的能力。
  • 我们需要确保的一件事是,所提出的设计具有与未来高级改进兼容的路径。
  • 这个观点也与 G1 密切相关,因为在我们添加高级形状推导机制之前,我们总是可以暂时回退到安全网络。

形状限制表示 #

张量的形状(约束)由relax.Expr(RelayExpr)的两个字段表示。

  • checked_type_: Type, 存储泛型 rank 和 dtype 约束
  • shape_:Expr,存储用于在运行时计算表达式形状的方法。

DynTensorType #

checked_type_:存储表达式的编译时推导类型。Tensor Expr包含下面两个fields:

class DynTensorType: 
    # rank can be unknown
    rank: int
    # dtype can be unknown
    dtype: DataType

在大多数情况下,给定 Tensor expr 的秩和 dtype 是已知的。但是我们也允许具有未知等级和 dtype 的类型(由于考虑了 G1 和 G2)。

Shape Field #

DynTensorType:不包含形状信息。相反,Tensor 的形状存储在 Expr 的可选 shape_字段中。

对于一个Expr xx.shape_,包含下面的值:

  • V0:ShapeExpr,包含一个Array<PrimExpr>,表示x有一个已知的动态形状,可以从先前的符号值中推导出来。
  • V1:通用relax.Expr,可以调用不透明(形状)函数或形状推导内部函数。
  • V2:None,表示形状在编译时未知,需要在运行时查找。
from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
    with R.dataflow():
        # V0: symbolic and static shape deduction
		lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4)) 
		lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
		lv2: R.Shape = (n * 4,)
		# V1: external opaque shape function
		lv3: R.Shape = R.call_packed("myshape_func", lv2)
		lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1]) 
		# V2: data dependent case
		lv5: R.Tensor[_, "f32"] = R.unique(lv4)
		# re-match shape
		lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
		gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
		R.outputs(gv0)
    return gv0

上面的代码块显示了三种方案的示例。

  • V0 是我们希望提供一类支持的最常见情况
  • V1:用于处理可能无法用符号表达式relax.Expr的情况(例如动态排名情况)
  • V2:我们的安全网络(G1) 当形状推导器无法推导出形状时,无论是由于缺乏分析能力还是由于无法解决的障碍(在数据依赖的情况下),我们可以回退到 V2。

Shape修订 #

在依赖数据的计算或外部调用之后,我们可能需要能够恢复/优化形状信息以实现更多优化。 match_shape构造被用于执行此类改进

value = match_shape(lhs, pattern)

match shape构造采用 lhs 值和pattern(符号整数表达式)。它有两个覆盖的语义:

  • 当lhs是一个张量时,它将match lhs.shape到pattern,如果该变量首次出现在模式中,则填充相应的符号整数变量,然后返回与 lhs 相同但 shape 字段更新为模式的新 Tensor。
  • lhs 也可以是直接匹配pattern的 Shape。当我们想要隔离出与任何张量值不对应的形状函数时,这很有用。

形状传播 #

本节讨论 Relax 中的形状传播策略。

D0:安全链路–没有形状计算 #

我们需要构建的第一件事是在没有执行形状计算时进行安全链路计算。对无形状计算的支持创建了一个安全网,我们总是可以回退到这个安全网中

from tvm.script import relax as R

@R.function
def shape_example(x: R.Tensor[_, _]):
    with R.dataflow():
        # computation without shape
		lv0: R.Tensor[_, _] = R.log(x) 
		lv1: R.Tensor[_, _] = R.flatten(lv0)
		# bring things back to the shape computatuion land
		lv2: R.Tensor[(m,), "f32"] = R.match_shape(lv1, (m,))
		gv0: R.Tensor[(m,), "f32"] = R.exp(lv2)
		R.outputs(gv0)
    return gv0

上面的代码片段示例显示了一个没有形状的 AST。

  • 在降级期间,这些调用不会翻译为为目标传递样式,因为无法获取形状信息并预先分配内存
  • 相反,它们被直接翻译为分配和返回结果张量的调用。

R.log:可以映射到运行时 PackedFunc 调用,该调用接受 NDArray x 并执行elementwise日志操作。

  • 我们甚至可以调度到常见的运行时库,例如 torch.log

VM 可以很容易地支持这些功能,因为 PackedFunc 调用返回对象。通过使用匹配形状和后续传播,我们可以将张量从无形状计算区域带到形状感知区域。

无形状计算绝不是处理事情的最有效方法。对于依赖于数据的计算以及与形状信息较弱的外部库的接口等情况,这是必需的。

重要的是,它创建了一个重要的安全网,使我们能够以渐进的方式探索先进的形状推导机制,而不必担心所有可能情况的 100% 覆盖率(回到 G2)。

D1:针对元OPS的编译器形状传播 #

处理具有二元广播关系的基础算子(如 R.add)的编译时形状传播。我们注册以下推导函数。

@register_op_attr("relax.add", "FInferShape")
def finfer_shape(call: Call) -> Shape:
    # note: if x.shape_ is None
    # x.shape() will return shape_of(x) that obtains shape in runtime.
    lhs = call.args[0].shape();
    rhs = call.args[1].shape(); # 如果x.shape_为空,将会返回运行期shape
    # pattern match to symbolic deduction case
    if isinstance(lhs, ShapeExpr) and isinstance(rhs, ShapeExpr):
	    # 符号shape
		# construct a symbolic shape according to lhs and rhs
		# Examples:
		# (n, m) + (m) => (n, m)
		# (n, 1, m) + (2, m) => (n, 2, m)
		...
		return deduced_shape_expr
    return call(op.get("binary_broadcast"), lhs,rhs)

然后,finfer_shape可用于形状传播。具体说来:

  • 它处理最常需要的符号形状 (G0) 情况
  • 允许回退到不透明形状的二元广播
  • 在最糟糕的情况中,我们仍然返回None,在这个例子中,我们回滚到我们的安全网络(G1) 请注意,可能需要为高级推导引入其他签名信息。这些推导HOOK可以作为单独的注册引入,而不会影响整体设计。

D2:跨transform的额外改进和形状预留 #

形状推导是不完美的。在低级调用转换为TIR的例子中。有时必须显式指定输出形状本身

为了补充这一点,我们需要能够利用额外的形状改进(通过match_shape),并在可能的情况下在转换中保留形状信息。具体说来:

  • 启用PASS以便于在分析PASS后注释其他形状信息。
  • 更改为低级函数后,从高级别显式保留输出形状信息。

现在我们需要确保我们有足够的 D2,以便它们与未来的改进 (G2) 兼容。有关更多讨论,请参见 A0。

D3:跨函数边界的形状传播 #

对子函数的调用可以由构造或融合PASS而生成。本节讨论跨函数边界的形状传播。

默认的函数调用惯例如下所示:

  • tir.PrimFunc:在运行时在被调用方中检查形状,调用方需要显式指定输出形状。
  • relax编译的PackedFunc:输入形状由被调用方检查,如果返回值的形状未知,则需要match_shape 对返回值的形状进行包装。

请注意,当前的调用约定在函数边界隔离了形状约束,因此它在运行时创建了安全网。

重要的是,调用者希望尽可能多地传播形状信息(尽管由于安全网的原因,这并不是绝对必要的)。这里有一些方法:

  • W0:在融合PASS中,形状信息最初在融合前可用,生成包含这些形状信息的子函数,保留从融合前到融合后的形状信息,必要时在调用方端插入匹配形状。
  • W1:利用子relax函数的常数传播进行形状推导
    • 将输入中的符号形状替换为输入参数的形状
    • Constant 评估进程,看看我们是否可以获得独立于其他执行的输出参数的形状。
  • W2:如果形状函数与值无关,则提取出该函数
    • 分析函数内部的计算
    • 如果形状计算延续独立的路径,则提取形状计算函数

请注意,W1/W2 并不总是可行的。只要我们有添加更多注释或依赖 D0 的后备安全网,就可以尽最大努力。 注意:这里的计划不够详细,无法立即采取行动,但考虑到 G2,我们可以继续进行必要的手段 (W0),然后稍后添加 W1/W2。

其它考虑 #

这些是额外的考虑因素,不属于直接范围,而是可能的未来目标,值得根据考虑因素 G2 进行调整。

A0:更高级的形状推导机制 #

有了清晰的安全网,我们可以利用高级重写,在必要时增强形状信息。请参阅以下示例:

from tvm.script import relax as R

@R.function
def before(x: R.Tensor[_, _], y: R.Tensor[(m, n), "f32"]):
    with R.dataflow():
		# computation without shape
		lv0: R.Tensor[_, _] = R.log(x) 
		gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
		R.outputs(gv0)
    return gv0

@R.function
def after(x: R.Tensor[(m, n), "f32"], y: R.Tensor[(m, n), "f32"]):
    with R.dataflow():
		# computation without shape
		lv0: R.Tensor[(m, n), "f32"] = R.log(x) 
		gv0: R.Tensor[(m, n), "f32"] = R.ewise_add(lv0, y)
		R.outputs(gv0)
    return gv0

上面的代码显示了利用更高级的形状分析的可能转换。在这种情况下,请使用 ewise_add 来确定输入 x 的形状和类型。这些改进可以以增量方式添加,而无需更改整体体系结构。

A1:符号形状更接近Expr #

现在,符号形状位于自己的类型上,需要包装(ShapeExpr)才能用作relax.Expr。

Compilation MVP设计 #

关键目标 #

  • 将relax程序编译为VM能够执行的格式
  • 一个多阶段的编译pipeline,支持可组合的多阶段编译管道
    • 每个transformation是:IRModule –> IRModule
    • 用户可能会使用第三方库(如 cudnn)运行程序的一部分。我们需要有能力来优化剩余的部分

relax程序案例 #

我们以编译以下简单的 Relax 程序为例:

import tvm.script
from tvm.script import tir as T, relax as R

@tvm.script.ir_module
class MyIRModule:
    @T.prim_func
    def tirexp(a: ty.handle, b: ty.handle):
        n1, m1 = T.var("n1"), T.var("m1")
		X = T.match_buffer(x, (n1, m1))
		Y = T.match_buffer(y, (n1, m1))
		with T.block(n1, m1) as i, j:
		    Y[i, j] = T.exp(X[i, j])
    @R.function
    def myfunc(x: Tensor[(n, m)]):
        with R.dataflow():
		    lv0: Tensor[(n, m)] = R.call_tir((n, m), tirexp, [x])
		    gv0: Tensor[(m*n,)] = R.call_tir((m*n,), "flatten", [lv0])
		    R.outputs(gv0)
        return gv0

我们在relax中引入了一个新的intrinsic:relax.call_tir,并使用它来构建程序。CallTIR包含下面的特征:

  • 没有side effect
  • 有形状标注
  • 核心的表达式:call_tir(output_shape, func, [arg0, arg1, ...], Optional<shape_expr>) -> Expr
    • 这意味着我们采用一个 TIR 调用约定(目标传递样式)中的函数,例如:mylog(in, out),并且将它包装到call_tir函数中,call_tir(out_shape, mylog, in)将会返回输出
    • func 可以是一个TIR函数或者一个packed函数
    • shape_expr是一个可选的参数(可以传递整数),https://github.com/tlc-pack/relax/wiki/EmitTE-Staging-Integration#dynamic-shape-case
    • output_shape:是ShapeExpr或者一个Tuple

将该程序降级为VM指令的挑战 #

  • C0:每个call_tir都需要被降级(relax VM支持call指令,可用于直接调用packed func)–> 我们需要插入显式地输出内存分配(with 内存计划)
  • C1:符号形状 n 和 m 不是运行时可以表示的(relax VM只支持NDArrayShapeTuple运行时数据结构)–>我们需要在VM中使用堆来执行形状计算

首先,将call_tir降级为显式地内存形式 #

一个显式地内存形式程序有下列性质:

  • 显式地分配和kill存储器和tensors
  • 有side effect
  • 没有形状标注
  • 核心表达式:call(func_name, arg0, arg1, ...) -> optional<Expr>,将其映射为VM可以执行的Call指令

四个intrinsics/内置函数:

  • relax.vm.builtin.alloc_storage(size, device) -> storage:分配可用于创建张量的存储
  • relax.vm.builtin.alloc_tensor(storage, shape, offset, dtype) -> tensor:在存储中分配张量
  • relax.vm.builtin.free_storage(storage):释放分配的存储
  • relax.vm.builtin.free_tensor(tensor):释放分配的张量

由于alloc_storagealloc_tensor的参数是整数和dtype,并且它们不能表示为 Expr 作为 CallNode 的参数,alloc_storage 和 alloc_tensor 在 Relax 中被设计为内部函数,并包含包含 int 和 dtype 的属性

下面的程序是对call_tir降级后的表示:

from tvm.script import tir as T, relax as R
@R.function
def myfunc(x):
    # has side effect, so it's now in a BindingBlock instead of a DataflowBlock
    n, m = R.match_shape(x.shape)
		
    storage0 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
    tensor0 = relax.vm.builtin.alloc_tensor(storage0, shape=[n, m], offset=0, "f32")
    R.call_packed("tirexp"), x, tensor0)
		
    storage1 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
    tensor1 = relax.vm.builtin.alloc_tensor(storage1, shape=[m*n,], offset=0, "f32")
    R.call_packed("flatten"), tensor0, tensor1)
		
    R.call_packed("free_tensor"), tensor0)
    R.call_packed("free_storage"), storage0)
    return tensor1

下一步,通过VM堆操作实现shape降级 #

  • 三个内置函数:
    • relax.vm.builtin.alloc_heap(size) -> heap:分配具有特定大小的堆(NDArray)以执行形状计算(我们可以使用alloc_tensor实现相同的目标)
    • relax.vm.builtin.store_shape(shape, heap, idx0, ...):将形状存储到 vm 堆中的特定indices中
    • relax.vm.builtin.load_shape(heap, idx0, ...) -> shape:根据切片从 vm 堆构造shape (由于 store_shapeload_shape 包含索引(整数数组)作为其参数,并且它们不能表示为 Expr,因此它们在 Relax 中被设计为内部函数,并包含描述 int 和 dtype 的属性。
from tvm.script import tir as T, relax as R

# Program after shape lowering, shape降级后的程序
@R.function
def myfunc(x):
    shape_heap = relax.call_packed("vm.builtin.alloc_shape_heap", size=k) 
    relax.vm.builtin.store_shape(x.shape, shape_heap, 0, 1)
    sh = relax.vm.builtin.load_shape(shape_heap, 0, 1)
    # this product_shape function (to compute n*m) is generated as TIR primfunc when visiting ShapeExpr in the shape lowering pass
    shape_size = product_shape(sh) 
		
    storage0 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
    gv0 = relax.vm.builtin.alloc_tensor(storage0, sh, 0, "f32")
    R.call_packed("tirexp"), x, gv0)
		
    sh1 = R.call_packed("load_shape"), heap, 0, 1)
    storage1 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
    gv1 = relax.vm.builtin.alloc_tensor(storage1, sh1, 0, "f32")
    R.call_packed("flatten"), gv0, gv1)
		
    R.call_packed("free_tensor"), gv0)
    R.call_packed("free_storage"), storage0)
    return gv1

Relax构建和编译workflow #

vm.build(ir_mod, target)中,四个降级pass:

    passes = [relax.transform.ToNonDataflow()]
    passes.append(relax.transform.CallTIRRewrite())
    passes.append(relax.transform.VMMemoryLower())
    passes.append(relax.transform.VMShapeLower())

在对passes降级后,relax.vm.build(ir_mod, target) 调用tvm::build来构建IRModule中的所有的TIR primfuncs,并且使用CodeGenVM来visit IRModule中的所有的relax 函数,并在visit的时候生成可执行的VM。

VM设计 #

设计目标:

  • 设计一个灵活的基于寄存器的VM,用于执行包含动态shape和控制流的relax程序
  • 最小的指令集:
    • Call packed func作为其的核心指令
    • 内置packed function library(例如:shape_of(tensor)
  • 通过shape heap (NDArray) 操作完成形状计算
    • 假设张量A的形状在编译器是(m, n),relax程序中我们希望计算(j, k)=(m+1,n+1)。在运行期,A的形状将通过调用vm内置函数store_shape(A.shape)存储到shape堆的索引0和索引1中。m+1和n+1将会在shape降级pass生成的TIR函数计算,并且 j 和 k 将会被存储到index 2和 3中。

指令集 #

relax vm只包含四种指令:(Call,Ret、If、Goto)

  • Call:调用带有参数的packed函数,并选择性地将输出写入返回寄存器 call <packed_func> [arg0, ...] dst: optional<reg> 参数可以是 寄存器、整数立即数或常量。
  • Ret:返回结果寄存器中的值 ret <res_reg>
  • If:如果cond 寄存器是true,继续执行下一条指令(true 分支),否则将 pc(program counter) 增加 false_offset 以转到 false 分支。 If <cond> false_offset
  • Goto:通过pc_offset增加pc(程序计数器) Goto pc_offset

架构 #

ExecBuilder #

ExecBuilder 是我们用来为虚拟机Emit指令和构建可执行程序的工具。它本身不维护任何数据成员,但提供了操作内部Executable程序的方法。 下面是一个用于展示builder是如何工作的例子:

@tvm.register_func("test.vm.move")
def move(src):
    return src

@tvm.register_func("vm.add")
def add(a, b):
    ret = a.asnumpy() + b.asnumpy()
    return tvm.nd.array(ret)

@tvm.register_func("vm.mul")
def mul(a, b):
    ret = a.asnumpy() * b.asnumpy()
    return tvm.nd.array(ret)

from tvm import relax as rx
ib = rx.ExecBuilder()
with ib.function("main", num_inputs=1):
    ib.emit_call("vm.move", [ib.c(0)], dst=ib.r(1)) # 常量0->寄存器1
    ib.emit_call("vm.add", [ib.r(0), ib.imm(10)], dst=ib.r(2)) # 寄存器0 + 立即数10 -> 寄存器2
    ib.emit_call("vm.mul", [ib.r(2), ib.r(1)], dst=ib.r(3)) # 寄存器2 * 寄存器1 -> 寄存器3
    ib.emit_ret(ib.r(3))
executable = ib.get()

这里

  • ib.r(x) 表示当前帧中的寄存器 x;
  • ib.c(i)表示常量池中的第i个常量,当前的常量池支持NDArray和DLDataType;
  • ib.imm(val):表示等价于val的立即数(inline constant)

ib.function和惯例 #

我们使用ib.function(func_name, num_inputs) 来注释发出的代码中函数的作用域。 必须提供输入的数量,例如k,以便前k个寄存器将用于存储函数的输入。在上面的例子中,寄存器 r(0) 存储函数输入。

Check和格式化 #

由于上面提到的函数约定,这使得检查用户是否正确使用寄存器会很棒。因此,我们将在 ib.function 的出口处验证发出的指令的几件事:

  • 用户是否使用意外的寄存器作为输入。例如,当输入数量为2的时候,ib.r(3)被用于作为输入。在这种情况中,我们将会报错;
  • 用户是否遗漏任何输入寄存器。例如,输入的数量为3,但是用户只使用ib.r(0)ib.r(2)作为输入。在这种情况下,我们将发出警告。

检查是在 CheckExecutable 中完成的。 此外,如果用户使用任意寄存器索引,则不利于寄存器分配。比如说,ib.r(10000),虽然这是正确的,但我们不想在执行过程中分配 10000 个寄存器。因此,我们将有一个formalize PASS,以按使用顺序重命名这些寄存器,这是在ExecBuilderNode::Formalize 中实现的。

Executable #

可执行文档存储虚拟机所需的内容,包括:字节码、函数表、常量池等。它必须支持序列化和反序列化,以便我们可以将其传递给另一个设备并将其加载到 vm 中。

数据结构 #

我们精心设计了 Executable 的数据结构,使其可以轻松序列化/反序列化。

struct VMFunction {
  std::string name; 
  Index start_instr;
  Index num_args;
  Index register_file_size;
};

class ExecutableNode : public Object {
  /* the global function informations */
  std::vector<VMFunction> global_funcs;
  /* the constant pool */
  std::vector<TVMRetValue> constants;
  /* packed function names, corresponding to the
     func_idx in Call instruction */
  std::vector<std::string> func_names;
  /* the emitted byte code of instrucitons */
  std::vector<ExecWord> instr_data;
  /* since the instruction's length is variable,
     we need to store the offset for indexing */
  std::vector<Index> instr_offset;
  ...
};

Emit指令和获取指令 #

在这样的数据结构中,当我们发出类似 ib.emit_call("add", [ib.r(0), ib.r(1)], dst=ib.r(2)) 的调用指令时,我们将:

  • 如果之前没有packed函数名字为"add"的话,我们将"add"添加到packed_func_names中;
  • instr_data的实际大小push到instr_offset中,instr_offset.push_back(instr_data.size())用于索引指令;
  • 将指令的opcode、func index、arguments、destination作为字节码push到instr_data中; 获取指令很简单,我们只需使用 instr_offsetinstr_data中索引指令:
Instruction ExecutableNode::GetInstruction(Index i) const {
  size_t offset = instr_offset[i];
  Opcode op = static_cast<Opcode>(instr_data[offset]);
  switch (op) {
    // ...
    // dispatch according to the op code
  }
}

序列化和反序列化 #

在这种情况下,序列化和反序列化也更容易。由于我们不需要将 Instruction 从内存中的数据结构转换为字节码,因此我们只需要序列化 instr_datainstr_offset,这已经是字节码了。

文本格式 #

我们支持以文本格式dump可执行文档的内容,以便用户可以检查代码。例如:

ib = rx.rx.ExecBuilder()

with ib.function("func0", num_inputs=2):
    ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
    ib.emit_call("vm.builtin.print", args=[ib.r(3)])
    ib.emit_ret(ib.r(3))

exec0 = ib.get()
print(exec0.stats())
print(exec0.astext()) 

输出:

Relax VM executable statistics:
  Constant shapes (# 0): []
  Globals (#1): [func0]
  Packed functions (#3): [vm.op.add, vm.builtin.move, vm.builtin.print]

@func0:
  call  vm.op.add        in: %0, %1       dst: %2
  call  vm.builtin.move  in: %2           dst: %3
  call  vm.builtin.print in: %3           dst: void
  ret   ret %3

Dump as Python Code #

我们还支持将其转储回 python 代码,以便用户可以轻松hack它:

# output of exec.aspython()
ib = rx.ExecBuilder()
with ib.function("func0", num_inputs=2):
    ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_call("time", dst=ib.r(100))
    ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
    ib.emit_call("time", dst=ib.r(101))
    ib.emit_call("vm.builtin.print", args=[ib.r(3)])
    ib.emit_ret(ib.r(3))

Virtual Machine #

VM 可以加载到包含生成的代码的executable和runtime module中,通过解释字节码,使用给定的输入执行特定函数。

# building
ib = relax.ExecBuilder()
with ib.function("func0", num_inputs=2):
    ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_ret(ib.r(2))

with ib.function("func1", num_inputs=2):
    ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_ret(ib.r(2))
ex = ib.get()

# execution
vm = relax.VirtualMachine(ex)
a = tvm.nd.array(np.random.rand(4,))
b = tvm.nd.array(np.random.rand(4,))
mul_res = vm["func1"](a, b)
add_res = vm["func0"](a, b)
np.testing.assert_allclose(add_res.asnumpy(), a.asnumpy() + b.asnumpy())
np.testing.assert_allclose(mul_res.asnumpy(), a.asnumpy() * b.asnumpy())

VM状态 #

内置函数(例如,alloc_storage)可能需要访问 VM 的内部状态(例如,allocators)。VMState 存储在一个特殊的寄存器中。

struct VMState {
  /*! \brief The memory allocators. */
  std::vector<Allocator*> allocators;
};

static constexpr RegName kVMStateRegister = 0x008D14FA4379015C; // magic number

// ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ...])
};

形状计算 #

VM 使用shape堆 (NDArray) 进行形状计算。(ShapeTuple 也是一种运行时数据类型。)

ib = relax.ExecBuilder()
shape = (32, 16)
x = tvm.nd.array(np.random.rand(*shape))
with ib.function("main", num_inputs=0):
    # alloc a shape heap of size 2, and store it in r(0)
    ib.emit_call("vm.builtin.alloc_shape_heap", args=[ib.imm(2)], dst=ib.r(0))
    # get the shape of tensor x, and store the ShapeTuple to r(1)
    ib.emit_call("vm.builtin.shape_of", args=[x], dst=ib.r(1))
    # store the shape of x to index 0 and index 1 of the shape heap
    # shape_heap[0] = 32, shape_heap[1] = 16
    ib.emit_call("vm.builtin.store_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)])
    # construct a ShapeTuple from the values in index 0 and index 1 of the shape heap
    ib.emit_call("vm.builtin.load_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2))
    ib.emit_ret(ib.r(2))

内置的runtime packed函数 #

VM 运行时库是packed函数的集合,可以使用 Call 指令调用这些函数(以便 relax vm 具有最小指令集) https://github.com/octoml/relax/blob/relax/src/relax/vm/builtin.cc

EmitTE Staging集成 #

tensor expression

本文档是relax集成TE的工程计划草图。Relax 支持通过 call_tir 直接嵌入 TIR 功能。然而,通过TVMScript手动构建TIR函数仍然很困难。

TE(tensor expression)是一种DSL,我们传统上用它来构造很多算子。虽然我们正在向 TIR 迈进以进行调度,但 te 作为创建 TIR 函数的简洁 API 仍然非常有用。

现在,我们引入了一个 API create_prim_func,使我们能够有效地从 TE 创建 TIR 功能。下面的代码块给出了一个示例。

from tvm.script import tir as T

@T.prim_func
def tir_element_wise(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    with T.block([128, 128]) as [i, j]:
        B[i, j] = A[i, j] * 2.0

A = te.placeholder((128, 128), name="A")
B = te.compute((128, 128), lambda x, y: A[x, y] * 2, name="B")

func = te.create_prim_func([A, B]) 
tvm.ir.assert_structural_equal(func, tir_element_wise)

有关其他读数,请参阅下面的文档列表:

重要的是,我们利用 TE 编写了丰富的运算符库集合(例如 topi)。重用这些库对于快速创建工作负载和算子降级非常有用。

本说明概述了一种将 TE 机制紧密重用的方法,用于 Relax ast 暂存和高级算子降级。

关键目标 #

下面的代码片段显示了我们的设计目标:

bb = rx.BlockBuilder()
n, m = tir.var("n"), tir.var("m")
x : rx.Var = rx.var("x", shape=[n, m])
a : rx.Expr = bb.emit(x + 1)
# create a te tensor
# requires a.shape to be a known ShapeExpr, otherwise match shape is needed here
A : te.Tensor= rx.te_tensor(a)
# construct a te expression via arbitrary topi/te calls
# This will create the prim_func that is related to te, and add a call_tir to it
b : rx.Expr = bb.emit(te.compute((n, m), lambda x, y: A[x, y] * 2, name="B"))

# light weight decorator style that does automatic conversion
# (turn rx arguments into rx.te_tensor if needed)
# otherwise, simply pass things as te
# this decorator can be applied to topi functions in the future so they can directly
# be used to stage relax functions as well.
def te_func(X: te.Tensor, Y: te.Tensor):
    return te.compute((128, 128), lambda x, y: A[x, y] + B[x, y])

# directly call into te function to be able to construct the 
bb.emit_te(te_func, x, a)

关键步骤是:

  • D0:允许将一个带有已知符号形状的relax.Expr 包装成一个 te.Tensor
  • D1:重用基于 te 的库来创建算子实现
  • D2:emit后,调用 create_prim_func 创建 tir 函数,生成对 tir 函数的call_tir

动态shape情况 #

为了正确支持这种类型的 DSL,对动态形状有一些特殊的注意事项。参考下面的例子:

n, m = tir.var("n"), tir.var("m")
x : rx.Var = rx.var("x", shape=[n, 2*floordiv(m, 2)])
A : te.Tensor= rx.te_tensor(a)
B = bb.emit(te.compute([n, 2*floordiv(m, 2)], lambda i, j: A[i, j] + 1))

如果我们简单地遵循上面的算法,生成的call_tir可以如下所示:

@T.func
def ewise_fun(A, B):
    n = tir.var("n")
    m = tir.var("m")
    A = T.match_buffer(a, (n, 2*floordiv(m, 2)))
    B = T.match_buffer(b, (n, 2*floordiv(m, 2)))
    with T.block([n, 2 * 2*floordiv(m, 2)]) as [i, j]:
        B[i, j] = A[i, j] * 2.0

x : rx.Var = rx.var("x", shape=[n, 2 * 2*floordiv(m, 2)])
B = bb.call_tir([n, 2 * 2*floordiv(m, 2)], ewise_func, [x])

上述代码的主要问题是 2*floordiv(m, 2) 中的符号变量 m 未在 ewise_fun 中定义。 这是因为match_buffer只定义一个普通变量 n 的变量,而不是表达式 2*floordiv(m, 2)。请注意,我们当然可以增强 match 的行为以支持一些复杂的模式,例如 m*2,这有时是可取的。然而,像 2*floordiv(m, 2) 这样的复杂模式需要我们在函数的开头知道 m。

一种解决方案是:应用match shape来创建一个与 2*floordiv(m, 2) 匹配的新符号 var z。这将导致 tir 的信息丢失(形状是 2 的倍数)。

相反,我们增强了生成的 tircall_tir_dyn,以生成以下代码

@T.func
def ewise_fun(A, B, m: T.int64):
    n = tir.var("n")
    A = T.match_buffer(a, (n, 2*floordiv(m, 2)))
    B = T.match_buffer(b, (n, 2*floordiv(m, 2)))
    with T.block([n, 2 * 2*floordiv(m, 2)]) as [i, j]:
        B[i, j] = A[i, j] * 2.0

x : rx.Var = rx.var("x", shape=[n, 2*floordiv(m, 2)])
B = bb.call_tir_dyn([n, 2*floordiv(m, 2)], ewise_func, [x], ShapeExpr([m]))

我们首先推断未绑定的 TIR 变量,例如 m,并将其作为附加参数添加到函数中。我们使用增强的call_tir_dyn约定(如下所示,通过添加symbolic_int_shape),还允许指示符号整数值的 ShapeExpr 解压缩并调用函数。

def call_tir_dyn(shape, lowlevel_func, inputs, symbolic_var_tuple):
    out = alloc_tensor(shape)
    lowlevel_func(*inputs, out, *symbolic_var_tuple)
    return out

需要单独的 symbolic_var_tuple,主要是因为我们需要,因为:

  • C0:我们需要在输出值之后传递符号 int 提示
  • C1:relax符号整数本身的设计约束不显示为 rx.Expr

请注意,如果我们最终决定允许进一步混合 TIR/relax expr,则可以取消 C1,但这需要更多的设计考虑(另请参阅其他考虑因素)。

有几种不同的方法可以在编译和 VM 中实现这一点。目前使用的方法是

  • 将 args unpack约定引入寄存器,这意味着我们在进入打包的 fn 之前解包 r1(当它是 ShapeTuple 或 Array 时)。我们通过引入runtime intrinsic(vm.call_tir_dyn)来解决这个。
r1 = [2, 4]
call fn, r0, *r1
# semantics
call fn, r0, 2, 4

另一种解决方案是通过寄存器掩码指示解包的 args 调用约定,并与 python 的解包调用 conv 保持一致。

其它考虑 #

请注意,有许多变体可以使staging和重载更容易。目前,该说明概述了可管理工作负载的可能第一步。我们可以考虑未来可能的后续步骤。 其它可能的想法:

  • 允许计算作为宏扩展的 relax 脚本格式的宏语法糖,这当然是可行的。然而,当我们开始构建导入器和降级机制时,我们可能需要一个 staging构建器样式的 API
  • 允许 rx.Expr 像 te.Tensor 一样直接运算。这将导致更深层次的 TIR/Relax expr。这将是一个长期的考虑因素,需要对设计权衡进行更多思考