Skip to content

Commit d24d7ea

Browse files
committed
revert loop
1 parent a87af6f commit d24d7ea

File tree

4 files changed

+2
-9
lines changed

4 files changed

+2
-9
lines changed

adirondax/simulation.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,9 @@ def step_fn(carry, _):
224224
return (new_state, new_V), None
225225

226226
# Run the entire loop as a single JIT-compiled function
227-
# def run_loop(carry):
228-
# final_carry, _ = jax.lax.scan(
229-
# step_fn, carry, xs=None, length=nt, unroll=True
230-
# )
231-
# return final_carry
232-
233227
def run_loop(carry):
234-
for _ in range(nt):
235-
carry, _ = step_fn(carry, None)
236-
return carry
228+
final_carry, _ = jax.lax.scan(step_fn, carry, xs=None, length=nt)
229+
return final_carry
237230

238231
# Execute the compiled loop
239232
state, V = run_loop(carry)
0 Bytes
Loading

examples/orszag_tang/output.png

-3 Bytes
Loading
54 Bytes
Loading

0 commit comments

Comments
 (0)