Skip to content

Commit 04dce60

Browse files
committed
Make GenJAX IS timing match other frameworks
1 parent fd5e6a7 commit 04dce60

1 file changed

Lines changed: 27 additions & 23 deletions

File tree

  • examples/perfbench/benchmarks/src/timing_benchmarks/curvefit_benchmarks

examples/perfbench/benchmarks/src/timing_benchmarks/curvefit_benchmarks/genjax.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,19 @@ def infer(xs, ys):
5050
constraints = {"ys": ys}
5151

5252
def importance_(constraints):
53-
_, w = polynomial_flat.generate(constraints, xs)
54-
return w
53+
trace, w = polynomial_flat.generate(constraints, xs)
54+
choices = trace.get_choices()
55+
return w, choices["a"], choices["b"], choices["c"]
5556

5657
# Direct vmap without dummy arrays
5758
imp = vmap(importance_, axis_size=n_particles, in_axes=None)
58-
return imp(constraints)
59+
logw, a_samples, b_samples, c_samples = imp(constraints)
60+
return {
61+
"log_weights": logw,
62+
"a": a_samples,
63+
"b": b_samples,
64+
"c": c_samples,
65+
}
5966

6067
return infer
6168

@@ -91,30 +98,27 @@ def genjax_polynomial_is_timing(
9198
genjax_infer_is = make_genjax_infer_is(n_particles)
9299
infer_jit = jax.jit(seed(genjax_infer_is))
93100

94-
def full_task():
95-
log_weights = infer_jit(key, xs, ys)
96-
# Reconstruct traces to match other frameworks (samples + weights)
97-
constraints = {"ys": ys}
98-
def sample_particle(_):
99-
trace, _ = polynomial_flat.generate(constraints, xs)
100-
return trace
101-
traces = vmap(sample_particle)(jnp.arange(n_particles))
102-
choices = traces.get_choices()
103-
jax.block_until_ready((log_weights, choices["a"], choices["b"], choices["c"]))
104-
return log_weights, choices
101+
def task():
102+
result = infer_jit(key, xs, ys)
103+
jax.block_until_ready(
104+
(
105+
result["log_weights"],
106+
result["a"],
107+
result["b"],
108+
result["c"],
109+
)
110+
)
111+
return result
105112

106113
times, (mean_time, std_time) = benchmark_with_warmup(
107-
full_task,
114+
task,
108115
warmup_runs=5,
109116
repeats=repeats,
110117
inner_repeats=inner_repeats,
111118
auto_sync=False,
112119
)
113120

114-
log_weights, choices = full_task()
115-
samples_a = choices["a"]
116-
samples_b = choices["b"]
117-
samples_c = choices["c"]
121+
result = task()
118122

119123
return {
120124
"framework": "genjax" if not use_direct else "genjax_direct",
@@ -125,11 +129,11 @@ def sample_particle(_):
125129
"mean_time": mean_time,
126130
"std_time": std_time,
127131
"samples": {
128-
"a": samples_a,
129-
"b": samples_b,
130-
"c": samples_c,
132+
"a": result["a"],
133+
"b": result["b"],
134+
"c": result["c"],
131135
},
132-
"log_weights": log_weights,
136+
"log_weights": result["log_weights"],
133137
}
134138

135139

0 commit comments

Comments
 (0)