Skip to content

Commit a779638

Browse files
committed
remove common_types
1 parent 5114e22 commit a779638

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

MaxText/layers/attentions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def cudnn_jax_flash_attention(
872872
key: Array,
873873
value: Array,
874874
decoder_segment_ids: Array | None,
875-
model_mode: str = common_types.MODEL_MODE_TRAIN,
875+
model_mode: str = MODEL_MODE_TRAIN,
876876
) -> Array:
877877
"""CUDNN Flash Attention with JAX SDPA API.
878878
"""
@@ -885,7 +885,7 @@ def cudnn_jax_flash_attention(
885885

886886
_, _, _, head_dim = query.shape # pylint: disable=unused-variable
887887

888-
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
888+
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
889889
lengths = jnp.sum(decoder_segment_ids, axis=-1)
890890

891891
return dot_product_attention(

0 commit comments

Comments
 (0)