@@ -88,43 +88,33 @@ def genjax_polynomial_is_timing(
8888
8989 xs , ys = dataset .xs , dataset .ys
9090
91- # Create and JIT the inference function using faircoin pattern
9291 genjax_infer_is = make_genjax_infer_is (n_particles )
9392 infer_jit = jax .jit (seed (genjax_infer_is ))
9493
95- def task ():
94+ def full_task ():
9695 log_weights = infer_jit (key , xs , ys )
97- jax .block_until_ready (log_weights )
98- return log_weights
96+ # Reconstruct traces to match other frameworks (samples + weights)
97+ constraints = {"ys" : ys }
98+ def sample_particle (_ ):
99+ trace , _ = polynomial_flat .generate (constraints , xs )
100+ return trace
101+ traces = vmap (sample_particle )(jnp .arange (n_particles ))
102+ choices = traces .get_choices ()
103+ jax .block_until_ready ((log_weights , choices ["a" ], choices ["b" ], choices ["c" ]))
104+ return log_weights , choices
99105
100- # Run benchmark with automatic warm-up - more inner repeats for accuracy
101106 times , (mean_time , std_time ) = benchmark_with_warmup (
102- task ,
107+ full_task ,
103108 warmup_runs = 5 ,
104109 repeats = repeats ,
105110 inner_repeats = inner_repeats ,
106111 auto_sync = False ,
107112 )
108113
109- # Get final weights
110- log_weights = task ()
111-
112- # For compatibility, also get samples - run full version once
113- def get_samples ():
114- constraints = {"ys" : ys }
115- def sample_particle (_ ):
116- trace , _ = polynomial_flat .generate (constraints , xs )
117- return trace
118- traces = vmap (sample_particle )(jnp .arange (n_particles ))
119- return traces
120-
121- jitted_samples = jax .jit (seed (get_samples ))
122- traces = jitted_samples (key )
123-
124- # Extract samples
125- samples_a = traces .get_choices ()["a" ]
126- samples_b = traces .get_choices ()["b" ]
127- samples_c = traces .get_choices ()["c" ]
114+ log_weights , choices = full_task ()
115+ samples_a = choices ["a" ]
116+ samples_b = choices ["b" ]
117+ samples_c = choices ["c" ]
128118
129119 return {
130120 "framework" : "genjax" if not use_direct else "genjax_direct" ,
0 commit comments