@@ -57,4 +57,123 @@ module {
5757 : (tensor <2 xui64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >)
5858 return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
5959 }
60+
61+ func.func @shifted_logpdf (%x : tensor <1 x2 xf64 >, %mu : tensor <1 x2 xf64 >) -> tensor <f64 > {
62+ %diff = arith.subf %x , %mu : tensor <1 x2 xf64 >
63+ %sum_sq = enzyme.dot %diff , %diff {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 , 1 >, rhs_contracting_dimensions = array<i64 : 0 , 1 >} : (tensor <1 x2 xf64 >, tensor <1 x2 xf64 >) -> tensor <f64 >
64+ %neg_half = arith.constant dense <-5.000000e-01 > : tensor <f64 >
65+ %result = arith.mulf %neg_half , %sum_sq : tensor <f64 >
66+ return %result : tensor <f64 >
67+ }
68+
69+ // CHECK-LABEL: func.func @nuts_shifted_logpdf
70+ // CHECK: call @shifted_logpdf
71+ // CHECK-NEXT: %[[U0:.+]] = arith.negf
72+ // CHECK: enzyme.autodiff_region
73+ // CHECK: func.call @shifted_logpdf
74+ // CHECK-NEXT: %[[NEG:.+]] = arith.negf
75+ // CHECK-NEXT: enzyme.yield
76+ // CHECK: enzyme.for_loop
77+ // CHECK: enzyme.autodiff_region
78+ // CHECK: func.call @shifted_logpdf
79+ // CHECK-NEXT: %{{.+}} = arith.negf
80+ // CHECK-NEXT: enzyme.yield
81+ func.func @nuts_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
82+ %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
83+ %step_size = arith.constant dense <0.1 > : tensor <f64 >
84+ %res:3 = enzyme.mcmc (%rng , %mu )
85+ step_size = %step_size
86+ logpdf_fn = @shifted_logpdf
87+ initial_position = %init_pos
88+ { nuts_config = #enzyme.nuts_config <max_tree_depth = 3 , max_delta_energy = 1000.0 , adapt_step_size = false , adapt_mass_matrix = false >,
89+ name = " nuts_shifted_logpdf" , selection = [], all_addresses = [], num_warmup = 0 , num_samples = 1 }
90+ : (tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >)
91+ return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
92+ }
93+
94+ // CHECK-LABEL: func.func @hmc_shifted_logpdf
95+ // CHECK: call @shifted_logpdf
96+ // CHECK-NEXT: %{{.+}} = arith.negf
97+ // CHECK: enzyme.autodiff_region
98+ // CHECK: func.call @shifted_logpdf
99+ // CHECK-NEXT: %{{.+}} = arith.negf
100+ // CHECK-NEXT: enzyme.yield
101+ // CHECK: enzyme.for_loop
102+ // CHECK: enzyme.autodiff_region
103+ // CHECK: func.call @shifted_logpdf
104+ // CHECK-NEXT: %{{.+}} = arith.negf
105+ // CHECK-NEXT: enzyme.yield
106+ func.func @hmc_shifted_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
107+ %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
108+ %step_size = arith.constant dense <0.1 > : tensor <f64 >
109+ %res:3 = enzyme.mcmc (%rng , %mu )
110+ step_size = %step_size
111+ logpdf_fn = @shifted_logpdf
112+ initial_position = %init_pos
113+ { hmc_config = #enzyme.hmc_config <trajectory_length = 1.000000e+00 : f64 , adapt_step_size = false , adapt_mass_matrix = false >,
114+ name = " hmc_shifted_logpdf" , selection = [], all_addresses = [], num_warmup = 0 , num_samples = 1 }
115+ : (tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >)
116+ return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
117+ }
118+
119+ func.func @anisotropic_logpdf (%x : tensor <1 x2 xf64 >, %mu : tensor <1 x2 xf64 >, %precision : tensor <1 x2 xf64 >) -> tensor <f64 > {
120+ %diff = arith.subf %x , %mu : tensor <1 x2 xf64 >
121+ %diff_sq = arith.mulf %diff , %diff : tensor <1 x2 xf64 >
122+ %weighted = arith.mulf %precision , %diff_sq : tensor <1 x2 xf64 >
123+ %ones = arith.constant dense <1.0 > : tensor <1 x2 xf64 >
124+ %sum = enzyme.dot %ones , %weighted {lhs_batching_dimensions = array<i64 >, rhs_batching_dimensions = array<i64 >, lhs_contracting_dimensions = array<i64 : 0 , 1 >, rhs_contracting_dimensions = array<i64 : 0 , 1 >} : (tensor <1 x2 xf64 >, tensor <1 x2 xf64 >) -> tensor <f64 >
125+ %neg_half = arith.constant dense <-5.000000e-01 > : tensor <f64 >
126+ %result = arith.mulf %neg_half , %sum : tensor <f64 >
127+ return %result : tensor <f64 >
128+ }
129+
130+ // CHECK-LABEL: func.func @nuts_anisotropic_logpdf
131+ // CHECK: call @anisotropic_logpdf
132+ // CHECK-NEXT: %[[U0:.+]] = arith.negf
133+ // CHECK: enzyme.autodiff_region
134+ // CHECK: func.call @anisotropic_logpdf
135+ // CHECK-NEXT: %[[NEG:.+]] = arith.negf
136+ // CHECK-NEXT: enzyme.yield
137+ // CHECK: enzyme.for_loop
138+ // CHECK: enzyme.autodiff_region
139+ // CHECK: func.call @anisotropic_logpdf
140+ // CHECK-NEXT: %{{.+}} = arith.negf
141+ // CHECK-NEXT: enzyme.yield
142+ func.func @nuts_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >, %precision : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
143+ %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
144+ %step_size = arith.constant dense <0.1 > : tensor <f64 >
145+ %res:3 = enzyme.mcmc (%rng , %mu , %precision )
146+ step_size = %step_size
147+ logpdf_fn = @anisotropic_logpdf
148+ initial_position = %init_pos
149+ { nuts_config = #enzyme.nuts_config <max_tree_depth = 3 , max_delta_energy = 1000.0 , adapt_step_size = false , adapt_mass_matrix = false >,
150+ name = " nuts_anisotropic_logpdf" , selection = [], all_addresses = [], num_warmup = 0 , num_samples = 1 }
151+ : (tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >)
152+ return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
153+ }
154+
155+ // CHECK-LABEL: func.func @hmc_anisotropic_logpdf
156+ // CHECK: call @anisotropic_logpdf
157+ // CHECK-NEXT: %{{.+}} = arith.negf
158+ // CHECK: enzyme.autodiff_region
159+ // CHECK: func.call @anisotropic_logpdf
160+ // CHECK-NEXT: %{{.+}} = arith.negf
161+ // CHECK-NEXT: enzyme.yield
162+ // CHECK: enzyme.for_loop
163+ // CHECK: enzyme.autodiff_region
164+ // CHECK: func.call @anisotropic_logpdf
165+ // CHECK-NEXT: %{{.+}} = arith.negf
166+ // CHECK-NEXT: enzyme.yield
167+ func.func @hmc_anisotropic_logpdf (%rng : tensor <2 xui64 >, %mu : tensor <1 x2 xf64 >, %precision : tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >) {
168+ %init_pos = arith.constant dense <[[0.5 , -0.5 ]]> : tensor <1 x2 xf64 >
169+ %step_size = arith.constant dense <0.1 > : tensor <f64 >
170+ %res:3 = enzyme.mcmc (%rng , %mu , %precision )
171+ step_size = %step_size
172+ logpdf_fn = @anisotropic_logpdf
173+ initial_position = %init_pos
174+ { hmc_config = #enzyme.hmc_config <trajectory_length = 1.000000e+00 : f64 , adapt_step_size = false , adapt_mass_matrix = false >,
175+ name = " hmc_anisotropic_logpdf" , selection = [], all_addresses = [], num_warmup = 0 , num_samples = 1 }
176+ : (tensor <2 xui64 >, tensor <1 x2 xf64 >, tensor <1 x2 xf64 >, tensor <f64 >, tensor <1 x2 xf64 >) -> (tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >)
177+ return %res#0 , %res#1 , %res#2 : tensor <1 x2 xf64 >, tensor <1 xi1 >, tensor <2 xui64 >
178+ }
60179}
0 commit comments