diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index b730a36488780..0c87fa2522b4b 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -663,6 +663,11 @@ LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, /// based on the vscale_range function attribute. LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth); +/// If \p LHS or \p RHS is `mul nuw V, V`, return the implied unsigned range for +/// \p V: [0, 2^ceil(bitwidth(V)/2)). +LLVM_ABI std::optional +getRangeForNuwMulSquare(const Value *V, const Value *LHS, const Value *RHS); + /// Determine the possible constant range of an integer or vector of integer /// value. This is intended as a cheap, non-recursive check. LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned, diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index df75999eb6080..462459bf56f6a 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -1353,6 +1353,9 @@ std::optional LazyValueInfoImpl::getValueFromICmpCondition( return ValueLatticeElement::getOverdefined(); unsigned BitWidth = Ty->getScalarSizeInBits(); + if (auto Range = getRangeForNuwMulSquare(Val, LHS, RHS)) + return ValueLatticeElement::getRange(*Range); + APInt Offset(BitWidth, 0); if (matchICmpOperand(Offset, LHS, Val, EdgePred)) return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI, diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 045cbab221ac3..c65851d6f641b 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -119,6 +119,25 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { return nullptr; } +std::optional llvm::getRangeForNuwMulSquare(const Value *V, + const Value *LHS, + const Value *RHS) { + if (!V->getType()->isIntegerTy()) + return std::nullopt; + + if (!match(LHS, m_NUWMul(m_Specific(V), m_Specific(V))) && + !match(RHS, m_NUWMul(m_Specific(V), m_Specific(V)))) + return std::nullopt; + + unsigned BitWidth = V->getType()->getScalarSizeInBits(); + unsigned LimitBits = (BitWidth + 1) / 2; + if (LimitBits >= BitWidth) + return std::nullopt; + + APInt Upper = APInt::getOneBitSet(BitWidth, LimitBits); + return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), Upper); +} + static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf, const APInt &DemandedElts, APInt &DemandedLHS, APInt &DemandedRHS) { @@ -976,6 +995,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp, Value *LHS = Cmp->getOperand(0); Value *RHS = Cmp->getOperand(1); + if (auto Range = getRangeForNuwMulSquare(V, LHS, RHS)) + Known = Known.unionWith(Range->toKnownBits()); + // Handle icmp pred (trunc V), C if (match(LHS, m_Trunc(m_Specific(V)))) { KnownBits DstKnown(LHS->getType()->getScalarSizeInBits()); @@ -10383,7 +10405,16 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, Value *Arg = I->getArgOperand(0); ICmpInst *Cmp = dyn_cast(Arg); // Currently we just use information from comparisons. - if (!Cmp || Cmp->getOperand(0) != V) + if (!Cmp) + continue; + + if (auto Range = getRangeForNuwMulSquare(V, Cmp->getOperand(0), + Cmp->getOperand(1))) { + CR = CR.intersectWith(*Range); + continue; + } + + if (Cmp->getOperand(0) != V) continue; // TODO: Set "ForSigned" parameter via Cmp->isSigned()? ConstantRange RHS = @@ -10514,6 +10545,14 @@ void llvm::findValuesAffectedByCondition( } } + auto AddNuwSquareOperand = [&AddAffected](Value *Op) { + Value *SquareOp = nullptr; + if (match(Op, m_NUWMul(m_Value(SquareOp), m_Deferred(SquareOp)))) + AddAffected(SquareOp); + }; + AddNuwSquareOperand(A); + AddNuwSquareOperand(B); + if (HasRHSC && match(A, m_Intrinsic(m_Value(X)))) AddAffected(X); } else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) { diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll new file mode 100644 index 0000000000000..afec6387d7301 --- /dev/null +++ b/llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll @@ -0,0 +1,87 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes=correlated-propagation -S < %s | FileCheck %s + +declare void @llvm.assume(i1) + +define i1 @assume_mul_nuw_square_i8(i8 %s) { +; CHECK-LABEL: @assume_mul_nuw_square_i8( +; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]] +; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120 +; CHECK-NEXT: call void @llvm.assume(i1 [[COND]]) +; CHECK-NEXT: ret i1 true +; + %mul = mul nuw i8 %s, %s + %cond = icmp ule i8 %mul, 120 + call void @llvm.assume(i1 %cond) + %cmp = icmp ult i8 %s, 16 + ret i1 %cmp +} + +define i1 @assume_mul_nuw_square_i5(i5 %s) { +; CHECK-LABEL: @assume_mul_nuw_square_i5( +; CHECK-NEXT: [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]] +; CHECK-NEXT: [[COND:%.*]] = icmp ult i5 [[MUL]], 15 +; CHECK-NEXT: call void @llvm.assume(i1 [[COND]]) +; CHECK-NEXT: ret i1 true +; + %mul = mul nuw i5 %s, %s + %cond = icmp ult i5 %mul, 15 + call void @llvm.assume(i1 %cond) + %cmp = icmp ult i5 %s, 8 + ret i1 %cmp +} + +define i1 @branch_mul_nuw_square(i8 %s, i8 %num) { +; CHECK-LABEL: @branch_mul_nuw_square( +; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]] +; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]] +; CHECK-NEXT: br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]] +; CHECK: true: +; CHECK-NEXT: ret i1 true +; CHECK: false: +; CHECK-NEXT: ret i1 true +; + %mul = mul nuw i8 %s, %s + %cond = icmp ule i8 %mul, %num + br i1 %cond, label %true, label %false + +true: + %cmp = icmp ult i8 %s, 16 + ret i1 %cmp + +false: + %cmp2 = icmp ult i8 %s, 16 + ret i1 %cmp2 +} + +; negative test: missing nuw on the multiply. +define i1 @assume_mul_square_no_nuw(i8 %s) { +; CHECK-LABEL: @assume_mul_square_no_nuw( +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[S:%.*]], [[S]] +; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120 +; CHECK-NEXT: call void @llvm.assume(i1 [[COND]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16 +; CHECK-NEXT: ret i1 [[CMP]] +; + %mul = mul i8 %s, %s + %cond = icmp ule i8 %mul, 120 + call void @llvm.assume(i1 %cond) + %cmp = icmp ult i8 %s, 16 + ret i1 %cmp +} + +; negative test: multiply is not a square. +define i1 @assume_mul_nuw_not_square(i8 %s, i8 %t) { +; CHECK-LABEL: @assume_mul_nuw_not_square( +; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[T:%.*]] +; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120 +; CHECK-NEXT: call void @llvm.assume(i1 [[COND]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16 +; CHECK-NEXT: ret i1 [[CMP]] +; + %mul = mul nuw i8 %s, %t + %cond = icmp ule i8 %mul, 120 + call void @llvm.assume(i1 %cond) + %cmp = icmp ult i8 %s, 16 + ret i1 %cmp +}