Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,11 @@ LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
/// based on the vscale_range function attribute.
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);

/// If \p LHS or \p RHS is `mul nuw V, V`, return the implied unsigned range for
/// \p V: [0, 2^ceil(bitwidth(V)/2)).
LLVM_ABI std::optional<ConstantRange>
getRangeForNuwMulSquare(const Value *V, const Value *LHS, const Value *RHS);

/// Determine the possible constant range of an integer or vector of integer
/// value. This is intended as a cheap, non-recursive check.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Analysis/LazyValueInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,9 @@ std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
return ValueLatticeElement::getOverdefined();

unsigned BitWidth = Ty->getScalarSizeInBits();
if (auto Range = getRangeForNuwMulSquare(Val, LHS, RHS))
return ValueLatticeElement::getRange(*Range);

APInt Offset(BitWidth, 0);
if (matchICmpOperand(Offset, LHS, Val, EdgePred))
return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
Expand Down
41 changes: 40 additions & 1 deletion llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
return nullptr;
}

std::optional<ConstantRange> llvm::getRangeForNuwMulSquare(const Value *V,
const Value *LHS,
const Value *RHS) {
if (!V->getType()->isIntegerTy())
return std::nullopt;

if (!match(LHS, m_NUWMul(m_Specific(V), m_Specific(V))) &&
!match(RHS, m_NUWMul(m_Specific(V), m_Specific(V))))
return std::nullopt;

unsigned BitWidth = V->getType()->getScalarSizeInBits();
unsigned LimitBits = (BitWidth + 1) / 2;
if (LimitBits >= BitWidth)
return std::nullopt;

APInt Upper = APInt::getOneBitSet(BitWidth, LimitBits);
return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), Upper);
}

static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
const APInt &DemandedElts,
APInt &DemandedLHS, APInt &DemandedRHS) {
Expand Down Expand Up @@ -976,6 +995,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
Value *LHS = Cmp->getOperand(0);
Value *RHS = Cmp->getOperand(1);

if (auto Range = getRangeForNuwMulSquare(V, LHS, RHS))
Known = Known.unionWith(Range->toKnownBits());

// Handle icmp pred (trunc V), C
if (match(LHS, m_Trunc(m_Specific(V)))) {
KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
Expand Down Expand Up @@ -10383,7 +10405,16 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
Value *Arg = I->getArgOperand(0);
ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
// Currently we just use information from comparisons.
if (!Cmp || Cmp->getOperand(0) != V)
if (!Cmp)
continue;

if (auto Range = getRangeForNuwMulSquare(V, Cmp->getOperand(0),
Cmp->getOperand(1))) {
CR = CR.intersectWith(*Range);
continue;
}

if (Cmp->getOperand(0) != V)
continue;
// TODO: Set "ForSigned" parameter via Cmp->isSigned()?
ConstantRange RHS =
Expand Down Expand Up @@ -10514,6 +10545,14 @@ void llvm::findValuesAffectedByCondition(
}
}

auto AddNuwSquareOperand = [&AddAffected](Value *Op) {
Value *SquareOp = nullptr;
if (match(Op, m_NUWMul(m_Value(SquareOp), m_Deferred(SquareOp))))
AddAffected(SquareOp);
};
AddNuwSquareOperand(A);
AddNuwSquareOperand(B);

if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
AddAffected(X);
} else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
Expand Down
87 changes: 87 additions & 0 deletions llvm/test/Transforms/CorrelatedValuePropagation/mul-nuw-square.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=correlated-propagation -S < %s | FileCheck %s

declare void @llvm.assume(i1)

define i1 @assume_mul_nuw_square_i8(i8 %s) {
; CHECK-LABEL: @assume_mul_nuw_square_i8(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i8 %s, %s
%cond = icmp ule i8 %mul, 120
call void @llvm.assume(i1 %cond)
%cmp = icmp ult i8 %s, 16
ret i1 %cmp
}

define i1 @assume_mul_nuw_square_i5(i5 %s) {
; CHECK-LABEL: @assume_mul_nuw_square_i5(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i5 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ult i5 [[MUL]], 15
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i5 %s, %s
%cond = icmp ult i5 %mul, 15
call void @llvm.assume(i1 %cond)
%cmp = icmp ult i5 %s, 8
ret i1 %cmp
}

define i1 @branch_mul_nuw_square(i8 %s, i8 %num) {
; CHECK-LABEL: @branch_mul_nuw_square(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], [[NUM:%.*]]
; CHECK-NEXT: br i1 [[COND]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: true:
; CHECK-NEXT: ret i1 true
; CHECK: false:
; CHECK-NEXT: ret i1 true
;
%mul = mul nuw i8 %s, %s
%cond = icmp ule i8 %mul, %num
br i1 %cond, label %true, label %false

true:
%cmp = icmp ult i8 %s, 16
ret i1 %cmp

false:
%cmp2 = icmp ult i8 %s, 16
ret i1 %cmp2
}

; negative test: missing nuw on the multiply.
define i1 @assume_mul_square_no_nuw(i8 %s) {
; CHECK-LABEL: @assume_mul_square_no_nuw(
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[S:%.*]], [[S]]
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
; CHECK-NEXT: ret i1 [[CMP]]
;
%mul = mul i8 %s, %s
%cond = icmp ule i8 %mul, 120
call void @llvm.assume(i1 %cond)
%cmp = icmp ult i8 %s, 16
ret i1 %cmp
}

; negative test: multiply is not a square.
define i1 @assume_mul_nuw_not_square(i8 %s, i8 %t) {
; CHECK-LABEL: @assume_mul_nuw_not_square(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i8 [[S:%.*]], [[T:%.*]]
; CHECK-NEXT: [[COND:%.*]] = icmp ule i8 [[MUL]], 120
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[S]], 16
; CHECK-NEXT: ret i1 [[CMP]]
;
%mul = mul nuw i8 %s, %t
%cond = icmp ule i8 %mul, 120
call void @llvm.assume(i1 %cond)
%cmp = icmp ult i8 %s, 16
ret i1 %cmp
}