Skip to content

Invalid Triton code from derivatives of jnp.einsum: 'tt.dot' op expected the last dimension of the first operand to be equal to the second-to-last dimension of the second operand #34649

@am001122

Description

@am001122

TL;DR: under certain very specific circumstances, jnp.einsum seems to cause XLA to emit invalid Triton tt.dot calls, which cause logged errors or segfaults without error messages depending on the JAX version.

Issue

I am working with a complex codebase where, after making a simple change to the model architecture, JAX v0.8.2 and v0.8.1 started to silently segfault inside of XLA compilation. The segfaults were not reproducible in any narrower context than a full run and only occurred with certain hyperparameters (which control feature width). The values for those hyperparameters were always powers of two.

Updating to JAX 0.9.0 resolves the segfaults and the program is able to run and appears to be numerically correct, but prints (sometimes multiple) Triton errors like:

loc("reduce.177.1"): error: 'tt.dot' op expected the last dimension of the first operand to be equal to the second-to-last dimension of the second operand

Investigation

I was completely unable to make a minimal reproducer in Python or to reproduce the issue with anything smaller than running a job with the full code.

I tried to track down the root cause by running with

XLA_FLAGS="--xla_dump_to=/tmp/jax_dump --xla_dump_hlo_as_text"

and tracked down the responsible HLO in a file called module_4999.jit_myfunc.sm_8.9_gpu_after_optimizations.txt:

%fused_reduce.9 (param_0.6500: f32[32], param_1.8410: f32[128,128], param_2.4697: f32[128], param_3.2885: f32[128]) -> f32[128,128] {
  %param_1.8410 = f32[128,128]{1,0} parameter(1)
  %param_3.2885 = f32[128]{0} parameter(3)
  %broadcast.787.12 = f32[128,128]{1,0} broadcast(%param_3.2885), dimensions={1}, metadata={op_name="jit(myfunc)/jvp(jvp(abc,a,b->abc))/dot_general" stack_frame_id=132}
  %param_2.4697 = f32[128]{0} parameter(2)
  %broadcast.788.12 = f32[128,128]{1,0} broadcast(%param_2.4697), dimensions={0}, metadata={op_name="jit(myfunc)/jvp(jvp(abc,a,b->abc))/dot_general" stack_frame_id=132}
  %multiply.296.9 = f32[128,128]{1,0} multiply(%broadcast.787.12, %broadcast.788.12), metadata={op_name="jit(myfunc)/jvp(jvp(abc,a,b->abc))/dot_general" stack_frame_id=132}
  %multiply.406.3 = f32[128,128]{1,0} multiply(%param_1.8410, %multiply.296.9), metadata={op_name="jit(myfunc)/jvp(jvp(abc,a,b->abc))/dot_general" stack_frame_id=132}
  %broadcast.845.4 = f32[128,128,32,1]{3,2,1,0} broadcast(%multiply.406.3), dimensions={0,1}, metadata={op_name="jit(myfunc)/jvp(jvp())/broadcast_in_dim" stack_frame_id=110}
  %param_0.6500 = f32[32]{0} parameter(0)
  %broadcast.466.1 = f32[128,128,32,1]{3,2,1,0} broadcast(%param_0.6500), dimensions={2}, metadata={op_name="jit(myfunc)/jvp(jvp(attn))/dot_general" stack_frame_id=190}
  %multiply.150.3 = f32[128,128,32,1]{3,2,1,0} multiply(%broadcast.845.4, %broadcast.466.1), metadata={op_name="jit(myfunc)/jvp(jvp(attn))/dot_general" stack_frame_id=190}
  %bitcast.4069.1 = f32[128,128,32]{2,1,0} bitcast(%multiply.150.3)
  %constant_496_112 = f32[] constant(0)
  ROOT %reduce.177.1 = f32[128,128]{1,0} reduce(%bitcast.4069.1, %constant_496_112), dimensions={2}, to_apply=%region_0.2.clone.70, metadata={op_name="jit(myfunc)/jvp(jvp(attn))/dot_general" stack_frame_id=190}
}

using

%region_0.2.clone.70 (reduce_sum.525: f32[], reduce_sum.529: f32[]) -> f32[] {
  %reduce_sum.529 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  %reduce_sum.525 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.531.0 = f32[] add(%reduce_sum.525, %reduce_sum.529), metadata={op_name="jit(myfunc)/jvp(jvp())/reduce_sum" stack_frame_id=129}
}

and called by

%input_reduce_fusion.6 = f32[128,128]{1,0} fusion(%constant_497_0, %loop_add_fusion.44, %model_states_0__34_.1, %loop_rsqrt_fusion), kind=kInput, calls=%fused_reduce.9, metadata={op_name="jit(myfunc)/jvp(jvp(attn))/dot_general" stack_frame_id=190}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID","native_emitter_backend_config":{}}

This matches to a specific einsum of the form jnp.einsum("abc,a,b->abc", A, B, C) in my Python code with shapes:

  • A [128, 128, 32]
  • B [128]
  • C [128]

Workarounds

I found that disabling Triton autotuning makes the messages go away:

XLA_FLAGS="--xla_gpu_experimental_enable_fusion_autotuner=false"

I also found that replacing this jnp.einsum with explicit broadcast-and-mul also appears to resolve the issue:

A * B[:, None, None] * C[None, :, None]

Environment

  • JAX 0.9.0 (and previously 0.8.1 and 0.8.2) as jax[cuda12]
  • CUDA compute capability 8.9
  • System CUDA driver version: 560.35.03 and CUDA version: 12.6
  • Ubuntu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions