Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAGCombiner] Turn (neg (max x, (neg x))) into (min x, (neg x)) #120666

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented Dec 20, 2024

This pattern was originally spotted in 429.mcf by @topperc.

We already have a DAGCombiner pattern to turn (neg (abs x)) into (min x, (neg x)). But in some cases (neg (max x, (neg x))) is formed by an expanded abs followed by a neg that is generated only after the abs expansion. This patch adds a separate pattern to match cases like this, as well as its inverse pattern: (neg (min X, (neg X))) --> (max X, (neg X)).

I think this pattern is applicable to both signed and unsigned min/max. Here are the Alive2 proofs:

We already have a rule to turn `(neg (abs x))` into `(min x, (neg x))`.
But in some cases `(neg (max x, (neg x)))` is formed by an expanded
`abs` followed by a `neg` that is only generated after the expansion.
This patch adds a separate pattern to match this kind of cases.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Dec 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Min-Yih Hsu (mshockwave)

Changes

This pattern was originally spotted in 429.mcf by @topperc.

We already have a DAGCombiner pattern to turn (neg (abs x)) into (min x, (neg x)). But in some cases (neg (max x, (neg x))) is formed by an expanded abs followed by a neg that is generated only after the abs expansion. This patch adds a separate pattern to match cases like this.

I think this pattern is applicable to both signed and unsigned min/max. Here are the Alive2 proofs:


Full diff: https://github.com/llvm/llvm-project/pull/120666.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14)
  • (modified) llvm/test/CodeGen/RISCV/neg-abs.ll (+222)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..3cb33bdd02ef39 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -3949,6 +3949,20 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
         return Result;
 
+    // Similar to the previous rule, but this time targeting an expanded abs.
+    // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
+    // Note that this is applicable to both signed and unsigned min/max.
+    SDValue X;
+    if (LegalOperations &&
+        sd_match(N1,
+                 m_OneUse(m_AnyOf(m_SMax(m_Value(X), m_Neg(m_Deferred(X))),
+                                  m_UMax(m_Value(X), m_Neg(m_Deferred(X))))))) {
+      unsigned MinOpc = N1->getOpcode() == ISD::SMAX ? ISD::SMIN : ISD::UMIN;
+      if (hasOperation(MinOpc, VT))
+        return DAG.getNode(MinOpc, DL, VT, X,
+                           DAG.getNode(ISD::SUB, DL, VT, N0, X));
+    }
+
     // Fold neg(splat(neg(x)) -> splat(x)
     if (VT.isVector()) {
       SDValue N1S = DAG.getSplatValue(N1, true);
diff --git a/llvm/test/CodeGen/RISCV/neg-abs.ll b/llvm/test/CodeGen/RISCV/neg-abs.ll
index 7d6a6d7ed4ce64..c1695c88f1f384 100644
--- a/llvm/test/CodeGen/RISCV/neg-abs.ll
+++ b/llvm/test/CodeGen/RISCV/neg-abs.ll
@@ -258,3 +258,225 @@ define i64 @neg_abs64_multiuse(i64 %x, ptr %y) {
   %neg = sub nsw i64 0, %abs
   ret i64 %neg
 }
+
+define i32 @expanded_neg_abs32(i32 %x) {
+; RV32I-LABEL: expanded_neg_abs32:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    neg a1, a0
+; RV32I-NEXT:    blt a0, a1, .LBB6_2
+; RV32I-NEXT:  # %bb.1:
+; RV32I-NEXT:    mv a1, a0
+; RV32I-NEXT:  .LBB6_2:
+; RV32I-NEXT:    neg a0, a1
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs32:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    neg a1, a0
+; RV32ZBB-NEXT:    min a0, a0, a1
+; RV32ZBB-NEXT:    ret
+;
+; RV64I-LABEL: expanded_neg_abs32:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    sext.w a1, a0
+; RV64I-NEXT:    negw a0, a0
+; RV64I-NEXT:    blt a1, a0, .LBB6_2
+; RV64I-NEXT:  # %bb.1:
+; RV64I-NEXT:    mv a0, a1
+; RV64I-NEXT:  .LBB6_2:
+; RV64I-NEXT:    negw a0, a0
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs32:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    sext.w a1, a0
+; RV64ZBB-NEXT:    negw a0, a0
+; RV64ZBB-NEXT:    max a0, a0, a1
+; RV64ZBB-NEXT:    negw a0, a0
+; RV64ZBB-NEXT:    ret
+  %n = sub i32 0, %x
+  %t = call i32 @llvm.smax.i32(i32 %n, i32 %x)
+  %r = sub i32 0, %t
+  ret i32 %r
+}
+
+define i32 @expanded_neg_abs32_unsigned(i32 %x) {
+; RV32I-LABEL: expanded_neg_abs32_unsigned:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    neg a1, a0
+; RV32I-NEXT:    bltu a0, a1, .LBB7_2
+; RV32I-NEXT:  # %bb.1:
+; RV32I-NEXT:    mv a1, a0
+; RV32I-NEXT:  .LBB7_2:
+; RV32I-NEXT:    neg a0, a1
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs32_unsigned:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    neg a1, a0
+; RV32ZBB-NEXT:    minu a0, a0, a1
+; RV32ZBB-NEXT:    ret
+;
+; RV64I-LABEL: expanded_neg_abs32_unsigned:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    sext.w a1, a0
+; RV64I-NEXT:    negw a0, a0
+; RV64I-NEXT:    bltu a1, a0, .LBB7_2
+; RV64I-NEXT:  # %bb.1:
+; RV64I-NEXT:    mv a0, a1
+; RV64I-NEXT:  .LBB7_2:
+; RV64I-NEXT:    negw a0, a0
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs32_unsigned:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    sext.w a1, a0
+; RV64ZBB-NEXT:    negw a0, a0
+; RV64ZBB-NEXT:    maxu a0, a0, a1
+; RV64ZBB-NEXT:    negw a0, a0
+; RV64ZBB-NEXT:    ret
+  %n = sub i32 0, %x
+  %t = call i32 @llvm.umax.i32(i32 %n, i32 %x)
+  %r = sub i32 0, %t
+  ret i32 %r
+}
+
+define i64 @expanded_neg_abs64(i64 %x) {
+; RV32I-LABEL: expanded_neg_abs64:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    snez a2, a0
+; RV32I-NEXT:    neg a3, a1
+; RV32I-NEXT:    sub a2, a3, a2
+; RV32I-NEXT:    neg a3, a0
+; RV32I-NEXT:    beq a2, a1, .LBB8_2
+; RV32I-NEXT:  # %bb.1:
+; RV32I-NEXT:    slt a4, a1, a2
+; RV32I-NEXT:    beqz a4, .LBB8_3
+; RV32I-NEXT:    j .LBB8_4
+; RV32I-NEXT:  .LBB8_2:
+; RV32I-NEXT:    sltu a4, a0, a3
+; RV32I-NEXT:    bnez a4, .LBB8_4
+; RV32I-NEXT:  .LBB8_3:
+; RV32I-NEXT:    mv a2, a1
+; RV32I-NEXT:    mv a3, a0
+; RV32I-NEXT:  .LBB8_4:
+; RV32I-NEXT:    snez a0, a3
+; RV32I-NEXT:    add a0, a2, a0
+; RV32I-NEXT:    neg a1, a0
+; RV32I-NEXT:    neg a0, a3
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs64:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    snez a2, a0
+; RV32ZBB-NEXT:    neg a3, a1
+; RV32ZBB-NEXT:    sub a2, a3, a2
+; RV32ZBB-NEXT:    neg a3, a0
+; RV32ZBB-NEXT:    beq a2, a1, .LBB8_2
+; RV32ZBB-NEXT:  # %bb.1:
+; RV32ZBB-NEXT:    slt a4, a1, a2
+; RV32ZBB-NEXT:    beqz a4, .LBB8_3
+; RV32ZBB-NEXT:    j .LBB8_4
+; RV32ZBB-NEXT:  .LBB8_2:
+; RV32ZBB-NEXT:    sltu a4, a0, a3
+; RV32ZBB-NEXT:    bnez a4, .LBB8_4
+; RV32ZBB-NEXT:  .LBB8_3:
+; RV32ZBB-NEXT:    mv a2, a1
+; RV32ZBB-NEXT:    mv a3, a0
+; RV32ZBB-NEXT:  .LBB8_4:
+; RV32ZBB-NEXT:    snez a0, a3
+; RV32ZBB-NEXT:    add a0, a2, a0
+; RV32ZBB-NEXT:    neg a1, a0
+; RV32ZBB-NEXT:    neg a0, a3
+; RV32ZBB-NEXT:    ret
+;
+; RV64I-LABEL: expanded_neg_abs64:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    neg a1, a0
+; RV64I-NEXT:    blt a0, a1, .LBB8_2
+; RV64I-NEXT:  # %bb.1:
+; RV64I-NEXT:    mv a1, a0
+; RV64I-NEXT:  .LBB8_2:
+; RV64I-NEXT:    neg a0, a1
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs64:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    neg a1, a0
+; RV64ZBB-NEXT:    min a0, a0, a1
+; RV64ZBB-NEXT:    ret
+  %n = sub i64 0, %x
+  %t = call i64 @llvm.smax.i64(i64 %n, i64 %x)
+  %r = sub i64 0, %t
+  ret i64 %r
+}
+
+define i64 @expanded_neg_abs64_unsigned(i64 %x) {
+; RV32I-LABEL: expanded_neg_abs64_unsigned:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    snez a2, a0
+; RV32I-NEXT:    neg a3, a1
+; RV32I-NEXT:    sub a2, a3, a2
+; RV32I-NEXT:    neg a3, a0
+; RV32I-NEXT:    beq a2, a1, .LBB9_2
+; RV32I-NEXT:  # %bb.1:
+; RV32I-NEXT:    sltu a4, a1, a2
+; RV32I-NEXT:    beqz a4, .LBB9_3
+; RV32I-NEXT:    j .LBB9_4
+; RV32I-NEXT:  .LBB9_2:
+; RV32I-NEXT:    sltu a4, a0, a3
+; RV32I-NEXT:    bnez a4, .LBB9_4
+; RV32I-NEXT:  .LBB9_3:
+; RV32I-NEXT:    mv a2, a1
+; RV32I-NEXT:    mv a3, a0
+; RV32I-NEXT:  .LBB9_4:
+; RV32I-NEXT:    snez a0, a3
+; RV32I-NEXT:    add a0, a2, a0
+; RV32I-NEXT:    neg a1, a0
+; RV32I-NEXT:    neg a0, a3
+; RV32I-NEXT:    ret
+;
+; RV32ZBB-LABEL: expanded_neg_abs64_unsigned:
+; RV32ZBB:       # %bb.0:
+; RV32ZBB-NEXT:    snez a2, a0
+; RV32ZBB-NEXT:    neg a3, a1
+; RV32ZBB-NEXT:    sub a2, a3, a2
+; RV32ZBB-NEXT:    neg a3, a0
+; RV32ZBB-NEXT:    beq a2, a1, .LBB9_2
+; RV32ZBB-NEXT:  # %bb.1:
+; RV32ZBB-NEXT:    sltu a4, a1, a2
+; RV32ZBB-NEXT:    beqz a4, .LBB9_3
+; RV32ZBB-NEXT:    j .LBB9_4
+; RV32ZBB-NEXT:  .LBB9_2:
+; RV32ZBB-NEXT:    sltu a4, a0, a3
+; RV32ZBB-NEXT:    bnez a4, .LBB9_4
+; RV32ZBB-NEXT:  .LBB9_3:
+; RV32ZBB-NEXT:    mv a2, a1
+; RV32ZBB-NEXT:    mv a3, a0
+; RV32ZBB-NEXT:  .LBB9_4:
+; RV32ZBB-NEXT:    snez a0, a3
+; RV32ZBB-NEXT:    add a0, a2, a0
+; RV32ZBB-NEXT:    neg a1, a0
+; RV32ZBB-NEXT:    neg a0, a3
+; RV32ZBB-NEXT:    ret
+;
+; RV64I-LABEL: expanded_neg_abs64_unsigned:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    neg a1, a0
+; RV64I-NEXT:    bltu a0, a1, .LBB9_2
+; RV64I-NEXT:  # %bb.1:
+; RV64I-NEXT:    mv a1, a0
+; RV64I-NEXT:  .LBB9_2:
+; RV64I-NEXT:    neg a0, a1
+; RV64I-NEXT:    ret
+;
+; RV64ZBB-LABEL: expanded_neg_abs64_unsigned:
+; RV64ZBB:       # %bb.0:
+; RV64ZBB-NEXT:    neg a1, a0
+; RV64ZBB-NEXT:    minu a0, a0, a1
+; RV64ZBB-NEXT:    ret
+  %n = sub i64 0, %x
+  %t = call i64 @llvm.umax.i64(i64 %n, i64 %x)
+  %r = sub i64 0, %t
+  ret i64 %r
+}

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also do the globalisel version?

unsigned MinOpc = N1->getOpcode() == ISD::SMAX ? ISD::SMIN : ISD::UMIN;
if (hasOperation(MinOpc, VT))
return DAG.getNode(MinOpc, DL, VT, X,
DAG.getNode(ISD::SUB, DL, VT, N0, X));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can preserve flags

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I simply capture and reuse the SDValue of the first sub. It's done now.

@@ -3949,6 +3949,20 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
return Result;

// Similar to the previous rule, but this time targeting an expanded abs.
// (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also do the converse:

(sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))

but I don't know if there's a real need for it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to add it in this patch. It's done now.

@mshockwave
Copy link
Member Author

Can you also do the globalisel version?

Of course. Let me do that in a separate PR.

SDValue X;
SDValue S0;
auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0));
if (LegalOperations &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you should check the max is legal, but I doubt in practice the legality of min is different than max

Copy link
Member Author

@mshockwave mshockwave Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt in practice the legality of min is different than max

This, and even if only max is illegal, I think it'll be expanded (while min is not), which makes this transformation even more profitable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends in which combiner phase, eventually only legal operations can be emitted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, if you're referring to the max generated by this rule, I already checked its legality (through NewOpc) below.

Comment on lines 3966 to 3981
switch (N1->getOpcode()) {
case ISD::SMAX:
NewOpc = ISD::SMIN;
break;
case ISD::UMAX:
NewOpc = ISD::UMIN;
break;
case ISD::SMIN:
NewOpc = ISD::SMAX;
break;
case ISD::UMIN:
NewOpc = ISD::UMAX;
break;
default:
llvm_unreachable("unrecognized opcode");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be helper function instead of switch and break. Do we not have a min<->max helper already? I swear I've written one several times

Copy link
Member Author

@mshockwave mshockwave Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a min<->max helper already? I swear I've written one several times

That was what I tried to search in the first place too, but I don't think there is any. So I'd extracted them into ISD::getInverseMinMaxOpcode.

mshockwave added a commit to mshockwave/llvm-project that referenced this pull request Dec 23, 2024
mshockwave added a commit to mshockwave/llvm-project that referenced this pull request Dec 23, 2024
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm with nits

@@ -1506,6 +1506,8 @@ inline bool isBitwiseLogicOp(unsigned Opcode) {
return Opcode == ISD::AND || Opcode == ISD::OR || Opcode == ISD::XOR;
}

NodeType getInverseMinMaxOpcode(unsigned MinMaxOpc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

SDValue X;
SDValue S0;
auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0));
if (LegalOperations &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends in which combiner phase, eventually only legal operations can be emitted

%t = call i64 @llvm.umin.i64(i64 %n, i64 %x)
%r = sub i64 0, %t
ret i64 %r
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test some vectors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

And run this combiner rule in both pre- and post-legalized phase.
@mshockwave
Copy link
Member Author

mshockwave commented Dec 26, 2024

Note that I just made this rule applicable to pre-legalized phases as well. My argument is that since I already checked the legality of the newly generated operation, it shouldn't matter whether it's ran before or after the legalization. Plus, this will support RISC-V's fixed vectors out of the box (scalable vectors are not covered though, because although SDPatternMatch can use the same pattern on both scalar and VP operations through MatchContext, it needs to change visitSUB's signature to accept a MatchContext argument).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants