Skip to content

Commit 14b5876

Browse files
author
maxtext authors
committed
Merge pull request #1596 from AI-Hypercomputer:jacobplatin/fix-microbenchmark-tokenizer-issue
PiperOrigin-RevId: 748349490
2 parents 00c406d + 179a342 commit 14b5876

File tree

3 files changed

+31
-31
lines changed

3 files changed

+31
-31
lines changed

MaxText/inference_microbenchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def run_benchmarks(config):
407407

408408
text = config.prompt
409409
metadata = engine.get_tokenizer()
410-
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
410+
tokenizer_model = engine.build_tokenizer(metadata)
411411
rng, rng_init_decode = jax.random.split(rng)
412412

413413
generate_executable, params, decode_state_executable = engine.aot_compile(params, pass_rng_shape=True)
@@ -429,8 +429,8 @@ def run_benchmarks(config):
429429
rng_shape = jax.ShapeDtypeStruct([4], jax.numpy.dtype("uint32"))
430430

431431
for prefill_length in prefill_lengths:
432-
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
433-
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
432+
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = tokenizer_model.encode(
433+
text, is_bos=True, prefill_lengths=[prefill_length]
434434
)
435435

436436
key_shape = jax.ShapeDtypeStruct([prefill_length], jax.numpy.dtype("int32"))

MaxText/tests/grpo_trainer_correctness_test.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,31 +82,30 @@ def prepare_maxtext_inputs(input_str, tokenizer_model):
8282
return input_ids, input_segmentation, input_position, completion_segmentation
8383

8484

85-
8685
class GrpoTrainerTest(unittest.TestCase):
8786

8887
def setUp(self):
8988
super().setUp()
9089
command = [
91-
"gsutil",
92-
"cp",
93-
"-r",
94-
"gs://maxtext-dataset/hf/llama3.1-tokenizer",
95-
os.path.join(os.path.dirname(PKG_DIR), "assets", ""),
90+
"gsutil",
91+
"cp",
92+
"-r",
93+
"gs://maxtext-dataset/hf/llama3.1-tokenizer",
94+
os.path.join(os.path.dirname(PKG_DIR), "assets", ""),
9695
]
9796
exit_code = subprocess.call(command, cwd=os.path.dirname(PKG_DIR))
9897
if exit_code != 0:
9998
raise ValueError(f"{command} failed with exit code: {exit_code}")
10099
self.config = pyconfig.initialize(
101100
[None, "MaxText/experimental/rl/grpo_trainer_test.yml"],
102101
run_name="unit_test_grpo_trainer",
103-
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), 'assets', 'llama3.1-tokenizer'),
102+
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "llama3.1-tokenizer"),
104103
enable_checkpointing=False,
105104
)
106105
self.config_inference = pyconfig.initialize(
107106
[None, "MaxText/experimental/rl/grpo_trainer_test.yml"],
108107
run_name="unit_test_grpo_trainer_inference",
109-
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), 'assets', 'llama3.1-tokenizer'),
108+
tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "llama3.1-tokenizer"),
110109
enable_checkpointing=False,
111110
ici_tensor_parallelism=4,
112111
per_device_batch_size=self.config.per_device_batch_size * self.config.num_generations,
@@ -115,11 +114,11 @@ def setUp(self):
115114
self.atol = 1e-08
116115
self.rng = jax.random.PRNGKey(self.config.init_weights_seed)
117116
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(
118-
self.config.tokenizer_path,
119-
add_bos_token=self.config.add_bos,
120-
add_eos_token=self.config.add_eos,
121-
legacy=False,
122-
padding_side="left"
117+
self.config.tokenizer_path,
118+
add_bos_token=self.config.add_bos,
119+
add_eos_token=self.config.add_eos,
120+
legacy=False,
121+
padding_side="left",
123122
)
124123
self.tokenizer_model.add_special_tokens({"pad_token": "<pad>"})
125124

@@ -135,22 +134,23 @@ def test_grpo_trainer_correctness(self):
135134
)
136135
# Obtain per-token logits.
137136
maxtext_per_token_logps, _ = compute_log_probs(
138-
maxtext_model,
139-
state.params,
140-
input_ids,
141-
input_position,
142-
input_segmentation,
143-
completion_segmentation,
144-
self.config,
145-
is_train=False,
146-
rngs=self.rng,
137+
maxtext_model,
138+
state.params,
139+
input_ids,
140+
input_position,
141+
input_segmentation,
142+
completion_segmentation,
143+
self.config,
144+
is_train=False,
145+
rngs=self.rng,
147146
)
148-
jax.debug.print("maxtext_per_token_logps={maxtext_per_token_logps}",maxtext_per_token_logps=maxtext_per_token_logps)
149-
jax.debug.print("golden_per_token_logps={golden_per_token_logps}",golden_per_token_logps=golden_data["maxtext_per_token_logps_no_ckpt_loading"])
150-
golden_maxtext_logits = np.array(golden_data["maxtext_per_token_logps_no_ckpt_loading"])
151-
self.assertTrue(
152-
jnp.all(np.array(golden_data["input_ids"]) == np.array(input_ids[0]))
147+
jax.debug.print("maxtext_per_token_logps={maxtext_per_token_logps}", maxtext_per_token_logps=maxtext_per_token_logps)
148+
jax.debug.print(
149+
"golden_per_token_logps={golden_per_token_logps}",
150+
golden_per_token_logps=golden_data["maxtext_per_token_logps_no_ckpt_loading"],
153151
)
152+
golden_maxtext_logits = np.array(golden_data["maxtext_per_token_logps_no_ckpt_loading"])
153+
self.assertTrue(jnp.all(np.array(golden_data["input_ids"]) == np.array(input_ids[0])))
154154
self.assertTrue(
155155
jax.numpy.allclose(
156156
maxtext_per_token_logps[0],

MaxText/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def setup_train_loop(config):
663663
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None)
664664
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
665665

666-
context_parallel_size = mesh.shape['context']
666+
context_parallel_size = mesh.shape["context"]
667667
# Check if context parallelism is being used with sequence packing
668668
if context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic":
669669
raise ValueError(

0 commit comments

Comments
 (0)