@@ -50,12 +50,19 @@ def infer(xs, ys):
5050 constraints = {"ys" : ys }
5151
5252 def importance_ (constraints ):
53- _ , w = polynomial_flat .generate (constraints , xs )
54- return w
53+ trace , w = polynomial_flat .generate (constraints , xs )
54+ choices = trace .get_choices ()
55+ return w , choices ["a" ], choices ["b" ], choices ["c" ]
5556
5657 # Direct vmap without dummy arrays
5758 imp = vmap (importance_ , axis_size = n_particles , in_axes = None )
58- return imp (constraints )
59+ logw , a_samples , b_samples , c_samples = imp (constraints )
60+ return {
61+ "log_weights" : logw ,
62+ "a" : a_samples ,
63+ "b" : b_samples ,
64+ "c" : c_samples ,
65+ }
5966
6067 return infer
6168
@@ -91,30 +98,27 @@ def genjax_polynomial_is_timing(
9198 genjax_infer_is = make_genjax_infer_is (n_particles )
9299 infer_jit = jax .jit (seed (genjax_infer_is ))
93100
94- def full_task ():
95- log_weights = infer_jit (key , xs , ys )
96- # Reconstruct traces to match other frameworks (samples + weights)
97- constraints = { "ys" : ys }
98- def sample_particle ( _ ):
99- trace , _ = polynomial_flat . generate ( constraints , xs )
100- return trace
101- traces = vmap ( sample_particle )( jnp . arange ( n_particles ))
102- choices = traces . get_choices ( )
103- jax . block_until_ready (( log_weights , choices [ "a" ], choices [ "b" ], choices [ "c" ]) )
104- return log_weights , choices
101+ def task ():
102+ result = infer_jit (key , xs , ys )
103+ jax . block_until_ready (
104+ (
105+ result [ "log_weights" ],
106+ result [ "a" ],
107+ result [ "b" ],
108+ result [ "c" ],
109+ )
110+ )
111+ return result
105112
106113 times , (mean_time , std_time ) = benchmark_with_warmup (
107- full_task ,
114+ task ,
108115 warmup_runs = 5 ,
109116 repeats = repeats ,
110117 inner_repeats = inner_repeats ,
111118 auto_sync = False ,
112119 )
113120
114- log_weights , choices = full_task ()
115- samples_a = choices ["a" ]
116- samples_b = choices ["b" ]
117- samples_c = choices ["c" ]
121+ result = task ()
118122
119123 return {
120124 "framework" : "genjax" if not use_direct else "genjax_direct" ,
@@ -125,11 +129,11 @@ def sample_particle(_):
125129 "mean_time" : mean_time ,
126130 "std_time" : std_time ,
127131 "samples" : {
128- "a" : samples_a ,
129- "b" : samples_b ,
130- "c" : samples_c ,
132+ "a" : result [ "a" ] ,
133+ "b" : result [ "b" ] ,
134+ "c" : result [ "c" ] ,
131135 },
132- "log_weights" : log_weights ,
136+ "log_weights" : result [ " log_weights" ] ,
133137 }
134138
135139
0 commit comments