1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
| @I.ir_module
class Module:
@T.prim_func(private=True)
def my_conv2d(x: T.Buffer((T.int64(4), T.int64(1), T.int64(28), T.int64(28)), "float32"), B: T.Buffer((T.int64(32), T.int64(1), T.int64(3), T.int64(3)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(1)), "float32"), compute: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
conv2d = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)))
for n, co, oh, ow, k, r, s in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26), T.int64(1), T.int64(3), T.int64(3)):
with T.block("conv2d"):
v_n, v_co, v_oh, v_ow, v_k, v_r, v_s = T.axis.remap("SSSSRRR", [n, co, oh, ow, k, r, s])
T.reads(x[v_n, v_k, v_oh + v_r, v_ow + v_s], B[v_co, v_k, v_r, v_s])
T.writes(conv2d[v_n, v_co, v_oh, v_ow])
with T.init():
conv2d[v_n, v_co, v_oh, v_ow] = T.float32(0.0)
conv2d[v_n, v_co, v_oh, v_ow] = conv2d[v_n, v_co, v_oh, v_ow] + x[v_n, v_k, v_oh + v_r, v_ow + v_s] * B[v_co, v_k, v_r, v_s]
for n, co, oh, ow in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26)):
with T.block("compute"):
v_n, v_co, v_oh, v_ow = T.axis.remap("SSSS", [n, co, oh, ow])
T.reads(conv2d[v_n, v_co, v_oh, v_ow], C[T.int64(0), v_co, T.int64(0), T.int64(0)])
T.writes(compute[v_n, v_co, v_oh, v_ow])
compute[v_n, v_co, v_oh, v_ow] = conv2d[v_n, v_co, v_oh, v_ow] + C[T.int64(0), v_co, T.int64(0), T.int64(0)]
@T.prim_func(private=True)
def my_flatten(lv2: T.Buffer((T.int64(4), T.int64(32), T.int64(13), T.int64(13)), "float32"), compute: T.Buffer((T.int64(4), T.int64(5408)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, i in T.grid(T.int64(4), T.int64(5408)):
with T.block("compute"):
v_n, v_i = T.axis.remap("SS", [n, i])
T.reads(lv2[v_n, v_i // T.int64(169), v_i % T.int64(169) // T.int64(13), v_i % T.int64(13)])
T.writes(compute[v_n, v_i])
compute[v_n, v_i] = lv2[v_n, v_i // T.int64(169), v_i % T.int64(169) // T.int64(13), v_i % T.int64(13)]
@T.prim_func(private=True)
def my_linear(lv3: T.Buffer((T.int64(4), T.int64(5408)), "float32"), B: T.Buffer((T.int64(100), T.int64(5408)), "float32"), C: T.Buffer((T.int64(1), T.int64(100)), "float32"), compute: T.Buffer((T.int64(4), T.int64(100)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
compute_1 = T.alloc_buffer((T.int64(4), T.int64(100)))
for i, j, FI in T.grid(T.int64(4), T.int64(100), T.int64(5408)):
with T.block("compute"):
v_i, v_j, v_FI = T.axis.remap("SSR", [i, j, FI])
T.reads(lv3[v_i, v_FI], B[v_j, v_FI])
T.writes(compute_1[v_i, v_j])
with T.init():
compute_1[v_i, v_j] = T.float32(0.0)
compute_1[v_i, v_j] = compute_1[v_i, v_j] + lv3[v_i, v_FI] * B[v_j, v_FI]
for i, j in T.grid(T.int64(4), T.int64(100)):
with T.block("compute_1"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(C[T.int64(0), v_j], compute_1[v_i, v_j])
T.writes(compute[v_i, v_j])
compute[v_i, v_j] = C[T.int64(0), v_j] + compute_1[v_i, v_j]
@T.prim_func(private=True)
def my_linear1(lv5: T.Buffer((T.int64(4), T.int64(100)), "float32"), B: T.Buffer((T.int64(10), T.int64(100)), "float32"), C: T.Buffer((T.int64(1), T.int64(10)), "float32"), compute: T.Buffer((T.int64(4), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
compute_1 = T.alloc_buffer((T.int64(4), T.int64(10)))
for i, j, FI in T.grid(T.int64(4), T.int64(10), T.int64(100)):
with T.block("compute"):
v_i, v_j, v_FI = T.axis.remap("SSR", [i, j, FI])
T.reads(lv5[v_i, v_FI], B[v_j, v_FI])
T.writes(compute_1[v_i, v_j])
with T.init():
compute_1[v_i, v_j] = T.float32(0.0)
compute_1[v_i, v_j] = compute_1[v_i, v_j] + lv5[v_i, v_FI] * B[v_j, v_FI]
for i, j in T.grid(T.int64(4), T.int64(10)):
with T.block("compute_1"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(C[T.int64(0), v_j], compute_1[v_i, v_j])
T.writes(compute[v_i, v_j])
compute[v_i, v_j] = C[T.int64(0), v_j] + compute_1[v_i, v_j]
@T.prim_func(private=True)
def my_maxpool2d(lv1: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32"), maxpool2d: T.Buffer((T.int64(4), T.int64(32), T.int64(13), T.int64(13)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, co, oh, ow, i, j in T.grid(T.int64(4), T.int64(32), T.int64(13), T.int64(13), T.int64(2), T.int64(2)):
with T.block("maxpool2d"):
v_n, v_co, v_oh, v_ow, v_i, v_j = T.axis.remap("SSSSRR", [n, co, oh, ow, i, j])
T.reads(lv1[v_n, v_co, v_oh * T.int64(2) + v_i, v_ow * T.int64(2) + v_j])
T.writes(maxpool2d[v_n, v_co, v_oh, v_ow])
with T.init():
maxpool2d[v_n, v_co, v_oh, v_ow] = T.float32(-340282346638528859811704183484516925440.0)
maxpool2d[v_n, v_co, v_oh, v_ow] = T.max(maxpool2d[v_n, v_co, v_oh, v_ow], lv1[v_n, v_co, v_oh * T.int64(2) + v_i, v_ow * T.int64(2) + v_j])
@T.prim_func(private=True)
def my_relu(lv: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32"), compute: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26)):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(lv[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2, v_i3], T.float32(0.0))
@T.prim_func(private=True)
def my_relu1(lv4: T.Buffer((T.int64(4), T.int64(100)), "float32"), compute: T.Buffer((T.int64(4), T.int64(100)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(4), T.int64(100)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv4[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv4[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def my_softmax(lv6: T.Buffer((T.int64(4), T.int64(10)), "float32"), compute: T.Buffer((T.int64(4), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
compute_1 = T.alloc_buffer((T.int64(4),))
compute_2 = T.alloc_buffer((T.int64(4), T.int64(10)))
compute_3 = T.alloc_buffer((T.int64(4),))
for i, c in T.grid(T.int64(4), T.int64(10)):
with T.block("compute"):
v_i, v_c = T.axis.remap("SR", [i, c])
T.reads(lv6[v_i, v_c])
T.writes(compute_1[v_i])
with T.init():
compute_1[v_i] = T.float32(-340282346638528859811704183484516925440.0)
compute_1[v_i] = T.max(compute_1[v_i], lv6[v_i, v_c])
for i, j in T.grid(T.int64(4), T.int64(10)):
with T.block("compute_1"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(lv6[v_i, v_j], compute_1[v_i])
T.writes(compute_2[v_i, v_j])
compute_2[v_i, v_j] = T.exp(lv6[v_i, v_j] - compute_1[v_i])
for i, c in T.grid(T.int64(4), T.int64(10)):
with T.block("compute_2"):
v_i, v_c = T.axis.remap("SR", [i, c])
T.reads(compute_2[v_i, v_c])
T.writes(compute_3[v_i])
with T.init():
compute_3[v_i] = T.float32(0.0)
compute_3[v_i] = compute_3[v_i] + compute_2[v_i, v_c]
for i, j in T.grid(T.int64(4), T.int64(10)):
with T.block("compute_3"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(compute_2[v_i, v_j], compute_3[v_i])
T.writes(compute[v_i, v_j])
compute[v_i, v_j] = compute_2[v_i, v_j] / compute_3[v_i]
@R.function
def main(x: R.Tensor((4, 1, 28, 28), dtype="float32")) -> R.Tensor((4, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.my_conv2d, (x, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((4, 32, 26, 26), dtype="float32"))
lv1 = R.call_tir(cls.my_relu, (lv,), out_sinfo=R.Tensor((4, 32, 26, 26), dtype="float32"))
lv2 = R.call_tir(cls.my_maxpool2d, (lv1,), out_sinfo=R.Tensor((4, 32, 13, 13), dtype="float32"))
lv3 = R.call_tir(cls.my_flatten, (lv2,), out_sinfo=R.Tensor((4, 5408), dtype="float32"))
lv4 = R.call_tir(cls.my_linear, (lv3, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((4, 100), dtype="float32"))
lv5 = R.call_tir(cls.my_relu1, (lv4,), out_sinfo=R.Tensor((4, 100), dtype="float32"))
lv6 = R.call_tir(cls.my_linear1, (lv5, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5]), out_sinfo=R.Tensor((4, 10), dtype="float32"))
lv7 = R.call_tir(cls.my_softmax, (lv6,), out_sinfo=R.Tensor((4, 10), dtype="float32"))
gv: R.Tensor((4, 10), dtype="float32") = lv7
R.output(gv)
return gv
|