Skip to content

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_925836768
Open

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_925836768

Conversation

@copybara-service
Copy link
Copy Markdown

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing

This is another attempt for #37970:
Prior to this change, when running in 64-bit mode, we get
the error 'arith.cmpi' op requires all operands to have the same type because in the pipeline loop some of the operations get promoted to 64-bit mode.

That change was rolled back due to failure attempting to apply asarray
to tracers. I do not have a reproducing failure from the user, but an LLM
suggested that this can happen in presence of disable_jit, and in fact
created a failing test case for a pipeline with disable_jit.

We had used jnp.int32(0) and jnp.int32(num_steps) which create
tracers instead of constants, which in turn confuses the disable_jit
machinery. A better approach is to use np.int32(0),
which creates constants of the right type. For num_steps we should
use np.int32(num_steps) if isinstance(num_steps, int) else num_steps.

While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same.

@copybara-service copybara-service Bot force-pushed the test_925836768 branch 2 times, most recently from 318b505 to 28fc538 Compare June 3, 2026 16:25
This is another attempt for #37970:
Prior to this change, when running in 64-bit mode, we get
the error `'arith.cmpi' op requires all operands to have the same type` because in the pipeline loop some of the operations get promoted to 64-bit mode.

That change was rolled back due to failure attempting to apply `asarray`
to tracers. I do not have a reproducing failure from the user, but an LLM
suggested that this can happen in presence of disable_jit, and in fact
created a failing test case for a pipeline with disable_jit.

We had used `jnp.int32(0)` and `jnp.int32(num_steps)` which create
tracers instead of constants, which in turn confuses the disable_jit
machinery. A better approach is to use `np.int32(0)`,
which creates constants of the right type. For `num_steps` we should
use `np.int32(num_steps) if isinstance(num_steps, int) else num_steps`.

While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same.

PiperOrigin-RevId: 925836768
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant