Skip to content

Commit fd5e6a7

Browse files
committed
Align IS timing workload for GenJAX
1 parent 37e593c commit fd5e6a7

1 file changed

Lines changed: 15 additions & 25 deletions

File tree

  • examples/perfbench/benchmarks/src/timing_benchmarks/curvefit_benchmarks

examples/perfbench/benchmarks/src/timing_benchmarks/curvefit_benchmarks/genjax.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -88,43 +88,33 @@ def genjax_polynomial_is_timing(
8888

8989
xs, ys = dataset.xs, dataset.ys
9090

91-
# Create and JIT the inference function using faircoin pattern
9291
genjax_infer_is = make_genjax_infer_is(n_particles)
9392
infer_jit = jax.jit(seed(genjax_infer_is))
9493

95-
def task():
94+
def full_task():
9695
log_weights = infer_jit(key, xs, ys)
97-
jax.block_until_ready(log_weights)
98-
return log_weights
96+
# Reconstruct traces to match other frameworks (samples + weights)
97+
constraints = {"ys": ys}
98+
def sample_particle(_):
99+
trace, _ = polynomial_flat.generate(constraints, xs)
100+
return trace
101+
traces = vmap(sample_particle)(jnp.arange(n_particles))
102+
choices = traces.get_choices()
103+
jax.block_until_ready((log_weights, choices["a"], choices["b"], choices["c"]))
104+
return log_weights, choices
99105

100-
# Run benchmark with automatic warm-up - more inner repeats for accuracy
101106
times, (mean_time, std_time) = benchmark_with_warmup(
102-
task,
107+
full_task,
103108
warmup_runs=5,
104109
repeats=repeats,
105110
inner_repeats=inner_repeats,
106111
auto_sync=False,
107112
)
108113

109-
# Get final weights
110-
log_weights = task()
111-
112-
# For compatibility, also get samples - run full version once
113-
def get_samples():
114-
constraints = {"ys": ys}
115-
def sample_particle(_):
116-
trace, _ = polynomial_flat.generate(constraints, xs)
117-
return trace
118-
traces = vmap(sample_particle)(jnp.arange(n_particles))
119-
return traces
120-
121-
jitted_samples = jax.jit(seed(get_samples))
122-
traces = jitted_samples(key)
123-
124-
# Extract samples
125-
samples_a = traces.get_choices()["a"]
126-
samples_b = traces.get_choices()["b"]
127-
samples_c = traces.get_choices()["c"]
114+
log_weights, choices = full_task()
115+
samples_a = choices["a"]
116+
samples_b = choices["b"]
117+
samples_c = choices["c"]
128118

129119
return {
130120
"framework": "genjax" if not use_direct else "genjax_direct",

0 commit comments

Comments
 (0)