Skip to content

Conversation

@adzcai
Copy link
Contributor

@adzcai adzcai commented Dec 10, 2025

I noticed the comment in rules.py about evaluating all rules. I think this is unavoidable under vmap, but in the single-threaded case, by passing the operands in the way below, Jax should evaluate only the selected one:

# old
grid, agent = jax.lax.switch(
    encoding[0],
    (
        lambda: EmptyRule.decode(encoding)(grid, agent, action, position),
        ...
    ),
)

# new
grid, agent = jax.lax.switch(
    encoding[0],
    (
        EmptyRule.decode(encoding),
        ...
    ),
    grid,
    agent,
    action,
    position,
)

Let me know the best way to benchmark this!

EDIT: Sorry, upon further thought, I was hasty, think these might be exactly the same. Feel free to close this.

@Howuhh
Copy link
Member

Howuhh commented Dec 10, 2025

hi @adzcai!

You can benchmark times with the help of scripts/benchmark_xland.py and scripts/benchmark_xland_all.py scripts.

@Howuhh Howuhh closed this Dec 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants