@@ -24,77 +24,67 @@ module {
2424 }
2525}
2626
27- // CHECK-LABEL: func.func @hmc
28-
29- // ============================================================================
30- // 1. UNCONSTRAIN TRANSFORM: math.log applied to trace samples
31- // For POSITIVE support, unconstrain(x) = log(x)
32- // ============================================================================
33-
34- // Get flattened samples from trace, then recover and apply log to each sample
3527// CHECK: %[[FLATTENED:.+]] = enzyme.getFlattenedSamplesFromTrace %{{.+}} {selection = {{\[}}[#enzyme.symbol<1>], [#enzyme.symbol<2>]{{\]}}} : (!enzyme.Trace) -> tensor<2xf64>
28+
3629// CHECK: %[[SAMPLE1:.+]] = enzyme.recover_sample %[[FLATTENED]][0] : tensor<2xf64> -> tensor<1xf64>
37- // CHECK-NEXT: %[[LOG1:.+]] = math.log %[[SAMPLE1]] : tensor<1xf64>
30+ // CHECK: %[[LOG1:.+]] = math.log %[[SAMPLE1]] : tensor<1xf64>
31+
3832// CHECK: %[[SAMPLE2:.+]] = enzyme.recover_sample %[[FLATTENED]][1] : tensor<2xf64> -> tensor<1xf64>
39- // CHECK-NEXT : %[[LOG2:.+]] = math.log %[[SAMPLE2]] : tensor<1xf64>
33+ // CHECK: %[[LOG2:.+]] = math.log %[[SAMPLE2]] : tensor<1xf64>
4034
41- // Assemble the unconstrained position vector from the log-transformed samples
4235// CHECK: %[[EXTRACT1:.+]] = enzyme.dynamic_extract %[[LOG1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
4336// CHECK: %[[POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
4437// CHECK: %[[EXTRACT2:.+]] = enzyme.dynamic_extract %[[LOG2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
4538// CHECK: %[[UNCONSTRAINED_POS:.+]] = enzyme.dynamic_update %[[POS1]], %{{.+}}, %[[EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
4639
47- // ============================================================================
48- // 2. JACOBIAN CORRECTION: sum(z) subtracted from -weight
49- // For POSITIVE support, log|det(J)| = sum(z) where z is unconstrained position
50- // Potential energy U = -weight - sum(z)
51- // ============================================================================
52-
53- // Get weight from trace and negate it: U_base = -weight
5440// CHECK: %[[WEIGHT:.+]] = enzyme.getWeightFromTrace %{{.+}} : (!enzyme.Trace) -> tensor<f64>
55- // CHECK-NEXT : %[[NEG_WEIGHT:.+]] = arith.negf %[[WEIGHT]] : tensor<f64>
41+ // CHECK: %[[NEG_WEIGHT:.+]] = arith.negf %[[WEIGHT]] : tensor<f64>
5642
57- // Compute sum(z) for Jacobian correction using dot product with ones vector
5843// CHECK: %[[Z1:.+]] = enzyme.recover_sample %[[UNCONSTRAINED_POS]][0] : tensor<2xf64> -> tensor<1xf64>
59- // CHECK-NEXT: %[[SUM1:.+]] = enzyme.dot %[[Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
44+ // CHECK: %[[SUM1:.+]] = enzyme.dot %[[Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
45+ // CHECK: %[[PARTIAL_SUM:.+]] = arith.addf %[[SUM1]], %{{.+}} : tensor<f64>
6046// CHECK: %[[Z2:.+]] = enzyme.recover_sample %[[UNCONSTRAINED_POS]][1] : tensor<2xf64> -> tensor<1xf64>
61- // CHECK-NEXT : %[[SUM2:.+]] = enzyme.dot %[[Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
62- // CHECK: %[[TOTAL_JACOBIAN:.+]] = arith.addf %{{.+}} , %[[SUM2]] : tensor<f64>
47+ // CHECK: %[[SUM2:.+]] = enzyme.dot %[[Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
48+ // CHECK: %[[TOTAL_JACOBIAN:.+]] = arith.addf %[[PARTIAL_SUM]] , %[[SUM2]] : tensor<f64>
6349
64- // Subtract Jacobian correction from potential energy: U = -weight - sum(z)
6550// CHECK: %[[U0:.+]] = arith.subf %[[NEG_WEIGHT]], %[[TOTAL_JACOBIAN]] : tensor<f64>
6651
67- // ============================================================================
68- // 3. CONSTRAIN TRANSFORM inside autodiff_region: math.exp before UpdateOp
69- // For POSITIVE support, constrain(z) = exp(z)
70- // The autodiff_region computes gradient of U(q) w.r.t. unconstrained position
71- // ============================================================================
72-
73- // CHECK: %{{.+}} = enzyme.autodiff_region(%[[UNCONSTRAINED_POS]], %{{.+}}) {
52+ // CHECK: %{{.+}}:3 = enzyme.autodiff_region(%{{.+}}, %{{.+}}) {
7453// CHECK: ^bb0(%[[ARG:.+]]: tensor<2xf64>):
7554
76- // Recover unconstrained samples and apply exp to constrain them
7755// CHECK: %[[Z_SAMPLE1:.+]] = enzyme.recover_sample %[[ARG]][0] : tensor<2xf64> -> tensor<1xf64>
78- // CHECK-NEXT : %[[EXP1:.+]] = math.exp %[[Z_SAMPLE1]] : tensor<1xf64>
56+ // CHECK: %[[EXP1:.+]] = math.exp %[[Z_SAMPLE1]] : tensor<1xf64>
7957// CHECK: %[[Z_SAMPLE2:.+]] = enzyme.recover_sample %[[ARG]][1] : tensor<2xf64> -> tensor<1xf64>
80- // CHECK-NEXT : %[[EXP2:.+]] = math.exp %[[Z_SAMPLE2]] : tensor<1xf64>
58+ // CHECK: %[[EXP2:.+]] = math.exp %[[Z_SAMPLE2]] : tensor<1xf64>
8159
82- // Assemble constrained position vector from exp-transformed samples
8360// CHECK: %[[C_EXTRACT1:.+]] = enzyme.dynamic_extract %[[EXP1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
8461// CHECK: %[[C_POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[C_EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
8562// CHECK: %[[C_EXTRACT2:.+]] = enzyme.dynamic_extract %[[EXP2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
8663// CHECK: %[[CONSTRAINED_POS:.+]] = enzyme.dynamic_update %[[C_POS1]], %{{.+}}, %[[C_EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
8764
88- // Call update function with constrained position
89- // CHECK: %[[UPDATE_RES:.+]]:3 = func.call @test.update{{.*}}(%{{.+}}, %[[CONSTRAINED_POS]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)
65+ // CHECK: %[[UPDATE_RES:.+]]:3 = func.call @test.update{{.+}}(%{{.+}}, %[[CONSTRAINED_POS]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)
9066
91- // Compute adjusted potential energy: U = -weight - sum(z)
9267// CHECK: %[[NEG_UPDATE_WEIGHT:.+]] = arith.negf %[[UPDATE_RES]]#1 : tensor<f64>
93- // CHECK: enzyme.recover_sample %[[ARG]][0]
94- // CHECK-NEXT: enzyme.dot
95- // CHECK: enzyme.recover_sample %[[ARG]][1]
96- // CHECK-NEXT: enzyme.dot
97- // CHECK: %[[JAC_SUM:.+]] = arith.addf %{{.+}}, %{{.+}} : tensor<f64>
98- // CHECK-NEXT: %[[ADJUSTED_U:.+]] = arith.subf %[[NEG_UPDATE_WEIGHT]], %[[JAC_SUM]] : tensor<f64>
99- // CHECK-NEXT: enzyme.yield %[[ADJUSTED_U]], %{{.+}} : tensor<f64>, tensor<2xui64>
68+ // CHECK: %[[ARG_Z1:.+]] = enzyme.recover_sample %[[ARG]][0] : tensor<2xf64> -> tensor<1xf64>
69+ // CHECK: %[[ARG_DOT1:.+]] = enzyme.dot %[[ARG_Z1]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
70+ // CHECK: %[[ARG_PARTIAL:.+]] = arith.addf %[[ARG_DOT1]], %{{.+}} : tensor<f64>
71+ // CHECK: %[[ARG_Z2:.+]] = enzyme.recover_sample %[[ARG]][1] : tensor<2xf64> -> tensor<1xf64>
72+ // CHECK: %[[ARG_DOT2:.+]] = enzyme.dot %[[ARG_Z2]], %{{.+}} {lhs_batching_dimensions = array<i64>, lhs_contracting_dimensions = array<i64: 0>, rhs_batching_dimensions = array<i64>, rhs_contracting_dimensions = array<i64: 0>} : (tensor<1xf64>, tensor<1xf64>) -> tensor<f64>
73+ // CHECK: %[[JAC_SUM:.+]] = arith.addf %[[ARG_PARTIAL]], %[[ARG_DOT2]] : tensor<f64>
74+ // CHECK: %[[ADJUSTED_U:.+]] = arith.subf %[[NEG_UPDATE_WEIGHT]], %[[JAC_SUM]] : tensor<f64>
75+ // CHECK: enzyme.yield %[[ADJUSTED_U]], %{{.+}} : tensor<f64>, tensor<2xui64>
10076// CHECK: }
77+
78+ // CHECK: %[[FINAL_SELECT:.+]] = enzyme.select %{{.+}}, %{{.+}}, %[[UNCONSTRAINED_POS]] : (tensor<i1>, tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64>
79+
80+ // CHECK: %[[FINAL_SAMPLE1:.+]] = enzyme.recover_sample %[[FINAL_SELECT]][0] : tensor<2xf64> -> tensor<1xf64>
81+ // CHECK: %[[FINAL_EXP1:.+]] = math.exp %[[FINAL_SAMPLE1]] : tensor<1xf64>
82+ // CHECK: %[[FINAL_SAMPLE2:.+]] = enzyme.recover_sample %[[FINAL_SELECT]][1] : tensor<2xf64> -> tensor<1xf64>
83+ // CHECK: %[[FINAL_EXP2:.+]] = math.exp %[[FINAL_SAMPLE2]] : tensor<1xf64>
84+
85+ // CHECK: %[[FINAL_EXTRACT1:.+]] = enzyme.dynamic_extract %[[FINAL_EXP1]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
86+ // CHECK: %[[FINAL_POS1:.+]] = enzyme.dynamic_update %{{.+}}, %{{.+}}, %[[FINAL_EXTRACT1]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
87+ // CHECK: %[[FINAL_EXTRACT2:.+]] = enzyme.dynamic_extract %[[FINAL_EXP2]], %{{.+}} : (tensor<1xf64>, tensor<i64>) -> tensor<f64>
88+ // CHECK: %[[FINAL_CONSTRAINED:.+]] = enzyme.dynamic_update %[[FINAL_POS1]], %{{.+}}, %[[FINAL_EXTRACT2]] : (tensor<2xf64>, tensor<i64>, tensor<f64>) -> tensor<2xf64>
89+
90+ // CHECK: call @test.update(%{{.+}}, %[[FINAL_CONSTRAINED]], %{{.+}}, %{{.+}}) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor<f64>) -> (!enzyme.Trace, tensor<f64>, tensor<2xui64>)
0 commit comments