Skip to content

Commit a022ce3

Browse files
Merge pull request #2574 from AI-Hypercomputer:chengnuojin-add-tiled
PiperOrigin-RevId: 826678381
2 parents bab34e5 + 3d6b929 commit a022ce3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/forward_pass_logit_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def main(config, test_args): # pylint: disable=W0621
279279
rngs={"aqt": init_rng},
280280
)
281281

282-
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
282+
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True)
283283
# if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size]
284284
if full_train_logits.ndim == 4:
285285
full_train_logits = jnp.reshape(full_train_logits, (-1, config.max_target_length, config.vocab_size))

0 commit comments

Comments
 (0)