From 73d741cded897ba867b1e226521c2fe3b9589ba1 Mon Sep 17 00:00:00 2001 From: Nicholas Mosier Date: Sat, 20 Jan 2024 17:28:41 +0000 Subject: [PATCH] fix bug --- llvm/lib/CodeGen/SafeStack.cpp | 121 ++++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/llvm/lib/CodeGen/SafeStack.cpp b/llvm/lib/CodeGen/SafeStack.cpp index 0a26247a4d16..6b2fc20b0364 100644 --- a/llvm/lib/CodeGen/SafeStack.cpp +++ b/llvm/lib/CodeGen/SafeStack.cpp @@ -146,6 +146,12 @@ class SafeStack { SmallVectorImpl &Returns, SmallVectorImpl &StackRestorePoints); + // Wrap the return value of all llvm.stacksave intrinsics inside a struct + // that additionally contains the current value of the unsafe stack pointer. + // Unwrap the call stack pointer and unsafe pointer before each llvm.stacksave, + // llvm.stackrestore, and llvm.eh.sjlj.longjmp. + void instrumentStackSaves(Function &F, Value *UnsafeStackPtr); + /// Calculate the allocation size of a given alloca. Returns 0 if the /// size can not be statically determined. uint64_t getStaticAllocaAllocationSize(const AllocaInst* AI); @@ -651,6 +657,93 @@ Value *SafeStack::moveStaticAllocasToUnsafeStack( return StaticTop; } +void SafeStack::instrumentStackSaves(Function &F, Value *UnsafeStackPtr) { + // This routine instruments llvm.stacksave intrinsics so that they work with SafeStack. + // llvm.stacksave is supposed to remember the current state of the function stack. + // Generally, this entails simply returning the current value of the stack pointer. + // However, under SafeStacks, we have *two* stack pointers: the call (i.e., safe) stack pointer and the unsafe stack pointer. + // In general, under SafeStacks, the return value of llvm.stacksave needs to remember the values of *both* of these. + // + // We achieve this with the following. + // Before each llvm.stacksave, we allocate a new anonymous struct {ptr, ptr} on the stack. + // We store (a) the return value of llvm.stacksave in the first member and (b) the current unsafe stack poiner in the second member. + // We then replace all uses of llvm.stacksave's return value with a pointer to the struct. + // + // However, llvm.stackrestore and llvm.eh.sjlj.longjmp (the only two consumers of llvm.stacksave I am aware of) + // of course still expect the call stack pointer as their argument, not a pointer to our pair of stack pointers. + // Thus, we must unwrap the struct, pass the call stack pointer directly to the consumer intrinsic, and insert + // a store to the unsafe stack pointer global variable to properly restore the unsafe stack state. + + // Identify the list of intrinsics requiring instrumentation + SmallVector Worklist; + for (Instruction &I : instructions(&F)) { + auto *II = dyn_cast(&I); + if (!II) + continue; + switch (II->getIntrinsicID()) { + case Intrinsic::stacksave: + // NHM-FIXME +#if defined(EXPENSIVE_CHECKS) || 1 + // Validate our assumption that only llvm.stackrestore intrinsics use llvm.stacksave directly. + assert(none_of(II->uses(), [] (const Use& U) { + auto *UseII = dyn_cast(U.getUser()); + return UseII && UseII->getIntrinsicID() != Intrinsic::stackrestore; + })); +#endif + case Intrinsic::stackrestore: + case Intrinsic::eh_sjlj_longjmp: + Worklist.push_back(II); + break; + } + } + + // Wrapper struct type for storing both the call stack pointer as well as the unsafe stack pointer. + auto *StackPtrPairTy = StructType::get(StackPtrTy, StackPtrTy); + + for (IntrinsicInst *II : Worklist) { + if (II->getIntrinsicID() == Intrinsic::stacksave) { + // For llvm.stacksave intrinsics, wrap the return value (the call stack pointer) with the unsafe stack pointer + // in the anonymous struct, and replace subsequent uses with this struct. + IRBuilder<> IRB(II->getParent(), std::next(II->getIterator())); + AllocaInst *StackPtrPair = IRB.CreateAlloca(StackPtrPairTy, nullptr, "stacksave_pair"); + Value *FirstPtr = IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(0)}, "stacksave_first", true); + auto *Store = IRB.CreateStore(II, FirstPtr); + Value *SecondPtr = IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(1)}, "stacksave_second", true); + LoadInst *LI = IRB.CreateLoad(StackPtrTy, UnsafeStackPtr, "stacksave_unsafe_stack_ptr"); + IRB.CreateStore(LI, SecondPtr); + II->replaceUsesWithIf(StackPtrPair, [Store] (const Use &U) { return U.getUser() != Store; }); + } else if (II->getIntrinsicID() == Intrinsic::stackrestore) { + // For llvm.stackrestore intrinsics, unwrap the stack pointer argument, which should be of our wrapper struct type + // containing both the call stack pointer as well as the unsafe stack pointer. Pass the call stack pointer to llvm.stackrestore and + // update the unsafe stack pointer global variable with a regular store. + IRBuilder<> IRB(II); + Value *StackPtrPair = II->getArgOperand(0); + Value *FirstPtr = IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(0)}); + Value *FirstVal = IRB.CreateLoad(StackPtrTy, FirstPtr); + Value *SecondPtr = IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(1)}); + Value *SecondVal = IRB.CreateLoad(StackPtrTy, SecondPtr); + II->setArgOperand(0, FirstVal); + IRB.CreateStore(SecondVal, UnsafeStackPtr); + } else if (II->getIntrinsicID() == Intrinsic::eh_sjlj_longjmp) { + // For llvm.eh.sjlj.longjmp instructions, the argument is a jump buffer (void *[5]). + // It expects the call stack pointer at index 2. However, our instrumentation of llvm.stacksave results in it being a pointer to our stack pointer struct. + // We convert it back to its canonical form by (a) extracting the call stack pointer from our struct, (b) storing it into the jump buffer at index 2, + // and (c) storing the unsafe stack pointer from our struct into the unsafe stack pointer global variable. + IRBuilder<> IRB(II); + Value *Buf = II->getArgOperand(0); + ArrayType *BufTy = ArrayType::get(IRB.getPtrTy(), 5); + Value *BufStackElt = IRB.CreateGEP(BufTy, Buf, {IRB.getInt32(0), IRB.getInt32(2)}, "", true); + Value *StackPtrPair = IRB.CreateLoad(IRB.getPtrTy(), BufStackElt); + LoadInst *FirstVal = IRB.CreateLoad(StackPtrTy, IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(0)})); + LoadInst *SecondVal = IRB.CreateLoad(StackPtrTy, IRB.CreateGEP(StackPtrPairTy, StackPtrPair, {IRB.getInt32(0), IRB.getInt32(1)})); + IRB.CreateStore(FirstVal, BufStackElt); + IRB.CreateStore(SecondVal, UnsafeStackPtr); + } else { + llvm_unreachable("Unexpected intrinsic!"); + } + } +} + void SafeStack::moveDynamicAllocasToUnsafeStack( Function &F, Value *UnsafeStackPtr, AllocaInst *DynamicTop, ArrayRef DynamicAllocas) { @@ -694,29 +787,6 @@ void SafeStack::moveDynamicAllocasToUnsafeStack( AI->replaceAllUsesWith(NewAI); AI->eraseFromParent(); } - - if (!DynamicAllocas.empty()) { - // Now go through the instructions again, replacing stacksave/stackrestore. - for (Instruction &I : llvm::make_early_inc_range(instructions(&F))) { - auto *II = dyn_cast(&I); - if (!II) - continue; - - if (II->getIntrinsicID() == Intrinsic::stacksave) { - IRBuilder<> IRB(II); - Instruction *LI = IRB.CreateLoad(StackPtrTy, UnsafeStackPtr); - LI->takeName(II); - II->replaceAllUsesWith(LI); - II->eraseFromParent(); - } else if (II->getIntrinsicID() == Intrinsic::stackrestore) { - IRBuilder<> IRB(II); - Instruction *SI = IRB.CreateStore(II->getArgOperand(0), UnsafeStackPtr); - SI->takeName(II); - assert(II->use_empty()); - II->eraseFromParent(); - } - } - } } bool SafeStack::ShouldInlinePointerAddress(CallInst &CI) { @@ -768,6 +838,10 @@ bool SafeStack::run() { // instrumentation to restore the unsafe stack pointer when necessary. SmallVector StackRestorePoints; + IRBuilder<> IRB(&F.front(), F.begin()->getFirstInsertionPt()); + + instrumentStackSaves(F, TL.getSafeStackPointerLocation(IRB)); + // Find all static and dynamic alloca instructions that must be moved to the // unsafe stack, all return instructions and stack restore points. findInsts(F, StaticAllocas, DynamicAllocas, ByValArguments, Returns, @@ -784,7 +858,6 @@ bool SafeStack::run() { if (!StackRestorePoints.empty()) ++NumUnsafeStackRestorePointsFunctions; - IRBuilder<> IRB(&F.front(), F.begin()->getFirstInsertionPt()); // Calls must always have a debug location, or else inlining breaks. So // we explicitly set a artificial debug location here. if (DISubprogram *SP = F.getSubprogram()) -- 2.34.1