函数注册:_init_api_prefix 函数会遍历所有 C++ 中的全局函数,找到以 “tir.schedule” 开头的函数,并将其注册到 Python 中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@_register_object("tir.LoopRV")classLoopRV(Object):"""A random variable that refers to a loop"""def__init__(self)->None:"""Construct a new LoopRV."""self.__init_handle_by_constructor__(_ffi_api.LoopRV# type: ignore # pylint: disable=no-member)"""FFI APIs for tvm.tir.schedule"""importtvm._ffitvm._ffi._init_api("tir.schedule",__name__)# pylint: disable=protected-access
@_register_object("tir.BlockRV")classBlockRV(Object):"""A random variable that refers to a block"""def__init__(self)->None:"""Construct a new BlockRV."""self.__init_handle_by_constructor__(_ffi_api.BlockRV# type: ignore # pylint: disable=no-member)
importtvmfromtvmimportteimportnumpyasnp# Declare some variables for use latern=te.var("n")m=te.var("m")# Declare a matrix element-wise multiplyA=te.placeholder((m,n),name="A")B=te.placeholder((m,n),name="B")C=te.compute((m,n),lambdai,j:A[i,j]*B[i,j],name="C")print(type(A))s=te.create_schedule([C.op])# lower 将计算从定义转换成可以调用的IRModuletvm.lower(s,[A,B,C],simple_mode=True).show()
tvm.lower
tvm.lower 函数是 TVM 中用于将计算图(Compute Graph)降低(lower)到更低级别的表示形式,例如 Relay IR 或 TensorIR ,该函数会返回一个IRModule.
# from tvm.script import ir as I# from tvm.script import tir as T@I.ir_moduleclassModule:@T.prim_funcdefmain(A:T.handle,B:T.handle,C:T.handle):T.func_attr({"from_legacy_te_schedule":T.bool(True),"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A_1=T.match_buffer(A,(m,n),strides=("stride","stride"),buffer_type="auto")B_1=T.match_buffer(B,(m,n),strides=("stride","stride"),buffer_type="auto")C_1=T.match_buffer(C,(m,n),strides=("stride","stride"),buffer_type="auto")fori,jinT.grid(m,n):C_2=T.Buffer((C_1.strides[0]*m,),data=C_1.data,buffer_type="auto")A_2=T.Buffer((A_1.strides[0]*m,),data=A_1.data,buffer_type="auto")B_2=T.Buffer((B_1.strides[0]*m,),data=B_1.data,buffer_type="auto")C_2[i*C_1.strides[0]+j*C_1.strides[1]]=A_2[i*A_1.strides[0]+j*A_1.strides[1]]*B_2[i*B_1.strides[0]+j*B_1.strides[1]]
A=te.placeholder((m,n),name="A")B=te.compute((m,n),lambdai,j:A[i,j]*2,name="B")func=te.create_prim_func([A,B])func=func.with_attr("global_symbol","main")ir_module=IRModule({"main":func})ir_module.show()#----------TensorIR Before Fuse--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))# with T.block("root"):fori,jinT.grid(m,n):withT.block("B"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)
sch=tvm.tir.Schedule(ir_module)block_B=sch.get_block("B")i,j=sch.get_loops(block_B)sch.fuse(i,j)sch.mod.show()#----------TensorIR After Fuse--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))# with T.block("root"):fori_j_fusedinrange(m*n):withT.block("B"):v_i=T.axis.spatial(m,i_j_fused%(n*m)//n)v_j=T.axis.spatial(n,i_j_fused%n)T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)
A=te.placeholder((m,),name="A")B=te.compute((m,),lambdai:A[i]*2,name="B")s=te.create_schedule(B.op)func=te.create_prim_func([A,B])func=func.with_attr("global_symbol","main")ir_module=IRModule({"main":func})ir_module.show()#----------TensorIR Before Split--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m=T.int32()A=T.match_buffer(var_A,(m,))B=T.match_buffer(var_B,(m,))# with T.block("root"):foriinrange(m):withT.block("B"):v_i=T.axis.spatial(m,i)T.reads(A[v_i])T.writes(B[v_i])B[v_i]=A[v_i]*T.float32(2.0)
sch=tvm.tir.Schedule(ir_module)block_b=sch.get_block("B")i,=sch.get_loops(block_b)sch.split(i,factors=[None,32])sch.mod.show()#----------TensorIR After Split--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m=T.int32()A=T.match_buffer(var_A,(m,))B=T.match_buffer(var_B,(m,))# with T.block("root"):fori_0,i_1inT.grid((m+31)//32,32):withT.block("B"):v_i=T.axis.spatial(m,i_0*32+i_1)T.where(i_0*32+i_1<m)T.reads(A[v_i])T.writes(B[v_i])B[v_i]=A[v_i]*T.float32(2.0)
m=128n=128A=te.placeholder((m,n),name="A")B=te.compute((m,n),lambdai,j:A[i,j]*2,name="B")func=te.create_prim_func([A,B])func=func.with_attr("global_symbol","main")ir_module=IRModule({"main":func})ir_module.show()#----------TensorIR Before Loop Partition--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(A:T.Buffer((128,),"float32"),B:T.Buffer((128,),"float32")):T.func_attr({"tir.noalias":T.bool(True)})# with T.block("root"):foriinrange(128):withT.block("B"):v_i=T.axis.spatial(128,i)T.reads(A[v_i])T.writes(B[v_i])B[v_i]=A[v_i]*T.float32(2.0)
sch=tvm.tir.Schedule(ir_module)block_B=sch.get_block("B")[i]=sch.get_loops(block_B)# return a list of LoopRVsch.loop_partition(i,[2,64])sch.mod.show()#----------TensorIR After Loop Partition--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(A:T.Buffer((128,),"float32"),B:T.Buffer((128,),"float32")):T.func_attr({"tir.noalias":T.bool(True)})withT.block("root"):T.reads()T.writes()withT.block("B_i_common"):T.reads()T.writes()withT.block("B_i0_partition"):T.reads()T.writes()fori0inrange(2):withT.block("B_i0"):v_i=T.axis.spatial(2,i0)T.reads(A[0:2])T.writes(B[0:2])B[v_i]=A[v_i]*T.float32(2.0)withT.block("B_i1_partition"):T.reads()T.writes()fori1inrange(2,66):withT.block("B_i1"):v_i=T.axis.spatial((2,66),i1)T.reads(A[2:66])T.writes(B[2:66])B[v_i]=A[v_i]*T.float32(2.0)withT.block("B_i2_partition"):T.reads()T.writes()fori2inrange(66,128):withT.block("B_i2"):v_i=T.axis.spatial((66,128),i2)T.reads(A[66:128])T.writes(B[66:128])B[v_i]=A[v_i]*T.float32(2.0)
A=te.placeholder((m,n),name="A")B=te.placeholder((m,n),name="B")C=te.compute((m,n),lambdai,j:A[i,j]*B[i,j],name="C")func=te.create_prim_func([A,B,C])func=func.with_attr("global_symbol","main")ir_module=IRModule({"main":func})ir_module.show()#----------TensorIR Before Parallel--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):fori,jinT.grid(m,n):withT.block("C"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j],B[v_i,v_j])T.writes(C[v_i,v_j])C[v_i,v_j]=A[v_i,v_j]*B[v_i,v_j]
sch=tvm.tir.Schedule(ir_module)block_c=sch.get_block("C")i,j=sch.get_loops(block_c)sch.parallel(i)sch.mod.show()#----------TensorIR After Parallel--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):fori,jinT.grid(m,n):withT.block("C"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j],B[v_i,v_j])T.writes(C[v_i,v_j])C[v_i,v_j]=A[v_i,v_j]*B[v_i,v_j]
sch=tvm.tir.Schedule(ir_module)block_b=sch.get_block("B")i,j=sch.get_loops(block_b)sch.vectorize(j)sch.mod.show()#----------TensorIR After Vectorize--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))# with T.block("root"):foriinrange(m):forjinT.vectorized(n):withT.block("B"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)
sch=tvm.tir.Schedule(ir_module)block_b=sch.get_block("B")i,j=sch.get_loops(block_b)sch.unroll(i)sch.mod.show()#----------TensorIR After Vectorize--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))B=T.match_buffer(var_B,(m,n))# with T.block("root"):foriinT.unroll(m):forjinrange(n):withT.block("B"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)
A=te.placeholder((m,n),name="A")B=te.compute((m,n),lambdai,j:A[i,j]*2,name="B")C=te.compute((m,n),lambdai,j:B[i,j]+1,name="C")func=te.create_prim_func([A,C])fuc=func.with_attr({"global_symbol":"main"})ir_module=IRModule({"main":func})ir_module.show()#----------TensorIR Before Compute_at--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):B=T.alloc_buffer((m,n))fori,jinT.grid(m,n):withT.block("B"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)fori,jinT.grid(m,n):withT.block("C"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(B[v_i,v_j])T.writes(C[v_i,v_j])C[v_i,v_j]=B[v_i,v_j]+T.float32(1.0)
sch=tvm.tir.Schedule(ir_module)block=sch.get_block("B")loop,_=sch.get_loops(sch.get_block("C"))sch.compute_at(block,loop,preserve_unit_loops=False)''' same way
block = sch.get_block("C")
loop, _ = sch.get_loops(sch.get_block("B"))
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
'''sch.mod.show()#----------TensorIR After Compute_at--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):B=T.alloc_buffer((m,n))foriinrange(m):forax0inrange(n):withT.block("B"):v_i,v_j=T.axis.remap("SS",[i,ax0])T.reads(A[v_i,v_j])T.writes(B[v_i,v_j])B[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)forjinrange(n):withT.block("C"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(B[v_i,v_j])T.writes(C[v_i,v_j])C[v_i,v_j]=B[v_i,v_j]+T.float32(1.0)
sch=tvm.tir.Schedule(ir_module)block=sch.get_block("B")# same: sch.reverse_compute_inline(sch.get_block("C"))sch.compute_inline(block)sch.mod.show()#----------TensorIR After Compute_inline--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,n=T.int32(),T.int32()A=T.match_buffer(var_A,(m,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):fori,jinT.grid(m,n):withT.block("C"):v_i,v_j=T.axis.remap("SS",[i,j])T.reads(A[v_i,v_j])T.writes(C[v_i,v_j])C[v_i,v_j]=A[v_i,v_j]*T.float32(2.0)+T.float32(1.0)
l=te.var("l")A=te.placeholder((m,l),name="A")B=te.placeholder((l,n),name="B")k=te.reduce_axis((0,l),name="l")C=te.compute((m,n),lambdai,j:te.sum(A[i,k]*B[k,j],axis=k),name="C")#----------TensorIR Before Decompose Reduction--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,l=T.int32(),T.int32()A=T.match_buffer(var_A,(m,l))n=T.int32()B=T.match_buffer(var_B,(l,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):fori,j,l_1inT.grid(m,n,l):withT.block("C"):v_i,v_j,v_l=T.axis.remap("SSR",[i,j,l_1])T.reads(A[v_i,v_l],B[v_l,v_j])T.writes(C[v_i,v_j])withT.init():C[v_i,v_j]=T.float32(0.0)C[v_i,v_j]=C[v_i,v_j]+A[v_i,v_l]*B[v_l,v_j]
调用 decompose_reduction 方法后将块 C 分解成一个初始化块和一个更新块,并将初始化块插入到 i 循环之前,对应的TensorIR如下
sch=tvm.tir.Schedule(ir_module)block_c=sch.get_block("C")i,j,k=sch.get_loops(block_c)sch.decompose_reduction(block_c,i)sch.mod.show()#----------TensorIR After Decompose Reduction--------------@I.ir_moduleclassModule:@T.prim_funcdefmain(var_A:T.handle,var_B:T.handle,var_C:T.handle):T.func_attr({"tir.noalias":T.bool(True)})m,l=T.int32(),T.int32()A=T.match_buffer(var_A,(m,l))n=T.int32()B=T.match_buffer(var_B,(l,n))C=T.match_buffer(var_C,(m,n))# with T.block("root"):fori_init,j_initinT.grid(m,n):withT.block("C_init"):v_i,v_j=T.axis.remap("SS",[i_init,j_init])T.reads()T.writes(C[v_i,v_j])C[v_i,v_j]=T.float32(0.0)fori,j,l_1inT.grid(m,n,l):withT.block("C_update"):v_i,v_j,v_l=T.axis.remap("SSR",[i,j,l_1])T.reads(C[v_i,v_j],A[v_i,v_l],B[v_l,v_j])T.writes(C[v_i,v_j])