-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAGCombiner] Turn (neg (max x, (neg x)))
into (min x, (neg x))
#120666
base: main
Are you sure you want to change the base?
Conversation
We already have a rule to turn `(neg (abs x))` into `(min x, (neg x))`. But in some cases `(neg (max x, (neg x)))` is formed by an expanded `abs` followed by a `neg` that is only generated after the expansion. This patch adds a separate pattern to match this kind of cases.
@llvm/pr-subscribers-llvm-selectiondag Author: Min-Yih Hsu (mshockwave) ChangesThis pattern was originally spotted in 429.mcf by @topperc. We already have a DAGCombiner pattern to turn I think this pattern is applicable to both signed and unsigned min/max. Here are the Alive2 proofs:
Full diff: https://github.com/llvm/llvm-project/pull/120666.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..3cb33bdd02ef39 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -3949,6 +3949,20 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
return Result;
+ // Similar to the previous rule, but this time targeting an expanded abs.
+ // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
+ // Note that this is applicable to both signed and unsigned min/max.
+ SDValue X;
+ if (LegalOperations &&
+ sd_match(N1,
+ m_OneUse(m_AnyOf(m_SMax(m_Value(X), m_Neg(m_Deferred(X))),
+ m_UMax(m_Value(X), m_Neg(m_Deferred(X))))))) {
+ unsigned MinOpc = N1->getOpcode() == ISD::SMAX ? ISD::SMIN : ISD::UMIN;
+ if (hasOperation(MinOpc, VT))
+ return DAG.getNode(MinOpc, DL, VT, X,
+ DAG.getNode(ISD::SUB, DL, VT, N0, X));
+ }
+
// Fold neg(splat(neg(x)) -> splat(x)
if (VT.isVector()) {
SDValue N1S = DAG.getSplatValue(N1, true);
diff --git a/llvm/test/CodeGen/RISCV/neg-abs.ll b/llvm/test/CodeGen/RISCV/neg-abs.ll
index 7d6a6d7ed4ce64..c1695c88f1f384 100644
--- a/llvm/test/CodeGen/RISCV/neg-abs.ll
+++ b/llvm/test/CodeGen/RISCV/neg-abs.ll
@@ -258,3 +258,225 @@ define i64 @neg_abs64_multiuse(i64 %x, ptr %y) {
%neg = sub nsw i64 0, %abs
ret i64 %neg
}
+
+define i32 @expanded_neg_abs32(i32 %x) {
+; RV32I-LABEL: expanded_neg_abs32:
+; RV32I: # %bb.0:
+; RV32I-NEXT: neg a1, a0
+; RV32I-NEXT: blt a0, a1, .LBB6_2
+; RV32I-NEXT: # %bb.1:
+; RV32I-NEXT: mv a1, a0
+; RV32I-NEXT: .LBB6_2:
+; RV32I-NEXT: neg a0, a1
+; RV32I-NEXT: ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs32:
+; RV32ZBB: # %bb.0:
+; RV32ZBB-NEXT: neg a1, a0
+; RV32ZBB-NEXT: min a0, a0, a1
+; RV32ZBB-NEXT: ret
+;
+; RV64I-LABEL: expanded_neg_abs32:
+; RV64I: # %bb.0:
+; RV64I-NEXT: sext.w a1, a0
+; RV64I-NEXT: negw a0, a0
+; RV64I-NEXT: blt a1, a0, .LBB6_2
+; RV64I-NEXT: # %bb.1:
+; RV64I-NEXT: mv a0, a1
+; RV64I-NEXT: .LBB6_2:
+; RV64I-NEXT: negw a0, a0
+; RV64I-NEXT: ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs32:
+; RV64ZBB: # %bb.0:
+; RV64ZBB-NEXT: sext.w a1, a0
+; RV64ZBB-NEXT: negw a0, a0
+; RV64ZBB-NEXT: max a0, a0, a1
+; RV64ZBB-NEXT: negw a0, a0
+; RV64ZBB-NEXT: ret
+ %n = sub i32 0, %x
+ %t = call i32 @llvm.smax.i32(i32 %n, i32 %x)
+ %r = sub i32 0, %t
+ ret i32 %r
+}
+
+define i32 @expanded_neg_abs32_unsigned(i32 %x) {
+; RV32I-LABEL: expanded_neg_abs32_unsigned:
+; RV32I: # %bb.0:
+; RV32I-NEXT: neg a1, a0
+; RV32I-NEXT: bltu a0, a1, .LBB7_2
+; RV32I-NEXT: # %bb.1:
+; RV32I-NEXT: mv a1, a0
+; RV32I-NEXT: .LBB7_2:
+; RV32I-NEXT: neg a0, a1
+; RV32I-NEXT: ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs32_unsigned:
+; RV32ZBB: # %bb.0:
+; RV32ZBB-NEXT: neg a1, a0
+; RV32ZBB-NEXT: minu a0, a0, a1
+; RV32ZBB-NEXT: ret
+;
+; RV64I-LABEL: expanded_neg_abs32_unsigned:
+; RV64I: # %bb.0:
+; RV64I-NEXT: sext.w a1, a0
+; RV64I-NEXT: negw a0, a0
+; RV64I-NEXT: bltu a1, a0, .LBB7_2
+; RV64I-NEXT: # %bb.1:
+; RV64I-NEXT: mv a0, a1
+; RV64I-NEXT: .LBB7_2:
+; RV64I-NEXT: negw a0, a0
+; RV64I-NEXT: ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs32_unsigned:
+; RV64ZBB: # %bb.0:
+; RV64ZBB-NEXT: sext.w a1, a0
+; RV64ZBB-NEXT: negw a0, a0
+; RV64ZBB-NEXT: maxu a0, a0, a1
+; RV64ZBB-NEXT: negw a0, a0
+; RV64ZBB-NEXT: ret
+ %n = sub i32 0, %x
+ %t = call i32 @llvm.umax.i32(i32 %n, i32 %x)
+ %r = sub i32 0, %t
+ ret i32 %r
+}
+
+define i64 @expanded_neg_abs64(i64 %x) {
+; RV32I-LABEL: expanded_neg_abs64:
+; RV32I: # %bb.0:
+; RV32I-NEXT: snez a2, a0
+; RV32I-NEXT: neg a3, a1
+; RV32I-NEXT: sub a2, a3, a2
+; RV32I-NEXT: neg a3, a0
+; RV32I-NEXT: beq a2, a1, .LBB8_2
+; RV32I-NEXT: # %bb.1:
+; RV32I-NEXT: slt a4, a1, a2
+; RV32I-NEXT: beqz a4, .LBB8_3
+; RV32I-NEXT: j .LBB8_4
+; RV32I-NEXT: .LBB8_2:
+; RV32I-NEXT: sltu a4, a0, a3
+; RV32I-NEXT: bnez a4, .LBB8_4
+; RV32I-NEXT: .LBB8_3:
+; RV32I-NEXT: mv a2, a1
+; RV32I-NEXT: mv a3, a0
+; RV32I-NEXT: .LBB8_4:
+; RV32I-NEXT: snez a0, a3
+; RV32I-NEXT: add a0, a2, a0
+; RV32I-NEXT: neg a1, a0
+; RV32I-NEXT: neg a0, a3
+; RV32I-NEXT: ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs64:
+; RV32ZBB: # %bb.0:
+; RV32ZBB-NEXT: snez a2, a0
+; RV32ZBB-NEXT: neg a3, a1
+; RV32ZBB-NEXT: sub a2, a3, a2
+; RV32ZBB-NEXT: neg a3, a0
+; RV32ZBB-NEXT: beq a2, a1, .LBB8_2
+; RV32ZBB-NEXT: # %bb.1:
+; RV32ZBB-NEXT: slt a4, a1, a2
+; RV32ZBB-NEXT: beqz a4, .LBB8_3
+; RV32ZBB-NEXT: j .LBB8_4
+; RV32ZBB-NEXT: .LBB8_2:
+; RV32ZBB-NEXT: sltu a4, a0, a3
+; RV32ZBB-NEXT: bnez a4, .LBB8_4
+; RV32ZBB-NEXT: .LBB8_3:
+; RV32ZBB-NEXT: mv a2, a1
+; RV32ZBB-NEXT: mv a3, a0
+; RV32ZBB-NEXT: .LBB8_4:
+; RV32ZBB-NEXT: snez a0, a3
+; RV32ZBB-NEXT: add a0, a2, a0
+; RV32ZBB-NEXT: neg a1, a0
+; RV32ZBB-NEXT: neg a0, a3
+; RV32ZBB-NEXT: ret
+;
+; RV64I-LABEL: expanded_neg_abs64:
+; RV64I: # %bb.0:
+; RV64I-NEXT: neg a1, a0
+; RV64I-NEXT: blt a0, a1, .LBB8_2
+; RV64I-NEXT: # %bb.1:
+; RV64I-NEXT: mv a1, a0
+; RV64I-NEXT: .LBB8_2:
+; RV64I-NEXT: neg a0, a1
+; RV64I-NEXT: ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs64:
+; RV64ZBB: # %bb.0:
+; RV64ZBB-NEXT: neg a1, a0
+; RV64ZBB-NEXT: min a0, a0, a1
+; RV64ZBB-NEXT: ret
+ %n = sub i64 0, %x
+ %t = call i64 @llvm.smax.i64(i64 %n, i64 %x)
+ %r = sub i64 0, %t
+ ret i64 %r
+}
+
+define i64 @expanded_neg_abs64_unsigned(i64 %x) {
+; RV32I-LABEL: expanded_neg_abs64_unsigned:
+; RV32I: # %bb.0:
+; RV32I-NEXT: snez a2, a0
+; RV32I-NEXT: neg a3, a1
+; RV32I-NEXT: sub a2, a3, a2
+; RV32I-NEXT: neg a3, a0
+; RV32I-NEXT: beq a2, a1, .LBB9_2
+; RV32I-NEXT: # %bb.1:
+; RV32I-NEXT: sltu a4, a1, a2
+; RV32I-NEXT: beqz a4, .LBB9_3
+; RV32I-NEXT: j .LBB9_4
+; RV32I-NEXT: .LBB9_2:
+; RV32I-NEXT: sltu a4, a0, a3
+; RV32I-NEXT: bnez a4, .LBB9_4
+; RV32I-NEXT: .LBB9_3:
+; RV32I-NEXT: mv a2, a1
+; RV32I-NEXT: mv a3, a0
+; RV32I-NEXT: .LBB9_4:
+; RV32I-NEXT: snez a0, a3
+; RV32I-NEXT: add a0, a2, a0
+; RV32I-NEXT: neg a1, a0
+; RV32I-NEXT: neg a0, a3
+; RV32I-NEXT: ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs64_unsigned:
+; RV32ZBB: # %bb.0:
+; RV32ZBB-NEXT: snez a2, a0
+; RV32ZBB-NEXT: neg a3, a1
+; RV32ZBB-NEXT: sub a2, a3, a2
+; RV32ZBB-NEXT: neg a3, a0
+; RV32ZBB-NEXT: beq a2, a1, .LBB9_2
+; RV32ZBB-NEXT: # %bb.1:
+; RV32ZBB-NEXT: sltu a4, a1, a2
+; RV32ZBB-NEXT: beqz a4, .LBB9_3
+; RV32ZBB-NEXT: j .LBB9_4
+; RV32ZBB-NEXT: .LBB9_2:
+; RV32ZBB-NEXT: sltu a4, a0, a3
+; RV32ZBB-NEXT: bnez a4, .LBB9_4
+; RV32ZBB-NEXT: .LBB9_3:
+; RV32ZBB-NEXT: mv a2, a1
+; RV32ZBB-NEXT: mv a3, a0
+; RV32ZBB-NEXT: .LBB9_4:
+; RV32ZBB-NEXT: snez a0, a3
+; RV32ZBB-NEXT: add a0, a2, a0
+; RV32ZBB-NEXT: neg a1, a0
+; RV32ZBB-NEXT: neg a0, a3
+; RV32ZBB-NEXT: ret
+;
+; RV64I-LABEL: expanded_neg_abs64_unsigned:
+; RV64I: # %bb.0:
+; RV64I-NEXT: neg a1, a0
+; RV64I-NEXT: bltu a0, a1, .LBB9_2
+; RV64I-NEXT: # %bb.1:
+; RV64I-NEXT: mv a1, a0
+; RV64I-NEXT: .LBB9_2:
+; RV64I-NEXT: neg a0, a1
+; RV64I-NEXT: ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs64_unsigned:
+; RV64ZBB: # %bb.0:
+; RV64ZBB-NEXT: neg a1, a0
+; RV64ZBB-NEXT: minu a0, a0, a1
+; RV64ZBB-NEXT: ret
+ %n = sub i64 0, %x
+ %t = call i64 @llvm.umax.i64(i64 %n, i64 %x)
+ %r = sub i64 0, %t
+ ret i64 %r
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also do the globalisel version?
unsigned MinOpc = N1->getOpcode() == ISD::SMAX ? ISD::SMIN : ISD::UMIN; | ||
if (hasOperation(MinOpc, VT)) | ||
return DAG.getNode(MinOpc, DL, VT, X, | ||
DAG.getNode(ISD::SUB, DL, VT, N0, X)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can preserve flags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I simply capture and reuse the SDValue of the first sub. It's done now.
@@ -3949,6 +3949,20 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { | |||
if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true)) | |||
return Result; | |||
|
|||
// Similar to the previous rule, but this time targeting an expanded abs. | |||
// (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also do the converse:
(sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
but I don't know if there's a real need for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to add it in this patch. It's done now.
Preserve the flags from the first sub
Of course. Let me do that in a separate PR. |
SDValue X; | ||
SDValue S0; | ||
auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0)); | ||
if (LegalOperations && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically you should check the max is legal, but I doubt in practice the legality of min is different than max
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt in practice the legality of min is different than max
This, and even if only max is illegal, I think it'll be expanded (while min is not), which makes this transformation even more profitable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends in which combiner phase, eventually only legal operations can be emitted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, if you're referring to the max generated by this rule, I already checked its legality (through NewOpc
) below.
switch (N1->getOpcode()) { | ||
case ISD::SMAX: | ||
NewOpc = ISD::SMIN; | ||
break; | ||
case ISD::UMAX: | ||
NewOpc = ISD::UMIN; | ||
break; | ||
case ISD::SMIN: | ||
NewOpc = ISD::SMAX; | ||
break; | ||
case ISD::UMIN: | ||
NewOpc = ISD::UMAX; | ||
break; | ||
default: | ||
llvm_unreachable("unrecognized opcode"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be helper function instead of switch and break. Do we not have a min<->max helper already? I swear I've written one several times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not have a min<->max helper already? I swear I've written one several times
That was what I tried to search in the first place too, but I don't think there is any. So I'd extracted them into ISD::getInverseMinMaxOpcode
.
This is the GISel version of llvm#120666.
This is the GISel version of llvm#120666.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm with nits
@@ -1506,6 +1506,8 @@ inline bool isBitwiseLogicOp(unsigned Opcode) { | |||
return Opcode == ISD::AND || Opcode == ISD::OR || Opcode == ISD::XOR; | |||
} | |||
|
|||
NodeType getInverseMinMaxOpcode(unsigned MinMaxOpc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
SDValue X; | ||
SDValue S0; | ||
auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0)); | ||
if (LegalOperations && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends in which combiner phase, eventually only legal operations can be emitted
%t = call i64 @llvm.umin.i64(i64 %n, i64 %x) | ||
%r = sub i64 0, %t | ||
ret i64 %r | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test some vectors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
And run this combiner rule in both pre- and post-legalized phase.
Note that I just made this rule applicable to pre-legalized phases as well. My argument is that since I already checked the legality of the newly generated operation, it shouldn't matter whether it's ran before or after the legalization. Plus, this will support RISC-V's fixed vectors out of the box (scalable vectors are not covered though, because although SDPatternMatch can use the same pattern on both scalar and VP operations through MatchContext, it needs to change visitSUB's signature to accept a MatchContext argument). |
This pattern was originally spotted in 429.mcf by @topperc.
We already have a DAGCombiner pattern to turn
(neg (abs x))
into(min x, (neg x))
. But in some cases(neg (max x, (neg x)))
is formed by an expandedabs
followed by aneg
that is generated only after theabs
expansion. This patch adds a separate pattern to match cases like this, as well as its inverse pattern:(neg (min X, (neg X))) --> (max X, (neg X))
.I think this pattern is applicable to both signed and unsigned min/max. Here are the Alive2 proofs: