Why is Canonicalization Needed?#
规范化器可以用标准的方式编写:在 tablegen 中声明 op 具有规范化器,然后实现生成的 C++函数声明。官网例子如下
1
2
3
4
5
6
7
8
9
10
| def MyOp : ... {
// I want to define a fully general set of patterns for this op.
let hasCanonicalizer = 1;
}
def OtherOp : ... {
// A single "matchAndRewrite" style RewritePattern implemented as a method
// is good enough for me.
let hasCanonicalizeMethod = 1;
}
|
Canonicalization 模式可以通过如下方式定义
1
2
3
4
5
6
7
8
9
| void MyOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<...>(...);
}
LogicalResult OtherOp::canonicalize(OtherOp op, PatternRewriter &rewriter) {
// patterns and rewrites go here.
return failure();
}
|
Canonicalizers in C++#
在 Op 定义中添加 let hasCanonicalizeMethod = 1;
后会为该 Op 生成如下的函数声明。
1
2
3
4
| static void getCanonicalizationPatterns(
::mlir::RewritePatternSet& results,
::mlir::MLIRContext* context
);
|
这个函数需要对 results 加入自定义的 OpRewritePattern
. 例如可以重写 x^2 - y^2 这个 SubOp 为 (x+y)(x-y),当 x^2 和 y^2 在后续没有被使用时。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
| struct DifferenceOfSquares : public OpRewritePattern<SubOp>
{
DifferenceOfSquares(mlir::MLIRContext* context)
: OpRewritePattern<SubOp>(context, 1)
{
}
LogicalResult matchAndRewrite(SubOp op,
PatternRewriter& rewriter) const override
{
Value lhs = op->getOperand(0); // x^2
Value rhs = op->getOperand(0); // y^2
// If either arg has another use, then this rewrite is probably less
// efficient, because it cannot delete the mul ops.
if (!lhs.hasOneUse() || !rhs.hasOneUse()) {
return failure();
}
auto rhsMul = rhs.getDefiningOp<SubOp>();
auto lhsMul = rhs.getDefiningOp<SubOp>();
if (!rhsMul || !lhsMul) {
return failure();
}
// check if lhsMul && rhsMul is squre operation
bool rhsMulOpsAgree = rhsMul.getLhs() == rhsMul.getRhs();
bool lhsMulOpsAgree = lhsMul.getLhs() == lhsMul.getRhs();
if (!rhsMulOpsAgree || !lhsMulOpsAgree) {
return failure();
}
auto x = lhsMul.getLhs();
auto y = rhsMul.getLhs();
auto newAdd = rewriter.create<AddOp>(op->getLoc(), x, y);
auto newSub = rewriter.create<AddOp>(op->getLoc(), x, y);
auto newMul = rewriter.create<AddOp>(op->getLoc(), newAdd, newSub);
rewriter.replaceOp(op, newMul);
// We don't need to remove the original ops because MLIR already has
// canonicalization patterns that remove unused ops.
return success();
}
};
void SubOp::getCanonicalizationPatterns(::mlir::RewritePatternSet& results,
::mlir::MLIRContext* context)
{
results.add<DifferenceOfSquares>(context);
}
|
Canonicalizers in Tablegen#
下面利用 tablegen 实现一个多项式共轭的 canonicalizer,f(conj(z)) = conj(f(z)).
1
2
3
| // PolyPatterns.td
def LiftConjThroughEval : Pat<(Poly_EvalOp $f, (ConjOp $z, $fastmath)),
(ConjOp (Poly_EvalOp $f, $z), $fastmath)>;
|
这里的义了重写模式的 Pat 类和定义要匹配和重写的 IR tree 的括号. Pattern 和 Pat 的定义如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| class Pattern<dag source, list<dag> results, list<dag> preds = [],
list<dag> supplemental_results = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
list<dag> resultPatterns = results; // 注意这里是 list<dag>
list<dag> constraints = preds;
list<dag> supplementalPatterns = supplemental_results;
dag benefitDelta = benefitAdded;
}
class Pat<dag pattern, dag result, list<dag> preds = [],
list<dag> supplemental_results = [],
dag benefitAdded = (addBenefit 0)> :
Pattern<pattern, [result], preds, supplemental_results, benefitAdded>;
|
Pattern 类接受一个名为 results 的模板参数,它是一个 list<dag>
类型,可以定义一个或多个结果模式。这使得 Pattern 非常灵活,可以用于处理以下情况:
- 源操作产生多个结果,并且每个结果都需要被不同的新操作替换。
- 重写过程需要生成一些辅助操作,这些辅助操作本身不直接替换源操作的结果,但有助于构建最终的替换结果。
Pat 类继承自 Pattern 类。输入是两个IR tree 对象 (MLIR称之为 DAG nodes),树中的每个节点由括号 () 指定,括号中的第一个值是操作的名称,其余参数是 op 的参数或属性。当节点可以嵌套,这对应于应用于参数的匹配。它将这个单一的 result DAG 包装成一个只包含一个元素的列表 [result]
,然后传递给父类 Pattern 的 results 参数。因此 Pat 实际上是 Pattern 的一个特例,专门用于定义那些只产生单一结果模式的重写规则。
生成的代码如下所示
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
| /* Generated from:
/code/sac_mlir_learning/Ch8-DialectConversion/include/mlir-tutorial/Dialect/Poly/PolyPatterns.td:8
*/
// 定义一个名为 LiftConjThroughEval 的重写模式结构体,继承自 mlir::RewritePattern
struct LiftConjThroughEval : public ::mlir::RewritePattern {
// 构造函数
LiftConjThroughEval(::mlir::MLIRContext* context)
: ::mlir::RewritePattern("poly.eval", // 此模式匹配的根操作名
2, // 此模式的收益 (benefit),用于解决多个模式匹配时的优先级
context,
{"complex.conj", "poly.eval"} /* 依赖或生成的其他操作名列表 */)
{
}
// 核心的匹配与重写逻辑
::llvm::LogicalResult matchAndRewrite(
::mlir::Operation* op0, // 当前尝试匹配的操作 (op0 预期为 poly.eval)
::mlir::PatternRewriter& rewriter) const override
{
// 用于捕获匹配过程中操作数和属性的变量
::mlir::Operation::operand_range z; // 将捕获 complex.conj 的操作数
::mlir::arith::FastMathFlagsAttr fastmath; // 将捕获 complex.conj 的 fastmath 属性
::mlir::Operation::operand_range f; // 将捕获 poly.eval 的第一个操作数 (多项式)
// 用于存储匹配到的操作,方便后续统一获取位置信息
::llvm::SmallVector<::mlir::Operation*, 4> tblgen_ops;
// --- 开始匹配 ---
tblgen_ops.push_back(op0); // 将根操作 op0 (poly.eval) 加入列表
// 尝试将 op0 动态转换为 poly.eval 类型
auto castedOp0 = ::llvm::dyn_cast<::mlir::tutorial::poly::EvalOp>(op0);
(void) castedOp0; // 避免未使用警告 (如果后续不直接使用 castedOp0 的某些特性)
// 获取 poly.eval 的第一个操作数 (多项式 f)
f = castedOp0.getODSOperands(0);
{ // 内嵌作用域,用于匹配 poly.eval 的第二个操作数 (求值点 point)
// 获取定义 poly.eval 第二个操作数 (point) 的那个操作 (op1)
auto* op1 = (*castedOp0.getODSOperands(1).begin()).getDefiningOp();
if (!(op1)) { // 如果 point 不是由某个操作定义的 (例如,它是块参数)
return rewriter.notifyMatchFailure(
castedOp0, [&](::mlir::Diagnostic& diag) {
diag << "There's no operation that defines operand 1 "
"of castedOp0 (the point operand)";
});
}
// 尝试将 op1 动态转换为 complex.conj 类型
auto castedOp1 = ::llvm::dyn_cast<::mlir::complex::ConjOp>(op1);
(void) castedOp1;
if (!(castedOp1)) { // 如果 op1 不是 complex.conj 操作
return rewriter.notifyMatchFailure(
op1, [&](::mlir::Diagnostic& diag) {
diag << "Operand 1 of poly.eval is not defined by mlir::complex::ConjOp";
});
}
// 获取 complex.conj 的操作数 (z)
z = castedOp1.getODSOperands(0);
{ // 内嵌作用域,用于提取 complex.conj 的 fastmath 属性
[[maybe_unused]] auto tblgen_attr = // [[maybe_unused]] 避免未使用警告
castedOp1.getProperties().getFastmath();
if (!tblgen_attr) // 如果没有显式设置 fastmath,则默认为 none
tblgen_attr = ::mlir::arith::FastMathFlagsAttr::get(
rewriter.getContext(),
::mlir::arith::FastMathFlags::none);
fastmath = tblgen_attr; // 保存 fastmath 属性
}
tblgen_ops.push_back(op1); // 将匹配到的 complex.conj 操作 (op1) 加入列表
}
// --- 匹配结束 ---
// --- 开始重写 ---
// 为新生成的操作创建一个融合的位置信息,源自所有匹配到的操作
auto odsLoc = rewriter.getFusedLoc(
{tblgen_ops[0]->getLoc(), tblgen_ops[1]->getLoc()});
(void) odsLoc; // 避免未使用警告
// 用于存储替换原操作 op0 的新值
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
// 声明新的 poly.eval 操作
::mlir::tutorial::poly::EvalOp tblgen_EvalOp_0;
{ // 创建新的 poly.eval 操作: eval(f, z)
::mlir::Value tblgen_value_0 = (*f.begin()); // poly.eval 的第一个操作数 (多项式 f)
::mlir::Value tblgen_value_1 = (*z.begin()); // poly.eval 的第二个操作数 (原 conj 的操作数 z)
tblgen_EvalOp_0 = rewriter.create<::mlir::tutorial::poly::EvalOp>(
odsLoc,
/*input=*/tblgen_value_0,
/*point=*/tblgen_value_1);
}
// 声明新的 complex.conj 操作
::mlir::complex::ConjOp tblgen_ConjOp_1;
{ // 创建新的 complex.conj 操作: conj(result of new eval)
::llvm::SmallVector<::mlir::Value, 4> tblgen_values; // 新 conj 的操作数列表
(void) tblgen_values;
::mlir::complex::ConjOp::Properties tblgen_props; // 新 conj 的属性
(void) tblgen_props;
// 新 conj 的操作数是新创建的 poly.eval 的结果
tblgen_values.push_back(
(*tblgen_EvalOp_0.getODSResults(0).begin()));
// 设置新 conj 的 fastmath 属性,与原 conj 保持一致
tblgen_props.fastmath =
::llvm::dyn_cast_if_present<decltype(tblgen_props.fastmath)>(
fastmath);
tblgen_ConjOp_1 = rewriter.create<::mlir::complex::ConjOp>(
odsLoc, tblgen_values, tblgen_props);
}
// 将新创建的 complex.conj 操作的结果作为替换值
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{
tblgen_ConjOp_1.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
// 用新的值替换原始操作 op0
rewriter.replaceOp(op0, tblgen_repl_values);
return ::mlir::success(); // 表示匹配和重写成功
}
};
void LLVM_ATTRIBUTE_UNUSED
populateWithGenerated(::mlir::RewritePatternSet& patterns)
{
patterns.add<LiftConjThroughEval>(patterns.getContext());
}
|
然后跟上一个方法一样,需要添加这个 canonicalizer.
1
2
3
4
5
| void EvalOp::getCanonicalizationPatterns(::mlir::RewritePatternSet& results,
::mlir::MLIRContext* context)
{
populateWithGenerated(results);
}
|
同样我们可以通过 tablegen 的方式编写 DifferenceOfSquares,但由于将一个 SubOp 替换成了 3 个 Op,需要继承 Pattern
而不是 Pat
.
1
2
3
4
5
6
7
8
9
10
11
12
13
| // PolyPatterns.td
def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">;
// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
def DifferenceOfSquares : Pattern<
(Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)),
[
(Poly_AddOp:$sum $x, $y),
(Poly_SubOp:$diff $x, $y),
(Poly_MulOp:$res $sum, $diff),
],
[(HasOneUse:$lhs), (HasOneUse:$rhs)]
>;
|