@@ -8633,26 +8633,172 @@ void GradientUtils::eraseFictiousPHIs() {
86338633 phis.emplace_back (pair.first , pair.second );
86348634 fictiousPHIs.clear ();
86358635
8636+ auto *i8Ty = Type::getInt8Ty (newFunc->getContext ());
8637+
8638+ // Helper: create an i8 alloca in the entry block that dominates all uses.
8639+ auto createEntryAllocaBytes = [&](PHINode *pp, Value *byteCount) -> Value * {
8640+ IRBuilder<> B (&*pp->getFunction ()->getEntryBlock ().getFirstInsertionPt ());
8641+
8642+ Value *sz = byteCount;
8643+ if (sz->getType ()->isIntegerTy () && sz->getType ()->getIntegerBitWidth () != 64 )
8644+ sz = B.CreateZExtOrTrunc (sz, B.getInt64Ty (), " enzyme.sz64" );
8645+ else if (!sz->getType ()->isIntegerTy ())
8646+ sz = B.CreatePtrToInt (sz, B.getInt64Ty (), " enzyme.sz64" );
8647+
8648+ AllocaInst *AI = B.CreateAlloca (i8Ty, sz, " enzyme_fromstack_alloca" );
8649+ // Alignment is not critical here; 16 is usually safe and helps codegen.
8650+ AI->setAlignment (Align (16 ));
8651+
8652+ Value *rep = AI;
8653+
8654+ // If pointer address spaces differ, cast appropriately.
8655+ if (pp->getType () != AI->getType () && pp->getType ()->isPointerTy () &&
8656+ AI->getType ()->isPointerTy ()) {
8657+ unsigned ppAS = cast<PointerType>(pp->getType ())->getAddressSpace ();
8658+ unsigned aiAS = cast<PointerType>(AI->getType ())->getAddressSpace ();
8659+ if (ppAS != aiAS) {
8660+ rep = B.CreateAddrSpaceCast (AI, pp->getType (), " enzyme.asc" );
8661+ } else {
8662+ rep = B.CreateBitCast (AI, pp->getType (), " enzyme.bc" );
8663+ }
8664+ }
8665+
8666+ return rep;
8667+ };
8668+
8669+ // Helper: handle malloc/calloc tagged !enzyme_fromstack by turning into alloca
8670+ auto tryRematFromstackMalloc = [&](PHINode *pp, Value *orig) -> bool {
8671+ auto *CI = dyn_cast<CallInst>(orig);
8672+ if (!CI)
8673+ return false ;
8674+
8675+ // Must be explicitly tagged as fromstack
8676+ if (!CI->getMetadata (" enzyme_fromstack" ))
8677+ return false ;
8678+
8679+ Function *Callee = CI->getCalledFunction ();
8680+ if (!Callee)
8681+ return false ;
8682+
8683+ StringRef name = Callee->getName ();
8684+ if (name != " malloc" && name != " calloc" )
8685+ return false ;
8686+
8687+ // Map operands into the new function safely (won't assert if already-new)
8688+ IRBuilder<> B (&*pp->getFunction ()->getEntryBlock ().getFirstInsertionPt ());
8689+
8690+ Value *byteCount = nullptr ;
8691+ if (name == " malloc" ) {
8692+ byteCount = getNewIfOriginal (const_cast <Value *>(CI->getArgOperand (0 )));
8693+ if (byteCount->getType ()->isIntegerTy () &&
8694+ byteCount->getType ()->getIntegerBitWidth () != 64 )
8695+ byteCount = B.CreateZExtOrTrunc (byteCount, B.getInt64Ty (), " malloc.sz64" );
8696+ } else { // calloc(nmemb, size)
8697+ Value *nmemb = getNewIfOriginal (const_cast <Value *>(CI->getArgOperand (0 )));
8698+ Value *esize = getNewIfOriginal (const_cast <Value *>(CI->getArgOperand (1 )));
8699+ if (nmemb->getType ()->getIntegerBitWidth () != 64 )
8700+ nmemb = B.CreateZExtOrTrunc (nmemb, B.getInt64Ty (), " calloc.nmemb64" );
8701+ if (esize->getType ()->getIntegerBitWidth () != 64 )
8702+ esize = B.CreateZExtOrTrunc (esize, B.getInt64Ty (), " calloc.esize64" );
8703+ byteCount = B.CreateMul (nmemb, esize, " calloc.bytes" );
8704+ }
8705+
8706+ Value *rep = createEntryAllocaBytes (pp, byteCount);
8707+
8708+ // calloc should be zero-initialized
8709+ if (name == " calloc" ) {
8710+ // rep might be a casted pointer; memset wants an i8*
8711+ Value *raw = rep;
8712+ if (raw->getType () != PointerType::getUnqual (i8Ty) && raw->getType ()->isPointerTy ()) {
8713+ raw = B.CreateBitCast (raw, PointerType::getUnqual (i8Ty), " calloc.i8p" );
8714+ }
8715+ B.CreateMemSet (raw, B.getInt8 (0 ), byteCount, MaybeAlign (16 ));
8716+ }
8717+
8718+ pp->replaceAllUsesWith (rep);
8719+ erase (pp);
8720+ return true ;
8721+ };
8722+
8723+ // Optional: rematerialize “cheap pure” instruction placeholders (the integer
8724+ // case you had before). Keep it conservative.
8725+ auto tryRematCheapPureInst = [&](PHINode *pp, Value *orig) -> bool {
8726+ auto *I = dyn_cast<Instruction>(orig);
8727+ if (!I)
8728+ return false ;
8729+
8730+ if (isa<PHINode>(I) || isa<AllocaInst>(I) || isa<LoadInst>(I) ||
8731+ isa<StoreInst>(I) || isa<CallBase>(I) || isa<InvokeInst>(I) ||
8732+ isa<AtomicRMWInst>(I) || isa<AtomicCmpXchgInst>(I))
8733+ return false ;
8734+
8735+ if (!I->mayReadFromMemory () && !I->mayWriteToMemory () && !I->mayHaveSideEffects ()) {
8736+ Instruction *cl = I->clone ();
8737+ cl->setName (I->getName () + " .remat" );
8738+
8739+ // Insert near the start of the entry block so it dominates all uses.
8740+ cl->insertBefore (&*newFunc->getEntryBlock ().getFirstInsertionPt ());
8741+
8742+ for (unsigned op = 0 ; op < cl->getNumOperands (); ++op) {
8743+ Value *v = cl->getOperand (op);
8744+ cl->setOperand (op, getNewIfOriginal (v));
8745+ }
8746+
8747+ pp->replaceAllUsesWith (cl);
8748+ erase (pp);
8749+ return true ;
8750+ }
8751+
8752+ return false ;
8753+ };
8754+
8755+ // 3) Replace/erase fictitious phis
86368756 for (auto pair : phis) {
8637- auto pp = pair.first ;
8757+ PHINode *pp = pair.first ;
8758+ Value *orig = pair.second ;
8759+
8760+ if (!pp)
8761+ continue ;
8762+
86388763 if (pp->getNumUses () != 0 ) {
8764+ // First: handle fromstack malloc/calloc
8765+ if (tryRematFromstackMalloc (pp, orig)) {
8766+ continue ;
8767+ }
8768+
8769+ // Second: handle cheap pure instruction placeholders (like your old add)
8770+ if (tryRematCheapPureInst (pp, orig)) {
8771+ continue ;
8772+ }
8773+
8774+ // If still used, this is a genuine bug; keep the diagnostics.
86398775 if (CustomErrorHandler) {
86408776 std::string str;
86418777 raw_string_ostream ss (str);
8642- ss << " Illegal replace ficticious phi for: " << *pp << " of "
8643- << *pair.second ;
8644- CustomErrorHandler (str.c_str (), wrap (pair.second ),
8778+ ss << " Illegal replace ficticious phi for: " << *pp << " of " << *orig;
8779+ CustomErrorHandler (str.c_str (), wrap (orig),
86458780 ErrorType::IllegalReplaceFicticiousPHIs, this ,
86468781 wrap (pp), nullptr );
86478782 } else {
86488783 llvm::errs () << " mod:" << *oldFunc->getParent () << " \n " ;
86498784 llvm::errs () << " oldFunc:" << *oldFunc << " \n " ;
86508785 llvm::errs () << " newFunc:" << *newFunc << " \n " ;
8651- llvm::errs () << " pp: " << *pp << " of " << *pair.second << " \n " ;
8652- assert (pp->getNumUses () == 0 );
8786+ llvm::errs () << " pp: " << *pp << " of " << *orig << " \n " ;
8787+
8788+ // Extra helpful debug: show direct users
8789+ for (auto *U : pp->users ()) {
8790+ if (auto *UI = dyn_cast<Instruction>(U)) {
8791+ llvm::errs () << " use: " << *UI << " \n " ;
8792+ } else {
8793+ llvm::errs () << " use: " << *U << " \n " ;
8794+ }
8795+ }
86538796 }
86548797 }
8655- pp->replaceAllUsesWith (UndefValue::get (pp->getType ()));
8798+
8799+ // Last resort: replace with undef (keeps compilation going, but derivative may be wrong)
8800+ if (!pp->use_empty ())
8801+ pp->replaceAllUsesWith (UndefValue::get (pp->getType ()));
86568802 erase (pp);
86578803 }
86588804}
0 commit comments