Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 161 additions & 7 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8633,26 +8633,180 @@ void GradientUtils::eraseFictiousPHIs() {
phis.emplace_back(pair.first, pair.second);
fictiousPHIs.clear();

auto *i8Ty = Type::getInt8Ty(newFunc->getContext());

// Helper: create an i8 alloca in the entry block that dominates all uses.
auto createEntryAllocaBytes = [&](PHINode *pp, Value *byteCount) -> Value * {
IRBuilder<> B(&*pp->getFunction()->getEntryBlock().getFirstInsertionPt());

Value *sz = byteCount;
if (sz->getType()->isIntegerTy() &&
sz->getType()->getIntegerBitWidth() != 64)
sz = B.CreateZExtOrTrunc(sz, B.getInt64Ty(), "enzyme.sz64");
else if (!sz->getType()->isIntegerTy())
sz = B.CreatePtrToInt(sz, B.getInt64Ty(), "enzyme.sz64");

AllocaInst *AI = B.CreateAlloca(i8Ty, sz, "enzyme_fromstack_alloca");
// Alignment is not critical here; 16 is usually safe and helps codegen.
AI->setAlignment(Align(16));

Value *rep = AI;

// If pointer address spaces differ, cast appropriately.
if (pp->getType() != AI->getType() && pp->getType()->isPointerTy() &&
AI->getType()->isPointerTy()) {
unsigned ppAS = cast<PointerType>(pp->getType())->getAddressSpace();
unsigned aiAS = cast<PointerType>(AI->getType())->getAddressSpace();
if (ppAS != aiAS) {
rep = B.CreateAddrSpaceCast(AI, pp->getType(), "enzyme.asc");
} else {
rep = B.CreateBitCast(AI, pp->getType(), "enzyme.bc");
}
}

return rep;
};

// Helper: handle malloc/calloc tagged !enzyme_fromstack by turning into
// alloca
auto tryRematFromstackMalloc = [&](PHINode *pp, Value *orig) -> bool {
auto *CI = dyn_cast<CallInst>(orig);
if (!CI)
return false;

// Must be explicitly tagged as fromstack
if (!CI->getMetadata("enzyme_fromstack"))
return false;

Function *Callee = CI->getCalledFunction();
if (!Callee)
return false;

StringRef name = Callee->getName();
if (name != "malloc" && name != "calloc")
return false;

// Map operands into the new function safely (won't assert if already-new)
IRBuilder<> B(&*pp->getFunction()->getEntryBlock().getFirstInsertionPt());

Value *byteCount = nullptr;
if (name == "malloc") {
byteCount = getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(0)));
if (byteCount->getType()->isIntegerTy() &&
byteCount->getType()->getIntegerBitWidth() != 64)
byteCount =
B.CreateZExtOrTrunc(byteCount, B.getInt64Ty(), "malloc.sz64");
} else { // calloc(nmemb, size)
Value *nmemb =
getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(0)));
Value *esize =
getNewIfOriginal(const_cast<Value *>(CI->getArgOperand(1)));
if (nmemb->getType()->getIntegerBitWidth() != 64)
nmemb = B.CreateZExtOrTrunc(nmemb, B.getInt64Ty(), "calloc.nmemb64");
if (esize->getType()->getIntegerBitWidth() != 64)
esize = B.CreateZExtOrTrunc(esize, B.getInt64Ty(), "calloc.esize64");
byteCount = B.CreateMul(nmemb, esize, "calloc.bytes");
}

Value *rep = createEntryAllocaBytes(pp, byteCount);

// calloc should be zero-initialized
if (name == "calloc") {
// rep might be a casted pointer; memset wants an i8*
Value *raw = rep;
if (raw->getType() != PointerType::getUnqual(i8Ty) &&
raw->getType()->isPointerTy()) {
raw = B.CreateBitCast(raw, PointerType::getUnqual(i8Ty), "calloc.i8p");
}
B.CreateMemSet(raw, B.getInt8(0), byteCount, MaybeAlign(16));
}

pp->replaceAllUsesWith(rep);
erase(pp);
return true;
};

// Optional: rematerialize “cheap pure” instruction placeholders (the integer
// case you had before). Keep it conservative.
auto tryRematCheapPureInst = [&](PHINode *pp, Value *orig) -> bool {
auto *I = dyn_cast<Instruction>(orig);
if (!I)
return false;

if (isa<PHINode>(I) || isa<AllocaInst>(I) || isa<LoadInst>(I) ||
isa<StoreInst>(I) || isa<CallBase>(I) || isa<InvokeInst>(I) ||
isa<AtomicRMWInst>(I) || isa<AtomicCmpXchgInst>(I))
return false;

if (!I->mayReadFromMemory() && !I->mayWriteToMemory() &&
!I->mayHaveSideEffects()) {
Instruction *cl = I->clone();
cl->setName(I->getName() + ".remat");

// Insert near the start of the entry block so it dominates all uses.
cl->insertBefore(&*newFunc->getEntryBlock().getFirstInsertionPt());

for (unsigned op = 0; op < cl->getNumOperands(); ++op) {
Value *v = cl->getOperand(op);
cl->setOperand(op, getNewIfOriginal(v));
}

pp->replaceAllUsesWith(cl);
erase(pp);
return true;
}

return false;
};

// 3) Replace/erase fictitious phis
for (auto pair : phis) {
auto pp = pair.first;
PHINode *pp = pair.first;
Value *orig = pair.second;

if (!pp)
continue;

if (pp->getNumUses() != 0) {
// First: handle fromstack malloc/calloc
if (tryRematFromstackMalloc(pp, orig)) {
continue;
}

// Second: handle cheap pure instruction placeholders (like your old add)
if (tryRematCheapPureInst(pp, orig)) {
continue;
}

// If still used, this is a genuine bug; keep the diagnostics.
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
ss << "Illegal replace ficticious phi for: " << *pp << " of "
<< *pair.second;
CustomErrorHandler(str.c_str(), wrap(pair.second),
ss << "Illegal replace ficticious phi for: " << *pp << " of " << *orig;
CustomErrorHandler(str.c_str(), wrap(orig),
ErrorType::IllegalReplaceFicticiousPHIs, this,
wrap(pp), nullptr);
} else {
llvm::errs() << "mod:" << *oldFunc->getParent() << "\n";
llvm::errs() << "oldFunc:" << *oldFunc << "\n";
llvm::errs() << "newFunc:" << *newFunc << "\n";
llvm::errs() << " pp: " << *pp << " of " << *pair.second << "\n";
assert(pp->getNumUses() == 0);
llvm::errs() << " pp: " << *pp << " of " << *orig << "\n";

// Extra helpful debug: show direct users
for (auto *U : pp->users()) {
if (auto *UI = dyn_cast<Instruction>(U)) {
llvm::errs() << " use: " << *UI << "\n";
} else {
llvm::errs() << " use: " << *U << "\n";
}
}
}
}
pp->replaceAllUsesWith(UndefValue::get(pp->getType()));

// Last resort: replace with undef (keeps compilation going, but derivative
// may be wrong)
if (!pp->use_empty())
pp->replaceAllUsesWith(UndefValue::get(pp->getType()));
erase(pp);
}
}
Expand Down
Loading