Skip to content

Commit 53a46ca

Browse files
authored
thread args (#2696)
1 parent 36298bd commit 53a46ca

File tree

4 files changed

+148
-16
lines changed

4 files changed

+148
-16
lines changed

enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,11 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder,
377377
Value qArg = autodiffBlock->getArgument(0);
378378

379379
if (ctx.hasCustomLogpdf()) {
380+
SmallVector<Value> callArgs;
381+
callArgs.push_back(qArg);
382+
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
380383
auto callOp = func::CallOp::create(builder, loc, ctx.logpdfFn,
381-
TypeRange{scalarType}, ValueRange{qArg});
384+
TypeRange{scalarType}, callArgs);
382385
Value U = arith::NegFOp::create(builder, loc, callOp.getResult(0));
383386
// TODO(#2695): handle hybrid case
384387
enzyme::YieldOp::create(builder, loc, {U, rng});
@@ -683,8 +686,11 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
683686

684687
if (ctx.hasCustomLogpdf()) {
685688
q0 = initialPosition;
689+
SmallVector<Value> callArgs;
690+
callArgs.push_back(q0);
691+
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
686692
auto callOp = func::CallOp::create(builder, loc, ctx.logpdfFn,
687-
TypeRange{scalarType}, ValueRange{q0});
693+
TypeRange{scalarType}, callArgs);
688694
U0 = arith::NegFOp::create(builder, loc, callOp.getResult(0));
689695
} else {
690696
auto fullTraceType =
@@ -748,8 +754,11 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng,
748754
auto q0Arg = autodiffInitBlock->getArgument(0);
749755

750756
if (ctx.hasCustomLogpdf()) {
751-
auto callOpInner = func::CallOp::create(
752-
builder, loc, ctx.logpdfFn, TypeRange{scalarType}, ValueRange{q0Arg});
757+
SmallVector<Value> callArgs;
758+
callArgs.push_back(q0Arg);
759+
callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end());
760+
auto callOpInner = func::CallOp::create(builder, loc, ctx.logpdfFn,
761+
TypeRange{scalarType}, callArgs);
753762
auto U0_init =
754763
arith::NegFOp::create(builder, loc, callOpInner.getResult(0));
755764
enzyme::YieldOp::create(builder, loc, {U0_init, rngForAutodiff});

enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ struct HMCContext {
123123
trajectoryLength(trajectoryLength), positionSize(positionSize),
124124
supports(supports.begin(), supports.end()) {}
125125

126-
HMCContext(FlatSymbolRefAttr logpdfFn, Value invMass, Value massMatrixSqrt,
127-
Value stepSize, Value trajectoryLength, int64_t positionSize)
128-
: invMass(invMass), massMatrixSqrt(massMatrixSqrt), stepSize(stepSize),
129-
trajectoryLength(trajectoryLength), positionSize(positionSize),
130-
logpdfFn(logpdfFn) {}
126+
HMCContext(FlatSymbolRefAttr logpdfFn, ArrayRef<Value> fnInputs,
127+
Value invMass, Value massMatrixSqrt, Value stepSize,
128+
Value trajectoryLength, int64_t positionSize)
129+
: fnInputs(fnInputs), invMass(invMass), massMatrixSqrt(massMatrixSqrt),
130+
stepSize(stepSize), trajectoryLength(trajectoryLength),
131+
positionSize(positionSize), logpdfFn(logpdfFn) {}
131132

132133
bool hasCustomLogpdf() const { return logpdfFn != nullptr; }
133134

@@ -180,10 +181,11 @@ struct NUTSContext : public HMCContext {
180181
supports),
181182
H0(H0), maxDeltaEnergy(maxDeltaEnergy), maxTreeDepth(maxTreeDepth) {}
182183

183-
NUTSContext(FlatSymbolRefAttr logpdfFn, Value invMass, Value massMatrixSqrt,
184-
Value stepSize, int64_t positionSize, Value H0,
185-
Value maxDeltaEnergy, int64_t maxTreeDepth)
186-
: HMCContext(logpdfFn, invMass, massMatrixSqrt, stepSize,
184+
NUTSContext(FlatSymbolRefAttr logpdfFn, ArrayRef<Value> fnInputs,
185+
Value invMass, Value massMatrixSqrt, Value stepSize,
186+
int64_t positionSize, Value H0, Value maxDeltaEnergy,
187+
int64_t maxTreeDepth)
188+
: HMCContext(logpdfFn, fnInputs, invMass, massMatrixSqrt, stepSize,
187189
/* Unused trajectoryLength */ Value(), positionSize),
188190
H0(H0), maxDeltaEnergy(maxDeltaEnergy), maxTreeDepth(maxTreeDepth) {}
189191

enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
653653

654654
if (hasLogpdfFn) {
655655
logpdfFnAttr = mcmcOp.getLogpdfFnAttr();
656+
fnInputs.assign(inputs.begin() + 1, inputs.end());
656657
auto initialPos = mcmcOp.getInitialPosition();
657658
positionSize =
658659
cast<RankedTensorType>(initialPos.getType()).getShape()[1];
@@ -735,8 +736,9 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
735736
Value currentMassMatrixSqrt,
736737
Value currentStepSize) -> HMCContext {
737738
if (hasLogpdfFn) {
738-
return HMCContext(logpdfFnAttr, currentInvMass, currentMassMatrixSqrt,
739-
currentStepSize, trajectoryLength, positionSize);
739+
return HMCContext(logpdfFnAttr, fnInputs, currentInvMass,
740+
currentMassMatrixSqrt, currentStepSize,
741+
trajectoryLength, positionSize);
740742
} else {
741743
return HMCContext(
742744
mcmcOp.getFnAttr(), fnInputs, fnResultTypes, originalTrace,
@@ -749,7 +751,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
749751
[&](Value currentInvMass, Value currentMassMatrixSqrt,
750752
Value currentStepSize, Value U) -> NUTSContext {
751753
if (hasLogpdfFn) {
752-
return NUTSContext(logpdfFnAttr, currentInvMass,
754+
return NUTSContext(logpdfFnAttr, fnInputs, currentInvMass,
753755
currentMassMatrixSqrt, currentStepSize,
754756
positionSize, U, maxDeltaEnergy, maxTreeDepth);
755757
} else {

enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,123 @@ module {
5757
: (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
5858
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
5959
}
60+
61+
func.func @shifted_logpdf(%x : tensor<1x2xf64>, %mu : tensor<1x2xf64>) -> tensor<f64> {
62+
%diff = arith.subf %x, %mu : tensor<1x2xf64>
63+
%sum_sq = enzyme.dot %diff, %diff {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
64+
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
65+
%result = arith.mulf %neg_half, %sum_sq : tensor<f64>
66+
return %result : tensor<f64>
67+
}
68+
69+
// CHECK-LABEL: func.func @nuts_shifted_logpdf
70+
// CHECK: call @shifted_logpdf
71+
// CHECK-NEXT: %[[U0:.+]] = arith.negf
72+
// CHECK: enzyme.autodiff_region
73+
// CHECK: func.call @shifted_logpdf
74+
// CHECK-NEXT: %[[NEG:.+]] = arith.negf
75+
// CHECK-NEXT: enzyme.yield
76+
// CHECK: enzyme.for_loop
77+
// CHECK: enzyme.autodiff_region
78+
// CHECK: func.call @shifted_logpdf
79+
// CHECK-NEXT: %{{.+}} = arith.negf
80+
// CHECK-NEXT: enzyme.yield
81+
func.func @nuts_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
82+
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
83+
%step_size = arith.constant dense<0.1> : tensor<f64>
84+
%res:3 = enzyme.mcmc (%rng, %mu)
85+
step_size = %step_size
86+
logpdf_fn = @shifted_logpdf
87+
initial_position = %init_pos
88+
{ nuts_config = #enzyme.nuts_config<max_tree_depth = 3, max_delta_energy = 1000.0, adapt_step_size = false, adapt_mass_matrix = false>,
89+
name = "nuts_shifted_logpdf", selection = [], all_addresses = [], num_warmup = 0, num_samples = 1 }
90+
: (tensor<2xui64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
91+
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
92+
}
93+
94+
// CHECK-LABEL: func.func @hmc_shifted_logpdf
95+
// CHECK: call @shifted_logpdf
96+
// CHECK-NEXT: %{{.+}} = arith.negf
97+
// CHECK: enzyme.autodiff_region
98+
// CHECK: func.call @shifted_logpdf
99+
// CHECK-NEXT: %{{.+}} = arith.negf
100+
// CHECK-NEXT: enzyme.yield
101+
// CHECK: enzyme.for_loop
102+
// CHECK: enzyme.autodiff_region
103+
// CHECK: func.call @shifted_logpdf
104+
// CHECK-NEXT: %{{.+}} = arith.negf
105+
// CHECK-NEXT: enzyme.yield
106+
func.func @hmc_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
107+
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
108+
%step_size = arith.constant dense<0.1> : tensor<f64>
109+
%res:3 = enzyme.mcmc (%rng, %mu)
110+
step_size = %step_size
111+
logpdf_fn = @shifted_logpdf
112+
initial_position = %init_pos
113+
{ hmc_config = #enzyme.hmc_config<trajectory_length = 1.000000e+00 : f64, adapt_step_size = false, adapt_mass_matrix = false>,
114+
name = "hmc_shifted_logpdf", selection = [], all_addresses = [], num_warmup = 0, num_samples = 1 }
115+
: (tensor<2xui64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
116+
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
117+
}
118+
119+
func.func @anisotropic_logpdf(%x : tensor<1x2xf64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> tensor<f64> {
120+
%diff = arith.subf %x, %mu : tensor<1x2xf64>
121+
%diff_sq = arith.mulf %diff, %diff : tensor<1x2xf64>
122+
%weighted = arith.mulf %precision, %diff_sq : tensor<1x2xf64>
123+
%ones = arith.constant dense<1.0> : tensor<1x2xf64>
124+
%sum = enzyme.dot %ones, %weighted {lhs_batching_dimensions = array<i64>, rhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0, 1>, rhs_contracting_dimensions = array<i64: 0, 1>} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<f64>
125+
%neg_half = arith.constant dense<-5.000000e-01> : tensor<f64>
126+
%result = arith.mulf %neg_half, %sum : tensor<f64>
127+
return %result : tensor<f64>
128+
}
129+
130+
// CHECK-LABEL: func.func @nuts_anisotropic_logpdf
131+
// CHECK: call @anisotropic_logpdf
132+
// CHECK-NEXT: %[[U0:.+]] = arith.negf
133+
// CHECK: enzyme.autodiff_region
134+
// CHECK: func.call @anisotropic_logpdf
135+
// CHECK-NEXT: %[[NEG:.+]] = arith.negf
136+
// CHECK-NEXT: enzyme.yield
137+
// CHECK: enzyme.for_loop
138+
// CHECK: enzyme.autodiff_region
139+
// CHECK: func.call @anisotropic_logpdf
140+
// CHECK-NEXT: %{{.+}} = arith.negf
141+
// CHECK-NEXT: enzyme.yield
142+
func.func @nuts_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
143+
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
144+
%step_size = arith.constant dense<0.1> : tensor<f64>
145+
%res:3 = enzyme.mcmc (%rng, %mu, %precision)
146+
step_size = %step_size
147+
logpdf_fn = @anisotropic_logpdf
148+
initial_position = %init_pos
149+
{ nuts_config = #enzyme.nuts_config<max_tree_depth = 3, max_delta_energy = 1000.0, adapt_step_size = false, adapt_mass_matrix = false>,
150+
name = "nuts_anisotropic_logpdf", selection = [], all_addresses = [], num_warmup = 0, num_samples = 1 }
151+
: (tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
152+
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
153+
}
154+
155+
// CHECK-LABEL: func.func @hmc_anisotropic_logpdf
156+
// CHECK: call @anisotropic_logpdf
157+
// CHECK-NEXT: %{{.+}} = arith.negf
158+
// CHECK: enzyme.autodiff_region
159+
// CHECK: func.call @anisotropic_logpdf
160+
// CHECK-NEXT: %{{.+}} = arith.negf
161+
// CHECK-NEXT: enzyme.yield
162+
// CHECK: enzyme.for_loop
163+
// CHECK: enzyme.autodiff_region
164+
// CHECK: func.call @anisotropic_logpdf
165+
// CHECK-NEXT: %{{.+}} = arith.negf
166+
// CHECK-NEXT: enzyme.yield
167+
func.func @hmc_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<1x2xf64>, %precision : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
168+
%init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64>
169+
%step_size = arith.constant dense<0.1> : tensor<f64>
170+
%res:3 = enzyme.mcmc (%rng, %mu, %precision)
171+
step_size = %step_size
172+
logpdf_fn = @anisotropic_logpdf
173+
initial_position = %init_pos
174+
{ hmc_config = #enzyme.hmc_config<trajectory_length = 1.000000e+00 : f64, adapt_step_size = false, adapt_mass_matrix = false>,
175+
name = "hmc_anisotropic_logpdf", selection = [], all_addresses = [], num_warmup = 0, num_samples = 1 }
176+
: (tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
177+
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
178+
}
60179
}

0 commit comments

Comments
 (0)