Skip to content

Commit d582844

Browse files
committed
minor
1 parent 07e9c5f commit d582844

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ Value MCMC::computeMassMatrixSqrt(OpBuilder &builder, Location loc,
229229
std::pair<Value, Value> MCMC::sampleMomentum(OpBuilder &builder, Location loc,
230230
Value rng, Value invMass,
231231
Value massMatrixSqrt,
232-
RankedTensorType positionType) {
232+
RankedTensorType positionType,
233+
bool debugDump) {
233234
auto elemType = positionType.getElementType();
234235
auto scalarType = RankedTensorType::get({}, elemType);
235236

@@ -244,6 +245,10 @@ std::pair<Value, Value> MCMC::sampleMomentum(OpBuilder &builder, Location loc,
244245
TypeRange{rng.getType()}, rng);
245246
auto rngForSampling = splitOp.getResult(0);
246247

248+
rngForSampling =
249+
conditionalDump(builder, loc, rngForSampling,
250+
"sampleMomentum: rng state for sampling", debugDump);
251+
247252
// Sample eps ~ N(0, I)
248253
auto randomOp = enzyme::RandomOp::create(
249254
builder, loc, TypeRange{rngForSampling.getType(), positionType},
@@ -689,8 +694,9 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q,
689694
auto rngTransition = sampleKernelSplit.getResult(2);
690695

691696
// 2. Sample fresh momentum p ~ N(0, M)
692-
auto [p0, rngAfterMomentum] = sampleMomentum(
693-
builder, loc, rngMomentum, ctx.invMass, ctx.massMatrixSqrt, positionType);
697+
auto [p0, rngAfterMomentum] =
698+
sampleMomentum(builder, loc, rngMomentum, ctx.invMass, ctx.massMatrixSqrt,
699+
positionType, debugDump);
694700

695701
// 3. Compute K0 = 0.5 * p^T * M^-1 * p
696702
auto K0 = computeKineticEnergy(builder, loc, p0, ctx.invMass, positionType);

enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ Value computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass,
195195
std::pair<Value, Value> sampleMomentum(OpBuilder &builder, Location loc,
196196
Value rng, Value invMass,
197197
Value massMatrixSqrt,
198-
RankedTensorType positionType);
198+
RankedTensorType positionType,
199+
bool debugDump = false);
199200

200201
/// Computes potential energy `U(q) = -log p(q)` and its gradient `dU/dq`
201202
GradientResult computePotentialAndGradient(OpBuilder &builder, Location loc,

0 commit comments

Comments
 (0)