Skip to content

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154

Merged
copybara-service[bot] merged 1 commit into
mainfrom
test_925836768
Jun 4, 2026
Merged

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154
copybara-service[bot] merged 1 commit into
mainfrom
test_925836768

Conversation

@copybara-service

Copy link
Copy Markdown

[pallas] Fix Pallas Mosaic 64-bit mode loop indexing

This is another attempt for #37970:
Prior to this change, when running in 64-bit mode, we get
the error 'arith.cmpi' op requires all operands to have the same type because in the pipeline loop some of the operations get promoted to 64-bit mode.

That change was rolled back due to failure attempting to apply asarray
to tracers. I do not have a reproducing failure from the user, but an LLM
suggested that this can happen in presence of disable_jit, and in fact
created a failing test case for a pipeline with disable_jit.

We had used jnp.int32(0) and jnp.int32(num_steps) which create
tracers instead of constants, which in turn confuses the disable_jit
machinery. A better approach is to use np.int32(0),
which creates constants of the right type. For num_steps we should
use np.int32(num_steps) if isinstance(num_steps, int) else num_steps.

While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same.

@copybara-service copybara-service Bot force-pushed the test_925836768 branch 6 times, most recently from d7aea04 to 8c9752d Compare June 4, 2026 08:07
This is another attempt for #37970:
Prior to this change, when running in 64-bit mode, we get
the error `'arith.cmpi' op requires all operands to have the same type` because in the pipeline loop some of the operations get promoted to 64-bit mode.

That change was rolled back due to failure attempting to apply `asarray`
to tracers. I do not have a reproducing failure from the user, but an LLM
suggested that this can happen in presence of disable_jit, and in fact
created a failing test case for a pipeline with disable_jit.

We had used `jnp.int32(0)` and `jnp.int32(num_steps)` which create
tracers instead of constants, which in turn confuses the disable_jit
machinery. A better approach is to use `np.int32(0)`,
which creates constants of the right type. For `num_steps` we should
use `np.int32(num_steps) if isinstance(num_steps, int) else num_steps`.

While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same.

PiperOrigin-RevId: 926542136
@copybara-service copybara-service Bot merged commit 9b82b11 into main Jun 4, 2026
1 check was pending
@copybara-service copybara-service Bot deleted the test_925836768 branch June 4, 2026 08:32
copybara-service Bot pushed a commit that referenced this pull request Jun 4, 2026
That change was #38154.

The problem is that now constants that were 0 or 1 were turned into np.int32 and there are several `is_instance(v, int)` checks throughout Pallas that behave differently than before.
I need to rethink how to make 64-bit more work for pipelines.

Reverts 9b82b11

PiperOrigin-RevId: 926690252
copybara-service Bot pushed a commit that referenced this pull request Jun 4, 2026
That change was #38154.

The problem is that now constants that were 0 or 1 were turned into np.int32 and there are several `is_instance(v, int)` checks throughout Pallas that behave differently than before.
I need to rethink how to make 64-bit more work for pipelines.

Reverts 9b82b11

PiperOrigin-RevId: 926690252
copybara-service Bot pushed a commit that referenced this pull request Jun 4, 2026
That change was #38154.

The problem is that now constants that were 0 or 1 were turned into np.int32 and there are several `is_instance(v, int)` checks throughout Pallas that behave differently than before.
I need to rethink how to make 64-bit more work for pipelines.

Reverts 9b82b11

PiperOrigin-RevId: 926690252
copybara-service Bot pushed a commit that referenced this pull request Jun 4, 2026
That change was #38154.

The problem is that now constants that were 0 or 1 were turned into np.int32 and there are several `is_instance(v, int)` checks throughout Pallas that behave differently than before.
I need to rethink how to make 64-bit more work for pipelines.

Reverts 9b82b11

PiperOrigin-RevId: 926751271
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant