File tree Expand file tree Collapse file tree
python/sgl_jax/srt/multimodal/models/dits Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -42,6 +42,13 @@ def _no_shard(x: jax.Array, mesh: Mesh | None) -> jax.Array:
4242 return jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P ()))
4343
4444
45+ def _data_seq_tensor_shard (x : jax .Array , mesh : Mesh | None ) -> jax .Array :
46+ if mesh is None :
47+ return x
48+ output_pspec = P ("data" , * ([None ] * (x .ndim - 2 )), "tensor" )
49+ return jax .lax .with_sharding_constraint (x , NamedSharding (mesh , output_pspec ))
50+
51+
4552def _sdpa_attention (
4653 query : jax .Array ,
4754 key : jax .Array ,
@@ -294,6 +301,7 @@ def __call__(
294301
295302 hidden_states = hidden_states .reshape (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 )
296303 hidden_states = hidden_states .astype (query .dtype )
304+ hidden_states = _data_seq_tensor_shard (hidden_states , self .mesh )
297305
298306 if encoder_hidden_states is not None and self .added_kv_proj_dim is not None :
299307 context_len = encoder_hidden_states .shape [1 ]
You can’t perform that action at this time.
0 commit comments