@@ -229,7 +229,8 @@ Value MCMC::computeMassMatrixSqrt(OpBuilder &builder, Location loc,
229229std::pair<Value, Value> MCMC::sampleMomentum (OpBuilder &builder, Location loc,
230230 Value rng, Value invMass,
231231 Value massMatrixSqrt,
232- RankedTensorType positionType) {
232+ RankedTensorType positionType,
233+ bool debugDump) {
233234 auto elemType = positionType.getElementType ();
234235 auto scalarType = RankedTensorType::get ({}, elemType);
235236
@@ -244,6 +245,10 @@ std::pair<Value, Value> MCMC::sampleMomentum(OpBuilder &builder, Location loc,
244245 TypeRange{rng.getType ()}, rng);
245246 auto rngForSampling = splitOp.getResult (0 );
246247
248+ rngForSampling =
249+ conditionalDump (builder, loc, rngForSampling,
250+ " sampleMomentum: rng state for sampling" , debugDump);
251+
247252 // Sample eps ~ N(0, I)
248253 auto randomOp = enzyme::RandomOp::create (
249254 builder, loc, TypeRange{rngForSampling.getType (), positionType},
@@ -689,8 +694,9 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q,
689694 auto rngTransition = sampleKernelSplit.getResult (2 );
690695
691696 // 2. Sample fresh momentum p ~ N(0, M)
692- auto [p0, rngAfterMomentum] = sampleMomentum (
693- builder, loc, rngMomentum, ctx.invMass , ctx.massMatrixSqrt , positionType);
697+ auto [p0, rngAfterMomentum] =
698+ sampleMomentum (builder, loc, rngMomentum, ctx.invMass , ctx.massMatrixSqrt ,
699+ positionType, debugDump);
694700
695701 // 3. Compute K0 = 0.5 * p^T * M^-1 * p
696702 auto K0 = computeKineticEnergy (builder, loc, p0, ctx.invMass , positionType);
0 commit comments