Skip to content

Commit c32d32c

Browse files
committed
Add some canonicalizers for muxes and registers
This commit adds four new canonicalization patterns to FIRRTL. mux(cond, 0, b) -> and(not(cond), b) mux(cond, 1, b) -> or(cond, b) mux(cond, a, 0) -> and(cond, a) mux(cond, a, 1) -> or(not(cond), a) These canonicalizations are already present for the comb dialect, but we want to run these canonicalizers before lowering layers, which can obscure constant ops behind ports. The problem with these mux canonicalizers is, they conflict with a register canonicalizer. This register canonicalizer converts a register to a constant, if the register's next-value is a mux of either the register itself, or a constant. For example: connect(reg, mux(reset, 0, reg)) ==> reg -> 0 These new canonicalizers would transform the connect to: connect(reg, and(not(reset), reg)) ...which prevents the register canonicalizer from running. To get this behaviour back, this PR adds four additional canonicalizations for both registers and reg resets: For registers, the canonicalizers are: connect(reg, and(reg, x)) ==> reg -> 0 connect(reg, and(x, reg)) ==> reg -> 0 connect(reg, or(reg, x)) ==> reg -> 1 connect(reg, or(x, reg)) ==> reg -> 1 For regresets, we have the same canonicalizers, but with an additional check: the reset value must be a constant zero or one. reset(reg) = 0 ==> connect(reg, and(reg, x)) ==> reg -> 0 reset(reg) = 0 ==> connect(reg, and(x, reg)) ==> reg -> 0 reset(reg) = 1 ==> connect(reg, or(reg, x)) ==> reg -> 1 reset(reg) = 1 ==> connect(reg, or(x, reg)) ==> reg -> 1
1 parent 638681a commit c32d32c

File tree

3 files changed

+208
-11
lines changed

3 files changed

+208
-11
lines changed

include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,31 +657,31 @@ def MuxNEQ : Pat<
657657
[(EqualTypes $x, $y), (KnownWidth $x)]>;
658658

659659
// mux(cond, 0, b) -> and(~cond, b)
660-
def MuxLHS0 : Pat<
660+
def MuxLhsZero : Pat<
661661
(MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b),
662662
(MoveNameHint $old, (AndPrimOp (NotPrimOp $cond), $b)),
663663
[(ZeroConstantOp $a),
664664
(EqualTypes $cond, $a),
665665
(EqualTypes $cond, $b)]>;
666666

667667
// mux(cond, 1, b) -> or(cond, b)
668-
def MuxLHS1 : Pat<
668+
def MuxLhsOne : Pat<
669669
(MuxPrimOp:$old $cond, (ConstantOp:$a $_), $b),
670670
(MoveNameHint $old, (OrPrimOp $cond, $b)),
671671
[(OneConstantOp $a),
672672
(EqualTypes $cond, $a),
673673
(EqualTypes $cond, $b)]>;
674674

675675
// mux(cond, a, 0) -> and(cond, a)
676-
def MuxRHS0 : Pat<
676+
def MuxRhsZero : Pat<
677677
(MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)),
678678
(MoveNameHint $old, (AndPrimOp $cond, $a)),
679679
[(ZeroConstantOp $b),
680680
(EqualTypes $cond, $a),
681681
(EqualTypes $cond, $b)]>;
682682

683683
// mux(cond, a, 1) -> or(~cond, a)
684-
def MuxRHS1 : Pat<
684+
def MuxRhsOne : Pat<
685685
(MuxPrimOp:$old $cond, $a, (ConstantOp:$b $_)),
686686
(MoveNameHint $old, (OrPrimOp (NotPrimOp $cond), $a)),
687687
[(OneConstantOp $b),

lib/Dialect/FIRRTL/FIRRTLFolds.cpp

Lines changed: 127 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,8 @@ void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
14761476
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
14771477
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
14781478
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS,
1479-
patterns::MuxPadSel, patterns::MuxLHS0, patterns::MuxLHS1,
1480-
patterns::MuxRHS0, patterns::MuxRHS1>(context);
1479+
patterns::MuxPadSel, patterns::MuxLhsZero, patterns::MuxLhsOne,
1480+
patterns::MuxRhsZero, patterns::MuxRhsOne>(context);
14811481
}
14821482

14831483
void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
@@ -2256,6 +2256,72 @@ struct FoldResetMux : public mlir::RewritePattern {
22562256
};
22572257
} // namespace
22582258

2259+
namespace {
2260+
/// This canonicalizer provides the following patterns:
2261+
///
2262+
/// reset(reg) = 0 ==> connect(reg, and(reg, x)) ==> reg -> 0
2263+
/// reset(reg) = 0 ==> connect(reg, and(x, reg)) ==> reg -> 0
2264+
///
2265+
/// reset(reg) = 1 ==> connect(reg, or(reg, x)) ==> reg -> 1
2266+
/// reset(reg) = 1 ==> connect(reg, or(x, reg)) ==> reg -> 1
2267+
struct RegResetAndOrOfSelf : public mlir::OpRewritePattern<RegResetOp> {
2268+
using OpRewritePattern::OpRewritePattern;
2269+
LogicalResult matchAndRewrite(RegResetOp op,
2270+
PatternRewriter &rewriter) const override {
2271+
// Do not fold the register away if it is important.
2272+
if (hasDontTouch(op.getOperation()) || !AnnotationSet(op).empty() ||
2273+
op.isForceable())
2274+
return failure();
2275+
2276+
// This canonicalization only applies when the register holds 1 bit.
2277+
auto type = dyn_cast<UIntType>(op.getResult().getType());
2278+
if (!type || type.getWidthOrSentinel() != 1)
2279+
return failure();
2280+
2281+
// This canonicalization only applies when the reset is a constant.
2282+
auto reset =
2283+
dyn_cast_or_null<ConstantOp>(op.getResetValue().getDefiningOp());
2284+
if (!reset)
2285+
return failure();
2286+
2287+
auto value = reset.getValue();
2288+
2289+
// Find the one true connect, or bail.
2290+
auto connect = getSingleConnectUserOf(op.getResult());
2291+
if (!connect)
2292+
return failure();
2293+
2294+
auto *src = connect.getSrc().getDefiningOp();
2295+
if (!src)
2296+
return failure();
2297+
2298+
if (value == 0) {
2299+
if (auto srcAnd = dyn_cast<AndPrimOp>(src)) {
2300+
if (srcAnd.getLhs().getDefiningOp() == op ||
2301+
srcAnd.getRhs().getDefiningOp() == op) {
2302+
rewriter.eraseOp(connect);
2303+
replaceOpAndCopyName(rewriter, op, reset.getResult());
2304+
return success();
2305+
}
2306+
}
2307+
}
2308+
2309+
if (value == 1) {
2310+
if (auto srcOr = dyn_cast<OrPrimOp>(src)) {
2311+
if (srcOr.getLhs().getDefiningOp() == op ||
2312+
srcOr.getRhs().getDefiningOp() == op) {
2313+
rewriter.eraseOp(connect);
2314+
replaceOpAndCopyName(rewriter, op, reset.getResult());
2315+
return success();
2316+
}
2317+
}
2318+
}
2319+
2320+
return failure();
2321+
}
2322+
};
2323+
} // namespace
2324+
22592325
static bool isDefinedByOneConstantOp(Value v) {
22602326
if (auto c = v.getDefiningOp<ConstantOp>())
22612327
return c.getValue().isOne();
@@ -2279,7 +2345,9 @@ canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
22792345

22802346
void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
22812347
MLIRContext *context) {
2282-
results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2348+
results
2349+
.add<patterns::RegResetWithZeroReset, FoldResetMux, RegResetAndOrOfSelf>(
2350+
context);
22832351
results.add(canonicalizeRegResetWithOneReset);
22842352
results.add(demoteForceableIfUnused<RegResetOp>);
22852353
}
@@ -3156,10 +3224,63 @@ static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
31563224
return success();
31573225
}
31583226

3227+
/// If a register latches to a constant, replace the register with a constant.
3228+
/// Recognizes the following and/or structures:
3229+
///
3230+
/// connect(reg, and(reg, x)) ==> reg -> 0
3231+
/// connect(reg, and(x, reg)) ==> reg -> 0
3232+
///
3233+
/// connect(reg, or(reg, x)) ==> reg -> 1
3234+
/// connect(reg, or(x, reg)) ==> reg -> 1
3235+
static LogicalResult foldRegAndOrOfSelf(RegOp reg, PatternRewriter &rewriter) {
3236+
// This canonicalization only applies when the register holds 1 bit.
3237+
auto type = dyn_cast<UIntType>(reg.getResult().getType());
3238+
if (!type || type.getWidthOrSentinel() != 1)
3239+
return failure();
3240+
3241+
// Find the one true connect, or bail.
3242+
auto connect = getSingleConnectUserOf(reg.getResult());
3243+
if (!connect)
3244+
return failure();
3245+
3246+
auto *src = connect.getSrc().getDefiningOp();
3247+
if (!src)
3248+
return failure();
3249+
3250+
// connect(reg, and(reg, x)) ==> reg -> 0
3251+
// connect(reg, and(x, reg)) ==> reg -> 0
3252+
if (auto srcAnd = dyn_cast<AndPrimOp>(src)) {
3253+
if (srcAnd.getLhs().getDefiningOp() == reg ||
3254+
srcAnd.getRhs().getDefiningOp() == reg) {
3255+
auto attr = getIntAttr(type, APInt(1, 0));
3256+
replaceOpWithNewOpAndCopyName<ConstantOp>(rewriter, reg, type, attr);
3257+
rewriter.eraseOp(connect);
3258+
return success();
3259+
}
3260+
}
3261+
3262+
// connect(reg, or(reg, x)) ==> reg -> 1
3263+
// connect(reg, or(x, reg)) ==> reg -> 1
3264+
if (auto srcOr = dyn_cast<OrPrimOp>(src)) {
3265+
if (srcOr.getLhs().getDefiningOp() == reg ||
3266+
srcOr.getRhs().getDefiningOp() == reg) {
3267+
auto attr = getIntAttr(type, APInt(1, 1));
3268+
replaceOpWithNewOpAndCopyName<ConstantOp>(rewriter, reg, type, attr);
3269+
rewriter.eraseOp(connect);
3270+
return success();
3271+
}
3272+
}
3273+
3274+
return failure();
3275+
}
3276+
31593277
LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3160-
if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3161-
succeeded(foldHiddenReset(op, rewriter)))
3162-
return success();
3278+
if (!hasDontTouch(op.getOperation()) && !op.isForceable()) {
3279+
if (succeeded(foldHiddenReset(op, rewriter)))
3280+
return success();
3281+
if (succeeded(foldRegAndOrOfSelf(op, rewriter)))
3282+
return success();
3283+
}
31633284

31643285
if (succeeded(demoteForceableIfUnused(op, rewriter)))
31653286
return success();

test/Dialect/FIRRTL/canonicalization.mlir

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,11 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
571571
out %out3: !firrtl.uint<1>,
572572
out %out4: !firrtl.uint<4>,
573573
out %out5: !firrtl.uint<1>,
574-
out %out6: !firrtl.uint<1>) {
574+
out %out6: !firrtl.uint<1>,
575+
out %out7: !firrtl.uint<1>,
576+
out %out8: !firrtl.uint<1>,
577+
out %out9: !firrtl.uint<1>,
578+
out %out10: !firrtl.uint<1>) {
575579
// CHECK: firrtl.matchingconnect %out, %in
576580
%0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
577581
firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4>
@@ -634,6 +638,32 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
634638
// CHECK-NEXT: mux4cell(%[[SEL]],
635639
%17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
636640
firrtl.matchingconnect %out6, %17 : !firrtl.uint<1>
641+
642+
// mux(cond, 0, x) -> and(~cond, x)
643+
// CHECK: [[V1:%.+]] = firrtl.not %cond
644+
// CHECK-NEXT: [[V2:%.+]] = firrtl.and [[V1]], %val1
645+
// CHECK-NEXT: firrtl.matchingconnect %out7, [[V2]]
646+
%18 = firrtl.mux (%cond, %c0_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
647+
firrtl.connect %out7, %18 : !firrtl.uint<1>, !firrtl.uint<1>
648+
649+
// mux(cond, 1, x) -> or(cond, x)
650+
// CHECK: [[V:%.+]] = firrtl.or %cond, %val1
651+
// CHECK-NEXT: firrtl.matchingconnect %out8, [[V]]
652+
%19 = firrtl.mux (%cond, %c1_ui1, %val1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
653+
firrtl.connect %out8, %19 : !firrtl.uint<1>, !firrtl.uint<1>
654+
655+
// mux(cond, x, 0) -> and(cond, x)
656+
// CHECK: [[V:%.+]] = firrtl.and %cond, %val1
657+
// CHECK-NEXT: firrtl.matchingconnect %out9, [[V]]
658+
%20 = firrtl.mux (%cond, %val1, %c0_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
659+
firrtl.connect %out9, %20 : !firrtl.uint<1>, !firrtl.uint<1>
660+
661+
// mux(cond, x, 1) -> or(~cond, x)
662+
// CHECK: [[V1:%.+]] = firrtl.not %cond
663+
// CHECK-NEXT: [[V2:%.+]] = firrtl.or [[V1]], %val1
664+
// CHECK-NEXT: firrtl.matchingconnect %out10, [[V2]]
665+
%21 = firrtl.mux (%cond, %val1, %c1_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
666+
firrtl.connect %out10, %21 : !firrtl.uint<1>, !firrtl.uint<1>
637667
}
638668

639669
// CHECK-LABEL: firrtl.module @Pad
@@ -2618,6 +2648,52 @@ firrtl.module @constReg9(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, i
26182648
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
26192649
}
26202650

2651+
// Check that a register driven by an and(en, reg) is folded to a constant zero.
2652+
// CHECK-LABEL: @constRegAnd
2653+
firrtl.module @constRegAnd(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
2654+
// CHECK-NOT: firrtl.reg
2655+
// CHECK: firrtl.matchingconnect %out, %c0_ui1
2656+
%r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1>
2657+
%0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
2658+
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
2659+
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
2660+
}
2661+
2662+
// Check that a register driven by an or(en, reg) is folded to a constant one.
2663+
// CHECK-LABEL: @constRegOr
2664+
firrtl.module @constRegOr(in %clock: !firrtl.clock, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
2665+
// CHECK-NOT: firrtl.reg
2666+
// CHECK: firrtl.matchingconnect %out, %c1_ui1
2667+
%r = firrtl.reg %clock {firrtl.random_init_start = 0 : ui64} : !firrtl.clock, !firrtl.uint<1>
2668+
%0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
2669+
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
2670+
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
2671+
}
2672+
2673+
// Check that a regreset driven by an and(en, reg) is folded to a constant zero, when the reset is zero.
2674+
// CHECK-LABEL: @constRegResetAnd
2675+
firrtl.module @constRegResetAnd(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
2676+
// CHECK-NOT: firrtl.reg
2677+
// CHECK: firrtl.matchingconnect %out, %c0_ui1
2678+
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
2679+
%r = firrtl.regreset %clock, %reset, %c0_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
2680+
%0 = firrtl.and %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
2681+
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
2682+
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
2683+
}
2684+
2685+
// Check that a regreset driven by an or(en, reg) is folded to a constant one, when the reset is one.
2686+
// CHECK-LABEL: @constRegResetOr
2687+
firrtl.module @constRegResetOr(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %en: !firrtl.uint<1>, out %out: !firrtl.uint<1>) {
2688+
// CHECK-NOT: firrtl.reg
2689+
// CHECK: firrtl.matchingconnect %out, %c1_ui1
2690+
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
2691+
%r = firrtl.regreset %clock, %reset, %c1_ui1 : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>
2692+
%0 = firrtl.or %en, %r : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
2693+
firrtl.connect %r, %0 : !firrtl.uint<1>, !firrtl.uint<1>
2694+
firrtl.matchingconnect %out, %r : !firrtl.uint<1>
2695+
}
2696+
26212697
firrtl.module @BitCast(out %o:!firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>> ) {
26222698
%a = firrtl.wire : !firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>>
26232699
%b = firrtl.bitcast %a : (!firrtl.bundle<valid: uint<1>, ready: uint<1>, data: uint<1>>) -> (!firrtl.uint<3>)

0 commit comments

Comments
 (0)