Skip to content

Commit 07e9c5f

Browse files
committed
test
1 parent 4877c96 commit 07e9c5f

File tree

5 files changed

+690
-536
lines changed

5 files changed

+690
-536
lines changed

enzyme/test/MLIR/ProbProg/exp_transform.mlir

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,77 +24,67 @@ module {
2424
}
2525
}
2626

27-
// CHECK-LABEL: func.func @hmc
28-
29-
// ============================================================================
30-
// 1. UNCONSTRAIN TRANSFORM: math.log applied to trace samples
31-
// For POSITIVE support, unconstrain(x) = log(x)
32-
// ============================================================================
33-
34-
// Get flattened samples from trace, then recover and apply log to each sample
3527
// CHECK: %[[FLATTENED:.+]] = enzyme.getFlattenedSamplesFromTrace %{{.+}} {selection = {{\[}}[#enzyme.symbol<1>], [#enzyme.symbol<2>]{{\]}}} : (!enzyme.Trace) -> tensor<2xf64>
28+
3629
// CHECK: %[[SAMPLE1:.+]] = enzyme.recover_sample %[[FLATTENED]][0] : tensor<2xf64> -> tensor<1xf64>
37-
// CHECK-NEXT: %[[LOG1:.+]] = math.log %[[SAMPLE1]] : tensor<1xf64>
30+
// CHECK: %[[LOG1:.+]] = math.log %[[SAMPLE1]] : tensor<1xf64>
31+
3832
// CHECK: %[[SAMPLE2:.+]] = enzyme.recover_sample %[[FLATTENED]][1] : tensor<2xf64> -> tensor<1xf64>
39-
// CHECK-NEXT: %[[LOG2:.+]] = math.log %[[SAMPLE2]] : tensor<1xf64>
33+
// CHECK: %[[LOG2:.+]] = math.log %[[SAMPLE2]] : tensor<1xf64>
4034

41-
// Assemble the unconstrained position vector from the log-transformed samples
4235
// CHECK: %[[EXTRACT1:.+]] = enzyme.dynamic_extract %[[LOG1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
4336
// CHECK: %[[POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
4437
// CHECK: %[[EXTRACT2:.+]] = enzyme.dynamic_extract %[[LOG2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
4538
// CHECK: %[[UNCONSTRAINED_POS:.+]] = enzyme.dynamic_update %[[POS1]], %{{.+}}, %[[EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
4639

47-
// ============================================================================
48-
// 2. JACOBIAN CORRECTION: sum(z) subtracted from -weight
49-
// For POSITIVE support, log|det(J)| = sum(z) where z is unconstrained position
50-
// Potential energy U = -weight - sum(z)
51-
// ============================================================================
52-
53-
// Get weight from trace and negate it: U_base = -weight
5440
// CHECK: %[[WEIGHT:.+]] = enzyme.getWeightFromTrace %{{.+}} : (!enzyme.Trace) -> tensor<f64>
55-
// CHECK-NEXT: %[[NEG_WEIGHT:.+]] = arith.negf %[[WEIGHT]] : tensor<f64>
41+
// CHECK: %[[NEG_WEIGHT:.+]] = arith.negf %[[WEIGHT]] : tensor<f64>
5642

57-
// Compute sum(z) for Jacobian correction using dot product with ones vector
5843
// CHECK: %[[Z1:.+]] = enzyme.recover_sample %[[UNCONSTRAINED_POS]][0] : tensor<2xf64> -> tensor<1xf64>
59-
// CHECK-NEXT: %[[SUM1:.+]] = enzyme.dot %[[Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
44+
// CHECK: %[[SUM1:.+]] = enzyme.dot %[[Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
45+
// CHECK: %[[PARTIAL_SUM:.+]] = arith.addf %[[SUM1]], %{{.+}} : tensor<f64>
6046
// CHECK: %[[Z2:.+]] = enzyme.recover_sample %[[UNCONSTRAINED_POS]][1] : tensor<2xf64> -> tensor<1xf64>
61-
// CHECK-NEXT: %[[SUM2:.+]] = enzyme.dot %[[Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
62-
// CHECK: %[[TOTAL_JACOBIAN:.+]] = arith.addf %{{.+}}, %[[SUM2]] : tensor<f64>
47+
// CHECK: %[[SUM2:.+]] = enzyme.dot %[[Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
48+
// CHECK: %[[TOTAL_JACOBIAN:.+]] = arith.addf %[[PARTIAL_SUM]], %[[SUM2]] : tensor<f64>
6349

64-
// Subtract Jacobian correction from potential energy: U = -weight - sum(z)
6550
// CHECK: %[[U0:.+]] = arith.subf %[[NEG_WEIGHT]], %[[TOTAL_JACOBIAN]] : tensor<f64>
6651

67-
// ============================================================================
68-
// 3. CONSTRAIN TRANSFORM inside autodiff_region: math.exp before UpdateOp
69-
// For POSITIVE support, constrain(z) = exp(z)
70-
// The autodiff_region computes gradient of U(q) w.r.t. unconstrained position
71-
// ============================================================================
72-
73-
// CHECK: %{{.+}} = enzyme.autodiff_region(%[[UNCONSTRAINED_POS]], %{{.+}}) {
52+
// CHECK: %{{.+}}:3 = enzyme.autodiff_region(%{{.+}}, %{{.+}}) {
7453
// CHECK: ^bb0(%[[ARG:.+]]: tensor<2xf64>):
7554

76-
// Recover unconstrained samples and apply exp to constrain them
7755
// CHECK: %[[Z_SAMPLE1:.+]] = enzyme.recover_sample %[[ARG]][0] : tensor<2xf64> -> tensor<1xf64>
78-
// CHECK-NEXT: %[[EXP1:.+]] = math.exp %[[Z_SAMPLE1]] : tensor<1xf64>
56+
// CHECK: %[[EXP1:.+]] = math.exp %[[Z_SAMPLE1]] : tensor<1xf64>
7957
// CHECK: %[[Z_SAMPLE2:.+]] = enzyme.recover_sample %[[ARG]][1] : tensor<2xf64> -> tensor<1xf64>
80-
// CHECK-NEXT: %[[EXP2:.+]] = math.exp %[[Z_SAMPLE2]] : tensor<1xf64>
58+
// CHECK: %[[EXP2:.+]] = math.exp %[[Z_SAMPLE2]] : tensor<1xf64>
8159

82-
// Assemble constrained position vector from exp-transformed samples
8360
// CHECK: %[[C_EXTRACT1:.+]] = enzyme.dynamic_extract %[[EXP1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
8461
// CHECK: %[[C_POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[C_EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
8562
// CHECK: %[[C_EXTRACT2:.+]] = enzyme.dynamic_extract %[[EXP2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
8663
// CHECK: %[[CONSTRAINED_POS:.+]] = enzyme.dynamic_update %[[C_POS1]], %{{.+}}, %[[C_EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
8764

88-
// Call update function with constrained position
89-
// CHECK: %[[UPDATE_RES:.+]]:3 = func.call @test.update{{.*}}(%{{.+}}, %[[CONSTRAINED_POS]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)
65+
// CHECK: %[[UPDATE_RES:.+]]:3 = func.call @test.update{{.+}}(%{{.+}}, %[[CONSTRAINED_POS]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)
9066

91-
// Compute adjusted potential energy: U = -weight - sum(z)
9267
// CHECK: %[[NEG_UPDATE_WEIGHT:.+]] = arith.negf %[[UPDATE_RES]]#1 : tensor<f64>
93-
// CHECK: enzyme.recover_sample %[[ARG]][0]
94-
// CHECK-NEXT: enzyme.dot
95-
// CHECK: enzyme.recover_sample %[[ARG]][1]
96-
// CHECK-NEXT: enzyme.dot
97-
// CHECK: %[[JAC_SUM:.+]] = arith.addf %{{.+}}, %{{.+}} : tensor<f64>
98-
// CHECK-NEXT: %[[ADJUSTED_U:.+]] = arith.subf %[[NEG_UPDATE_WEIGHT]], %[[JAC_SUM]] : tensor<f64>
99-
// CHECK-NEXT: enzyme.yield %[[ADJUSTED_U]], %{{.+}} : tensor<f64>, tensor<2xui64>
68+
// CHECK: %[[ARG_Z1:.+]] = enzyme.recover_sample %[[ARG]][0] : tensor<2xf64> -> tensor<1xf64>
69+
// CHECK: %[[ARG_DOT1:.+]] = enzyme.dot %[[ARG_Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
70+
// CHECK: %[[ARG_PARTIAL:.+]] = arith.addf %[[ARG_DOT1]], %{{.+}} : tensor<f64>
71+
// CHECK: %[[ARG_Z2:.+]] = enzyme.recover_sample %[[ARG]][1] : tensor<2xf64> -> tensor<1xf64>
72+
// CHECK: %[[ARG_DOT2:.+]] = enzyme.dot %[[ARG_Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
73+
// CHECK: %[[JAC_SUM:.+]] = arith.addf %[[ARG_PARTIAL]], %[[ARG_DOT2]] : tensor<f64>
74+
// CHECK: %[[ADJUSTED_U:.+]] = arith.subf %[[NEG_UPDATE_WEIGHT]], %[[JAC_SUM]] : tensor<f64>
75+
// CHECK: enzyme.yield %[[ADJUSTED_U]], %{{.+}} : tensor<f64>, tensor<2xui64>
10076
// CHECK: }
77+
78+
// CHECK: %[[FINAL_SELECT:.+]] = enzyme.select %{{.+}}, %{{.+}}, %[[UNCONSTRAINED_POS]] : (tensor<i1>, tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64>
79+
80+
// CHECK: %[[FINAL_SAMPLE1:.+]] = enzyme.recover_sample %[[FINAL_SELECT]][0] : tensor<2xf64> -> tensor<1xf64>
81+
// CHECK: %[[FINAL_EXP1:.+]] = math.exp %[[FINAL_SAMPLE1]] : tensor<1xf64>
82+
// CHECK: %[[FINAL_SAMPLE2:.+]] = enzyme.recover_sample %[[FINAL_SELECT]][1] : tensor<2xf64> -> tensor<1xf64>
83+
// CHECK: %[[FINAL_EXP2:.+]] = math.exp %[[FINAL_SAMPLE2]] : tensor<1xf64>
84+
85+
// CHECK: %[[FINAL_EXTRACT1:.+]] = enzyme.dynamic_extract %[[FINAL_EXP1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
86+
// CHECK: %[[FINAL_POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[FINAL_EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
87+
// CHECK: %[[FINAL_EXTRACT2:.+]] = enzyme.dynamic_extract %[[FINAL_EXP2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
88+
// CHECK: %[[FINAL_CONSTRAINED:.+]] = enzyme.dynamic_update %[[FINAL_POS1]], %{{.+}}, %[[FINAL_EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
89+
90+
// CHECK: call @test.update(%{{.+}}, %[[FINAL_CONSTRAINED]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)

0 commit comments

Comments
 (0)