Skip to content

Commit 9d7d668

Browse files
committed
fix the crash in eraseFictiousPHIs
1 parent 53a46ca commit 9d7d668

File tree

1 file changed

+153
-7
lines changed

1 file changed

+153
-7
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 153 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8633,26 +8633,172 @@ void GradientUtils::eraseFictiousPHIs() {
86338633
phis.emplace_back(pair.first, pair.second);
86348634
fictiousPHIs.clear();
86358635

8636+
auto *i8Ty = Type::getInt8Ty(newFunc->getContext());
8637+
8638+
// Helper: create an i8 alloca in the entry block that dominates all uses.
8639+
auto createEntryAllocaBytes = [&](PHINode *pp, Value *byteCount) -> Value * {
8640+
IRBuilder<> B(&*pp->getFunction()->getEntryBlock().getFirstInsertionPt());
8641+
8642+
Value *sz = byteCount;
8643+
if (sz->getType()->isIntegerTy() && sz->getType()->getIntegerBitWidth() != 64)
8644+
sz = B.CreateZExtOrTrunc(sz, B.getInt64Ty(), "enzyme.sz64");
8645+
else if (!sz->getType()->isIntegerTy())
8646+
sz = B.CreatePtrToInt(sz, B.getInt64Ty(), "enzyme.sz64");
8647+
8648+
AllocaInst *AI = B.CreateAlloca(i8Ty, sz, "enzyme_fromstack_alloca");
8649+
// Alignment is not critical here; 16 is usually safe and helps codegen.
8650+
AI->setAlignment(Align(16));
8651+
8652+
Value *rep = AI;
8653+
8654+
// If pointer address spaces differ, cast appropriately.
8655+
if (pp->getType() != AI->getType() && pp->getType()->isPointerTy() &&
8656+
AI->getType()->isPointerTy()) {
8657+
unsigned ppAS = cast<PointerType>(pp->getType())->getAddressSpace();
8658+
unsigned aiAS = cast<PointerType>(AI->getType())->getAddressSpace();
8659+
if (ppAS != aiAS) {
8660+
rep = B.CreateAddrSpaceCast(AI, pp->getType(), "enzyme.asc");
8661+
} else {
8662+
rep = B.CreateBitCast(AI, pp->getType(), "enzyme.bc");
8663+
}
8664+
}
8665+
8666+
return rep;
8667+
};
8668+
8669+
// Helper: handle malloc/calloc tagged !enzyme_fromstack by turning into alloca
8670+
auto tryRematFromstackMalloc = [&](PHINode *pp, Value *orig) -> bool {
8671+
auto *CI = dyn_cast<CallInst>(orig);
8672+
if (!CI)
8673+
return false;
8674+
8675+
// Must be explicitly tagged as fromstack
8676+
if (!CI->getMetadata("enzyme_fromstack"))
8677+
return false;
8678+
8679+
Function *Callee = CI->getCalledFunction();
8680+
if (!Callee)
8681+
return false;
8682+
8683+
StringRef name = Callee->getName();
8684+
if (name != "malloc" && name != "calloc")
8685+
return false;
8686+
8687+
// Map operands into the new function safely (won't assert if already-new)
8688+
IRBuilder<> B(&*pp->getFunction()->getEntryBlock().getFirstInsertionPt());
8689+
8690+
Value *byteCount = nullptr;
8691+
if (name == "malloc") {
8692+
byteCount = getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(0)));
8693+
if (byteCount->getType()->isIntegerTy() &&
8694+
byteCount->getType()->getIntegerBitWidth() != 64)
8695+
byteCount = B.CreateZExtOrTrunc(byteCount, B.getInt64Ty(), "malloc.sz64");
8696+
} else { // calloc(nmemb, size)
8697+
Value *nmemb = getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(0)));
8698+
Value *esize = getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(1)));
8699+
if (nmemb->getType()->getIntegerBitWidth() != 64)
8700+
nmemb = B.CreateZExtOrTrunc(nmemb, B.getInt64Ty(), "calloc.nmemb64");
8701+
if (esize->getType()->getIntegerBitWidth() != 64)
8702+
esize = B.CreateZExtOrTrunc(esize, B.getInt64Ty(), "calloc.esize64");
8703+
byteCount = B.CreateMul(nmemb, esize, "calloc.bytes");
8704+
}
8705+
8706+
Value *rep = createEntryAllocaBytes(pp, byteCount);
8707+
8708+
// calloc should be zero-initialized
8709+
if (name == "calloc") {
8710+
// rep might be a casted pointer; memset wants an i8*
8711+
Value *raw = rep;
8712+
if (raw->getType() != PointerType::getUnqual(i8Ty) && raw->getType()->isPointerTy()) {
8713+
raw = B.CreateBitCast(raw, PointerType::getUnqual(i8Ty), "calloc.i8p");
8714+
}
8715+
B.CreateMemSet(raw, B.getInt8(0), byteCount, MaybeAlign(16));
8716+
}
8717+
8718+
pp->replaceAllUsesWith(rep);
8719+
erase(pp);
8720+
return true;
8721+
};
8722+
8723+
// Optional: rematerialize “cheap pure” instruction placeholders (the integer
8724+
// case you had before). Keep it conservative.
8725+
auto tryRematCheapPureInst = [&](PHINode *pp, Value *orig) -> bool {
8726+
auto *I = dyn_cast<Instruction>(orig);
8727+
if (!I)
8728+
return false;
8729+
8730+
if (isa<PHINode>(I) || isa<AllocaInst>(I) || isa<LoadInst>(I) ||
8731+
isa<StoreInst>(I) || isa<CallBase>(I) || isa<InvokeInst>(I) ||
8732+
isa<AtomicRMWInst>(I) || isa<AtomicCmpXchgInst>(I))
8733+
return false;
8734+
8735+
if (!I->mayReadFromMemory() && !I->mayWriteToMemory() && !I->mayHaveSideEffects()) {
8736+
Instruction *cl = I->clone();
8737+
cl->setName(I->getName() + ".remat");
8738+
8739+
// Insert near the start of the entry block so it dominates all uses.
8740+
cl->insertBefore(&*newFunc->getEntryBlock().getFirstInsertionPt());
8741+
8742+
for (unsigned op = 0; op < cl->getNumOperands(); ++op) {
8743+
Value *v = cl->getOperand(op);
8744+
cl->setOperand(op, getNewIfOriginal(v));
8745+
}
8746+
8747+
pp->replaceAllUsesWith(cl);
8748+
erase(pp);
8749+
return true;
8750+
}
8751+
8752+
return false;
8753+
};
8754+
8755+
// 3) Replace/erase fictitious phis
86368756
for (auto pair : phis) {
8637-
auto pp = pair.first;
8757+
PHINode *pp = pair.first;
8758+
Value *orig = pair.second;
8759+
8760+
if (!pp)
8761+
continue;
8762+
86388763
if (pp->getNumUses() != 0) {
8764+
// First: handle fromstack malloc/calloc
8765+
if (tryRematFromstackMalloc(pp, orig)) {
8766+
continue;
8767+
}
8768+
8769+
// Second: handle cheap pure instruction placeholders (like your old add)
8770+
if (tryRematCheapPureInst(pp, orig)) {
8771+
continue;
8772+
}
8773+
8774+
// If still used, this is a genuine bug; keep the diagnostics.
86398775
if (CustomErrorHandler) {
86408776
std::string str;
86418777
raw_string_ostream ss(str);
8642-
ss << "Illegal replace ficticious phi for: " << *pp << " of "
8643-
<< *pair.second;
8644-
CustomErrorHandler(str.c_str(), wrap(pair.second),
8778+
ss << "Illegal replace ficticious phi for: " << *pp << " of " << *orig;
8779+
CustomErrorHandler(str.c_str(), wrap(orig),
86458780
ErrorType::IllegalReplaceFicticiousPHIs, this,
86468781
wrap(pp), nullptr);
86478782
} else {
86488783
llvm::errs() << "mod:" << *oldFunc->getParent() << "\n";
86498784
llvm::errs() << "oldFunc:" << *oldFunc << "\n";
86508785
llvm::errs() << "newFunc:" << *newFunc << "\n";
8651-
llvm::errs() << " pp: " << *pp << " of " << *pair.second << "\n";
8652-
assert(pp->getNumUses() == 0);
8786+
llvm::errs() << " pp: " << *pp << " of " << *orig << "\n";
8787+
8788+
// Extra helpful debug: show direct users
8789+
for (auto *U : pp->users()) {
8790+
if (auto *UI = dyn_cast<Instruction>(U)) {
8791+
llvm::errs() << " use: " << *UI << "\n";
8792+
} else {
8793+
llvm::errs() << " use: " << *U << "\n";
8794+
}
8795+
}
86538796
}
86548797
}
8655-
pp->replaceAllUsesWith(UndefValue::get(pp->getType()));
8798+
8799+
// Last resort: replace with undef (keeps compilation going, but derivative may be wrong)
8800+
if (!pp->use_empty())
8801+
pp->replaceAllUsesWith(UndefValue::get(pp->getType()));
86568802
erase(pp);
86578803
}
86588804
}

0 commit comments

Comments
 (0)