Skip to content

Commit ca6470c

Browse files
committed
[LV] Support argmin/argmax with strict predicates.
Extend handleMultiUseReductions to support strict predicates (>, <), matching the first index instead of the last for non-strict predicates. For strict predicates is detected, the transformation converts the FindLastIV reduction to FindFirstIV by: 1. Checking the IV range to ensure it does not include the sentinel value (max). 2. Creating a new reduction with the appropriate FindFirstIV kind (FindFirstIVSMin or FindFirstIVUMin based on the IV range) 3. Replacing the old reduction recipe with the new one
1 parent 13f7b30 commit ca6470c

File tree

14 files changed

+1079
-257
lines changed

14 files changed

+1079
-257
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Loop;
2828
class PredicatedScalarEvolution;
2929
class ScalarEvolution;
3030
class SCEV;
31+
class SCEVAddRecExpr;
3132
class StoreInst;
3233

3334
/// These are the kinds of recurrences that we support.
@@ -310,6 +311,11 @@ class RecurrenceDescriptor {
310311
isFindLastIVRecurrenceKind(Kind);
311312
}
312313

314+
/// Returns true if \p AR's range is valid for either FindFirstIV or
315+
/// FindLastIV reductions i.e. if the sentinel value is outside \p AR's range.
316+
static bool isValidIVRangeForFindIV(const SCEVAddRecExpr *AR, bool IsSigned,
317+
bool IsFindFirstIV, ScalarEvolution &SE);
318+
313319
/// Returns the type of the recurrence. This type can be narrower than the
314320
/// actual type of the Phi if the recurrence has been type-promoted.
315321
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,36 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
715715
return InstDesc(I, RecurKind::AnyOf);
716716
}
717717

718+
bool RecurrenceDescriptor::isValidIVRangeForFindIV(const SCEVAddRecExpr *AR,
719+
bool IsSigned,
720+
bool IsFindFirstIV,
721+
ScalarEvolution &SE) {
722+
const ConstantRange IVRange =
723+
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
724+
unsigned NumBits = AR->getType()->getIntegerBitWidth();
725+
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
726+
727+
if (IsFindFirstIV) {
728+
if (IsSigned)
729+
ValidRange =
730+
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
731+
APInt::getSignedMaxValue(NumBits) - 1);
732+
else
733+
ValidRange = ConstantRange::getNonEmpty(APInt::getMinValue(NumBits),
734+
APInt::getMaxValue(NumBits) - 1);
735+
} else {
736+
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
737+
: APInt::getMinValue(NumBits);
738+
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
739+
}
740+
741+
LLVM_DEBUG(dbgs() << "LV: " << (IsFindFirstIV ? "FindFirstIV" : "FindLastIV")
742+
<< " valid range is " << ValidRange << ", and the range of "
743+
<< *AR << " is " << IVRange << "\n");
744+
745+
return ValidRange.contains(IVRange);
746+
}
747+
718748
// We are looking for loops that do something like this:
719749
// int r = 0;
720750
// for (int i = 0; i < n; i++) {
@@ -792,49 +822,24 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
792822
// [Signed|Unsigned]Max(<recurrence type>) for FindFirstIV.
793823
// TODO: This range restriction can be lifted by adding an additional
794824
// virtual OR reduction.
795-
auto CheckRange = [&](bool IsSigned) {
796-
const ConstantRange IVRange =
797-
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
798-
unsigned NumBits = Ty->getIntegerBitWidth();
799-
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
800-
if (isFindLastIVRecurrenceKind(Kind)) {
801-
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
802-
: APInt::getMinValue(NumBits);
803-
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
804-
} else {
805-
if (IsSigned)
806-
ValidRange =
807-
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
808-
APInt::getSignedMaxValue(NumBits) - 1);
809-
else
810-
ValidRange = ConstantRange::getNonEmpty(
811-
APInt::getMinValue(NumBits), APInt::getMaxValue(NumBits) - 1);
812-
}
813-
814-
LLVM_DEBUG(dbgs() << "LV: "
815-
<< (isFindLastIVRecurrenceKind(Kind) ? "FindLastIV"
816-
: "FindFirstIV")
817-
<< " valid range is " << ValidRange
818-
<< ", and the range of " << *AR << " is " << IVRange
819-
<< "\n");
820-
821-
// Ensure the induction variable does not wrap around by verifying that
822-
// its range is fully contained within the valid range.
823-
return ValidRange.contains(IVRange);
824-
};
825+
bool IsFindFirstIV = isFindFirstIVRecurrenceKind(Kind);
825826
if (isFindLastIVRecurrenceKind(Kind)) {
826-
if (CheckRange(true))
827+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
828+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/true, IsFindFirstIV, SE))
827829
return RecurKind::FindLastIVSMax;
828-
if (CheckRange(false))
830+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
831+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/false, IsFindFirstIV, SE))
829832
return RecurKind::FindLastIVUMax;
830833
return std::nullopt;
831834
}
832835
assert(isFindFirstIVRecurrenceKind(Kind) &&
833836
"Kind must either be a FindLastIV or FindFirstIV");
834837

835-
if (CheckRange(true))
838+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
839+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/true, IsFindFirstIV, SE))
836840
return RecurKind::FindFirstIVSMin;
837-
if (CheckRange(false))
841+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
842+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/false, IsFindFirstIV, SE))
838843
return RecurKind::FindFirstIVUMin;
839844
return std::nullopt;
840845
};

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8506,8 +8506,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
85068506

85078507
// Apply mandatory transformation to handle reductions with multiple in-loop
85088508
// uses if possible, bail out otherwise.
8509-
if (!VPlanTransforms::runPass(VPlanTransforms::handleMultiUseReductions,
8510-
*Plan))
8509+
if (!VPlanTransforms::handleMultiUseReductions(*Plan, *PSE.getSE(), OrigLoop))
85118510
return nullptr;
85128511
// Apply mandatory transformation to handle FP maxnum/minnum reduction with
85138512
// NaNs if possible, bail out otherwise.

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414
#include "LoopVectorizationPlanner.h"
1515
#include "VPlan.h"
16+
#include "VPlanAnalysis.h"
1617
#include "VPlanCFG.h"
1718
#include "VPlanDominatorTree.h"
1819
#include "VPlanPatternMatch.h"
1920
#include "VPlanTransforms.h"
21+
#include "VPlanUtils.h"
2022
#include "llvm/Analysis/LoopInfo.h"
2123
#include "llvm/Analysis/LoopIterator.h"
2224
#include "llvm/Analysis/ScalarEvolution.h"
@@ -1120,7 +1122,48 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
11201122
return true;
11211123
}
11221124

1123-
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
1125+
/// Try to convert FindLastIV to FindFirstIV reduction when using a strict
1126+
/// predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1127+
static VPReductionPHIRecipe *
1128+
tryConvertToFindFirstIV(VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1129+
VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1130+
Type *Ty = VPTypeAnalysis(Plan).inferScalarType(FindLastIVPhiR);
1131+
unsigned NumBits = Ty->getIntegerBitWidth();
1132+
1133+
// Determine the reduction kind and sentinel based on the IV range.
1134+
RecurKind NewKind;
1135+
VPValue *NewSentinel;
1136+
auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue(IVOp, SE, L));
1137+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1138+
AR, /*IsSigned=*/true, /*IsFindFirstIV=*/true, SE)) {
1139+
NewKind = RecurKind::FindFirstIVSMin;
1140+
NewSentinel = Plan.getConstantInt(APInt::getSignedMaxValue(NumBits));
1141+
} else if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1142+
AR, /*IsSigned=*/false, /*IsFindFirstIV=*/true, SE)) {
1143+
NewKind = RecurKind::FindFirstIVUMin;
1144+
NewSentinel = Plan.getConstantInt(APInt::getMaxValue(NumBits));
1145+
} else {
1146+
return nullptr;
1147+
}
1148+
1149+
// Create the new FindFirstIV reduction recipe.
1150+
assert(!FindLastIVPhiR->isInLoop() && !FindLastIVPhiR->isOrdered());
1151+
ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor()};
1152+
auto *FindFirstIVPhiR =
1153+
new VPReductionPHIRecipe(nullptr, NewKind, *NewSentinel, Style,
1154+
FindLastIVPhiR->hasUsesOutsideReductionChain());
1155+
FindFirstIVPhiR->addOperand(FindLastIVPhiR->getBackedgeValue());
1156+
1157+
FindFirstIVPhiR->insertBefore(FindLastIVPhiR);
1158+
VPInstruction *FindLastIVResult =
1159+
findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1160+
FindLastIVPhiR->replaceAllUsesWith(FindFirstIVPhiR);
1161+
FindLastIVResult->setOperand(2, NewSentinel);
1162+
return FindFirstIVPhiR;
1163+
}
1164+
1165+
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
1166+
const Loop *L) {
11241167
for (auto &PhiR : make_early_inc_range(
11251168
Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis())) {
11261169
auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1203,33 +1246,41 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
12031246
FindIVPhiR->getRecurrenceKind()))
12041247
return false;
12051248

1249+
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
1250+
"cannot handle inloop/ordered reductions yet");
1251+
12061252
// TODO: Support cases where IVOp is the IV increment.
12071253
if (!match(IVOp, m_TruncOrSelf(m_VPValue(IVOp))) ||
12081254
!isa<VPWidenIntOrFpInductionRecipe>(IVOp))
12091255
return false;
12101256

1211-
CmpInst::Predicate RdxPredicate = [RdxKind]() {
1257+
// Check if the predicate is compatible with the reduction kind.
1258+
bool IsValidPredicate = [RdxKind, Pred]() {
12121259
switch (RdxKind) {
12131260
case RecurKind::UMin:
1214-
return CmpInst::ICMP_UGE;
1261+
return Pred == CmpInst::ICMP_UGE || Pred == CmpInst::ICMP_UGT;
12151262
case RecurKind::UMax:
1216-
return CmpInst::ICMP_ULE;
1263+
return Pred == CmpInst::ICMP_ULE || Pred == CmpInst::ICMP_ULT;
12171264
case RecurKind::SMax:
1218-
return CmpInst::ICMP_SLE;
1265+
return Pred == CmpInst::ICMP_SLE || Pred == CmpInst::ICMP_SLT;
12191266
case RecurKind::SMin:
1220-
return CmpInst::ICMP_SGE;
1267+
return Pred == CmpInst::ICMP_SGE || Pred == CmpInst::ICMP_SGT;
12211268
default:
12221269
llvm_unreachable("unhandled recurrence kind");
12231270
}
12241271
}();
12251272

1226-
// TODO: Strict predicates need to find the first IV value for which the
1227-
// predicate holds, not the last.
1228-
if (Pred != RdxPredicate)
1273+
if (!IsValidPredicate)
12291274
return false;
12301275

1231-
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
1232-
"cannot handle inloop/ordered reductions yet");
1276+
// For strict predicates, transform try to convert FindLastIV to
1277+
// FindFirstIV.
1278+
bool IsStrictPredicate = ICmpInst::isLT(Pred) || ICmpInst::isGT(Pred);
1279+
if (IsStrictPredicate) {
1280+
FindIVPhiR = tryConvertToFindFirstIV(Plan, FindIVPhiR, IVOp, SE, L);
1281+
if (!FindIVPhiR)
1282+
return false;
1283+
}
12331284

12341285
// The reduction using MinMaxPhiR needs adjusting to compute the correct
12351286
// result:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
163163
return cast<VPExpressionRecipe>(this)->mayHaveSideEffects();
164164
case VPDerivedIVSC:
165165
case VPFirstOrderRecurrencePHISC:
166+
case VPReductionPHISC:
166167
case VPPredInstPHISC:
167168
case VPVectorEndPointerSC:
168169
return false;

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ struct VPlanTransforms {
157157
const TargetLibraryInfo &TLI);
158158

159159
/// Try to legalize reductions with multiple in-loop uses. Currently only
160-
/// min/max reductions used by FindLastIV reductions are supported. Otherwise
161-
/// return false.
162-
static bool handleMultiUseReductions(VPlan &Plan);
160+
/// min/max reductions used by FindLastIV and FindFirstIV reductions are
161+
/// supported. Otherwise return false.
162+
static bool handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
163+
const Loop *L);
163164

164165
/// Try to have all users of fixed-order recurrences appear after the recipe
165166
/// defining their previous value, by either sinking users or hoisting recipes

0 commit comments

Comments
 (0)