Skip to content

Commit 41dd12b

Browse files
author
bjmsong
committed
resolve conflict
1 parent 7eef7d1 commit 41dd12b

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

  • python/sgl_jax/srt/multimodal/models/dits

python/sgl_jax/srt/multimodal/models/dits/flux.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def _no_shard(x: jax.Array, mesh: Mesh | None) -> jax.Array:
4242
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
4343

4444

45+
def _data_seq_tensor_shard(x: jax.Array, mesh: Mesh | None) -> jax.Array:
46+
if mesh is None:
47+
return x
48+
output_pspec = P("data", *([None] * (x.ndim - 2)), "tensor")
49+
return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, output_pspec))
50+
51+
4552
def _sdpa_attention(
4653
query: jax.Array,
4754
key: jax.Array,
@@ -294,6 +301,7 @@ def __call__(
294301

295302
hidden_states = hidden_states.reshape(hidden_states.shape[0], hidden_states.shape[1], -1)
296303
hidden_states = hidden_states.astype(query.dtype)
304+
hidden_states = _data_seq_tensor_shard(hidden_states, self.mesh)
297305

298306
if encoder_hidden_states is not None and self.added_kv_proj_dim is not None:
299307
context_len = encoder_hidden_states.shape[1]

0 commit comments

Comments
 (0)