[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154
Open
copybara-service[bot] wants to merge 1 commit into
Open
[pallas] Fix Pallas Mosaic 64-bit mode loop indexing#38154copybara-service[bot] wants to merge 1 commit into
copybara-service[bot] wants to merge 1 commit into
Conversation
318b505 to
28fc538
Compare
This is another attempt for #37970: Prior to this change, when running in 64-bit mode, we get the error `'arith.cmpi' op requires all operands to have the same type` because in the pipeline loop some of the operations get promoted to 64-bit mode. That change was rolled back due to failure attempting to apply `asarray` to tracers. I do not have a reproducing failure from the user, but an LLM suggested that this can happen in presence of disable_jit, and in fact created a failing test case for a pipeline with disable_jit. We had used `jnp.int32(0)` and `jnp.int32(num_steps)` which create tracers instead of constants, which in turn confuses the disable_jit machinery. A better approach is to use `np.int32(0)`, which creates constants of the right type. For `num_steps` we should use `np.int32(num_steps) if isinstance(num_steps, int) else num_steps`. While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same. PiperOrigin-RevId: 925836768
28fc538 to
244ffe1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[pallas] Fix Pallas Mosaic 64-bit mode loop indexing
This is another attempt for #37970:
Prior to this change, when running in 64-bit mode, we get
the error
'arith.cmpi' op requires all operands to have the same typebecause in the pipeline loop some of the operations get promoted to 64-bit mode.That change was rolled back due to failure attempting to apply
asarrayto tracers. I do not have a reproducing failure from the user, but an LLM
suggested that this can happen in presence of disable_jit, and in fact
created a failing test case for a pipeline with disable_jit.
We had used
jnp.int32(0)andjnp.int32(num_steps)which createtracers instead of constants, which in turn confuses the disable_jit
machinery. A better approach is to use
np.int32(0),which creates constants of the right type. For
num_stepswe shoulduse
np.int32(num_steps) if isinstance(num_steps, int) else num_steps.While fixing this I ran into a lowering error "error: 'arith.trunci' op operand type 'i32' and result type 'i32' are cast incompatible" which I worked around by not emitting a truncation when the types are the same.