From 902c43a9397355a18be42ca995229f7f9483d9c9 Mon Sep 17 00:00:00 2001 From: Peter Collingbourne Date: Fri, 5 Dec 2025 15:01:55 -0800 Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20change?= =?UTF-8?q?s=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.6-beta.1 [skip ci] --- llvm/lib/Transforms/Scalar/SROA.cpp | 83 +++++++++++++++++-- .../SROA/protected-field-pointer.ll | 73 ++++++++++++++++ 2 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 llvm/test/Transforms/SROA/protected-field-pointer.ll diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 3a70830cf8c0e..1102699aa04e9 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -535,9 +536,11 @@ class Slice { public: Slice() = default; - Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable) + Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable, + Value *ProtectedFieldDisc) : BeginOffset(BeginOffset), EndOffset(EndOffset), - UseAndIsSplittable(U, IsSplittable) {} + UseAndIsSplittable(U, IsSplittable), + ProtectedFieldDisc(ProtectedFieldDisc) {} uint64_t beginOffset() const { return BeginOffset; } uint64_t endOffset() const { return EndOffset; } @@ -550,6 +553,10 @@ class Slice { bool isDead() const { return getUse() == nullptr; } void kill() { UseAndIsSplittable.setPointer(nullptr); } + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc; + /// Support for ordering ranges. /// /// This provides an ordering over ranges such that start offsets are @@ -641,6 +648,10 @@ class AllocaSlices { /// Access the dead users for this alloca. ArrayRef getDeadUsers() const { return DeadUsers; } + /// Access the users for this alloca that are llvm.protected.field.ptr + /// intrinsics. + ArrayRef getPFPUsers() const { return PFPUsers; } + /// Access Uses that should be dropped if the alloca is promotable. ArrayRef getDeadUsesIfPromotable() const { return DeadUseIfPromotable; @@ -701,6 +712,10 @@ class AllocaSlices { /// they come from outside of the allocated space. SmallVector DeadUsers; + /// Users that are llvm.protected.field.ptr intrinsics. These will be RAUW'd + /// to their first argument if we rewrite the alloca. + SmallVector PFPUsers; + /// Uses which will become dead if can promote the alloca. SmallVector DeadUseIfPromotable; @@ -1029,6 +1044,10 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { /// Set to de-duplicate dead instructions found in the use walk. SmallPtrSet VisitedDeadInsts; + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc = nullptr; + public: SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) : PtrUseVisitor(DL), @@ -1074,7 +1093,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { EndOffset = AllocSize; } - AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); + AS.Slices.push_back( + Slice(BeginOffset, EndOffset, U, IsSplittable, ProtectedFieldDisc)); } void visitBitCastInst(BitCastInst &BC) { @@ -1274,6 +1294,27 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { return; } + if (II.getIntrinsicID() == Intrinsic::protected_field_ptr) { + // We only handle loads and stores as users of llvm.protected.field.ptr. + // Other uses may add items to the worklist, which will cause + // ProtectedFieldDisc to be tracked incorrectly. + AS.PFPUsers.push_back(&II); + ProtectedFieldDisc = II.getArgOperand(1); + for (Use &U : II.uses()) { + this->U = &U; + if (auto *LI = dyn_cast(U.getUser())) + visitLoadInst(*LI); + else if (auto *SI = dyn_cast(U.getUser())) + visitStoreInst(*SI); + else + PI.setAborted(&II); + if (PI.isAborted()) + break; + } + ProtectedFieldDisc = nullptr; + return; + } + Base::visitIntrinsicInst(II); } @@ -4948,7 +4989,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PLoad << "\n"); @@ -5104,10 +5145,12 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { LLVMContext::MD_access_group}); // Now build a new slice for the alloca. + // ProtectedFieldDisc==nullptr is a lie, but it doesn't matter because we + // already determined that all accesses are consistent. NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PStore->getOperandUse(PStore->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PStore << "\n"); @@ -5875,6 +5918,30 @@ SROA::runOnAlloca(AllocaInst &AI) { return {Changed, CFGChanged}; } + for (auto &P : AS.partitions()) { + // For now, we can't split if a field is accessed both via protected field + // and not, because that would mean that we would need to introduce sign and + // auth operations to convert between the protected and non-protected uses, + // and this pass doesn't know how to do that. Also, this case is unlikely to + // occur in normal code. + std::optional ProtectedFieldDisc; + auto SliceHasMismatch = [&](Slice &S) { + if (auto *II = dyn_cast(S.getUse()->getUser())) + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + return false; + if (!ProtectedFieldDisc) + ProtectedFieldDisc = S.ProtectedFieldDisc; + return *ProtectedFieldDisc != S.ProtectedFieldDisc; + }; + for (Slice &S : P) + if (SliceHasMismatch(S)) + return {Changed, CFGChanged}; + for (Slice *S : P.splitSliceTails()) + if (SliceHasMismatch(*S)) + return {Changed, CFGChanged}; + } + // Delete all the dead users of this alloca before splitting and rewriting it. for (Instruction *DeadUser : AS.getDeadUsers()) { // Free up everything used by this instruction. @@ -5892,6 +5959,12 @@ SROA::runOnAlloca(AllocaInst &AI) { clobberUse(*DeadOp); Changed = true; } + for (IntrinsicInst *PFPUser : AS.getPFPUsers()) { + PFPUser->replaceAllUsesWith(PFPUser->getArgOperand(0)); + + DeadInsts.push_back(PFPUser); + Changed = true; + } // No slices to split. Leave the dead alloca for a later pass to clean up. if (AS.begin() == AS.end()) diff --git a/llvm/test/Transforms/SROA/protected-field-pointer.ll b/llvm/test/Transforms/SROA/protected-field-pointer.ll new file mode 100644 index 0000000000000..e0e3ce0435c78 --- /dev/null +++ b/llvm/test/Transforms/SROA/protected-field-pointer.ll @@ -0,0 +1,73 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=sroa -S < %s | FileCheck %s + +define void @slice(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) { +; CHECK-LABEL: define void @slice( +; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) { +; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8 +; CHECK-NEXT: store ptr [[PTR2]], ptr [[OUT2]], align 8 +; CHECK-NEXT: ret void +; + %alloca = alloca { ptr, ptr } + + %protptrptr1.1 = call ptr @llvm.protected.field.ptr.p0(ptr %alloca, i64 1, i1 true) + store ptr %ptr1, ptr %protptrptr1.1 + %protptrptr1.2 = call ptr @llvm.protected.field.ptr.p0(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + %gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1 + %protptrptr2.1 = call ptr @llvm.protected.field.ptr.p0(ptr %gep, i64 2, i1 true) + store ptr %ptr2, ptr %protptrptr2.1 + %protptrptr2.2 = call ptr @llvm.protected.field.ptr.p0(ptr %gep, i64 2, i1 true) + %ptr2a = load ptr, ptr %protptrptr2.2 + + store ptr %ptr1a, ptr %out1 + store ptr %ptr2a, ptr %out2 + ret void +} + +define ptr @mixed(ptr %ptr) { +; CHECK-LABEL: define ptr @mixed( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[ALLOCA:%.*]] = alloca ptr, align 8 +; CHECK-NEXT: store ptr [[PTR]], ptr [[ALLOCA]], align 8 +; CHECK-NEXT: [[PROTPTRPTR1_2:%.*]] = call ptr @llvm.protected.field.ptr.p0(ptr [[ALLOCA]], i64 1, i1 true) +; CHECK-NEXT: [[PTR1A:%.*]] = load ptr, ptr [[PROTPTRPTR1_2]], align 8 +; CHECK-NEXT: ret ptr [[PTR1A]] +; + %alloca = alloca ptr + + store ptr %ptr, ptr %alloca + %protptrptr1.2 = call ptr @llvm.protected.field.ptr.p0(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + ret ptr %ptr1a +} + +define void @split_non_promotable(ptr %ptr1, ptr %ptr2, ptr %out1, ptr %out2) { +; CHECK-LABEL: define void @split_non_promotable( +; CHECK-SAME: ptr [[PTR1:%.*]], ptr [[PTR2:%.*]], ptr [[OUT1:%.*]], ptr [[OUT2:%.*]]) { +; CHECK-NEXT: [[ALLOCA_SROA_2:%.*]] = alloca ptr, align 8 +; CHECK-NEXT: store volatile ptr [[PTR2]], ptr [[ALLOCA_SROA_2]], align 8 +; CHECK-NEXT: [[PTR2A:%.*]] = load volatile ptr, ptr [[ALLOCA_SROA_2]], align 8 +; CHECK-NEXT: store ptr [[PTR1]], ptr [[OUT1]], align 8 +; CHECK-NEXT: store ptr [[PTR2A]], ptr [[OUT2]], align 8 +; CHECK-NEXT: ret void +; + %alloca = alloca { ptr, ptr } + + %protptrptr1.1 = call ptr @llvm.protected.field.ptr.p0(ptr %alloca, i64 1, i1 true) + store ptr %ptr1, ptr %protptrptr1.1 + %protptrptr1.2 = call ptr @llvm.protected.field.ptr.p0(ptr %alloca, i64 1, i1 true) + %ptr1a = load ptr, ptr %protptrptr1.2 + + %gep = getelementptr { ptr, ptr }, ptr %alloca, i64 0, i32 1 + %protptrptr2.1 = call ptr @llvm.protected.field.ptr.p0(ptr %gep, i64 2, i1 true) + store volatile ptr %ptr2, ptr %protptrptr2.1 + %protptrptr2.2 = call ptr @llvm.protected.field.ptr.p0(ptr %gep, i64 2, i1 true) + %ptr2a = load volatile ptr, ptr %protptrptr2.2 + + store ptr %ptr1a, ptr %out1 + store ptr %ptr2a, ptr %out2 + ret void +}