Skip to content

Commit 36298bd

Browse files
authored
ProbProg: Raw MCMC (#2695)
* logpdf only mode * mcmc op verify
1 parent 127e998 commit 36298bd

File tree

6 files changed

+373
-166
lines changed

6 files changed

+373
-166
lines changed

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,12 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
683683
let description = [{
684684
Runs MCMC inference on selected addresses.
685685

686+
Two modes of operation:
687+
1. Trace-based mode: `fn` and `original_trace` are provided. The model
688+
function with `enzyme.sample` ops defines the density.
689+
2. Custom logpdf mode: `logpdf_fn` and `initial_position` are provided.
690+
The logpdf function maps position → scalar log-density directly.
691+
686692
The `selection` attribute determines which addresses to sample via HMC/NUTS.
687693
All sample addresses are included in the trace tensor for consistency.
688694

@@ -693,9 +699,9 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
693699
}];
694700

695701
let arguments = (ins
696-
FlatSymbolRefAttr:$fn,
702+
OptionalAttr<FlatSymbolRefAttr>:$fn,
697703
Variadic<AnyType>:$inputs,
698-
AnyRankedTensor:$original_trace,
704+
Optional<AnyRankedTensor>:$original_trace,
699705
AddressArrayAttr:$selection,
700706
AddressArrayAttr:$all_addresses,
701707

@@ -711,6 +717,10 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
711717
OptionalAttr<HMCConfigAttr>:$hmc_config,
712718
OptionalAttr<NUTSConfigAttr>:$nuts_config,
713719

720+
// Custom logpdf mode
721+
OptionalAttr<FlatSymbolRefAttr>:$logpdf_fn,
722+
Optional<AnyRankedTensor>:$initial_position,
723+
714724
DefaultValuedStrAttr<StrAttr, "">:$name
715725
);
716726

@@ -721,9 +731,13 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
721731
);
722732

723733
let assemblyFormat = [{
724-
$fn `(` $inputs `)` `given` $original_trace
734+
($fn^)?
735+
`(` $inputs `)`
736+
(`given` $original_trace^)?
725737
(`inverse_mass_matrix` `=` $inverse_mass_matrix^)?
726738
(`step_size` `=` $step_size^)?
739+
(`logpdf_fn` `=` $logpdf_fn^)?
740+
(`initial_position` `=` $initial_position^)?
727741
attr-dict `:` functional-type(operands, results)
728742
}];
729743

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,13 +1062,21 @@ LogicalResult MHOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10621062
//===----------------------------------------------------------------------===//
10631063

10641064
LogicalResult MCMCOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1065-
// TODO: Verify that the result type is same as the type of the referenced
1066-
// func.func op.
1067-
auto global =
1068-
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
1069-
if (!global)
1070-
return emitOpError("'")
1071-
<< getFn() << "' does not reference a valid global funcOp";
1065+
if (auto fnAttr = getFnAttr()) {
1066+
auto global =
1067+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, fnAttr);
1068+
if (!global)
1069+
return emitOpError("'")
1070+
<< getFn().value() << "' does not reference a valid global funcOp";
1071+
}
1072+
1073+
if (auto logpdfAttr = getLogpdfFnAttr()) {
1074+
auto global =
1075+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, logpdfAttr);
1076+
if (!global)
1077+
return emitOpError("'") << logpdfAttr.getValue()
1078+
<< "' does not reference a valid global funcOp";
1079+
}
10721080

10731081
return success();
10741082
}
@@ -1082,7 +1090,19 @@ LogicalResult MCMCOp::verify() {
10821090
"Exactly one of hmc_config or nuts_config must be specified");
10831091
}
10841092

1085-
// TODO: More verification
1093+
// TODO(#2695): More verification
1094+
if (!getFnAttr() && !getLogpdfFnAttr()) {
1095+
return emitOpError("one of `fn` or `logpdf_fn` must be specified");
1096+
}
1097+
1098+
if (getFnAttr() && getLogpdfFnAttr()) {
1099+
return emitOpError("specifying both `fn` and `logpdf_fn` is unsupported");
1100+
}
1101+
1102+
if (getLogpdfFnAttr() && !getInitialPosition()) {
1103+
return emitOpError(
1104+
"custom logpdf mode requires `initial_position` to be provided");
1105+
}
10861106

10871107
return success();
10881108
}

0 commit comments

Comments
 (0)