Skip to content

Commit ec4ba8a

Browse files
committed
Implement correctly
1 parent 2dfb5e6 commit ec4ba8a

File tree

7 files changed

+124
-76
lines changed

7 files changed

+124
-76
lines changed

examples/multimodal/run_text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def generate_samples(model, config: EvaluationConfig, print_output):
227227
inference_request = VLMInferenceRequest(
228228
request_id=inference_engine.get_new_request_id(),
229229
prompt=conv,
230-
prompt_tokens=controller.tokenize_prompt(conv),
230+
prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv),
231231
sampling_params=sampling_params,
232232
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
233233
imgs=imgs,
@@ -344,7 +344,7 @@ def generate_samples(model, config: EvaluationConfig, print_output):
344344
inference_request = VLMInferenceRequest(
345345
request_id=inference_engine.get_new_request_id(),
346346
prompt=conv,
347-
prompt_tokens=controller.tokenize_prompt(conv),
347+
prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv),
348348
sampling_params=sampling_params,
349349
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
350350
imgs=imgs,

megatron/core/inference/data_parallel_inference_coordinator.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from megatron.core.inference.config import PrefixCachingCoordinatorPolicy
1717
from megatron.core.inference.headers import Headers, UnknownHeaderError
1818
from megatron.core.inference.inference_request import compute_block_hashes_batched
19+
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
20+
TextGenerationController,
21+
)
1922

2023
try:
2124
import zmq
@@ -501,18 +504,12 @@ def detokenize(self, finished_request):
501504
generated tokens to be detokenized. It is modified in place.
502505
"""
503506
if finished_request["prompt"] is None:
504-
finished_request["prompt"] = self.tokenizer.detokenize(
505-
finished_request["prompt_tokens"][1]
507+
finished_request["prompt"] = TextGenerationController.detokenize(
508+
self.tokenizer, finished_request["prompt_tokens"][1], remove_EOD=False
506509
)
507-
generated_tokens = finished_request["generated_tokens"]
508-
termination_id = (finished_request.get("sampling_params", {}) or {}).get("termination_id")
509-
while (
510-
generated_tokens
511-
and termination_id is not None
512-
and generated_tokens[-1] == termination_id
513-
):
514-
generated_tokens = generated_tokens[:-1]
515-
finished_request["generated_text"] = self.tokenizer.detokenize(generated_tokens)
510+
finished_request["generated_text"] = TextGenerationController.detokenize(
511+
self.tokenizer, finished_request["generated_tokens"]
512+
)
516513

517514
@classmethod
518515
def entrypoint(

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,9 @@ def _add_request(
890890
# Tokenize stop words if provided
891891
if request.sampling_params.stop_words:
892892
stop_word_ids = [
893-
self.controller.tokenize_prompt(stop_word, add_BOS=False)
893+
TextGenerationController.tokenize_prompt(
894+
self.controller.tokenizer, stop_word, add_BOS=False
895+
)
894896
for stop_word in request.sampling_params.stop_words
895897
]
896898
request.stop_word_ids = stop_word_ids
@@ -931,9 +933,13 @@ def add_request(
931933
# Tokenize prompt if text. Support legacy single-arg mocks.
932934
prompt_str = prompt
933935
try:
934-
prompt_token_ids = self.controller.tokenize_prompt(prompt, sampling_params.add_BOS)
936+
prompt_token_ids = TextGenerationController.tokenize_prompt(
937+
self.controller.tokenizer, prompt, sampling_params.add_BOS
938+
)
935939
except TypeError:
936-
prompt_token_ids = self.controller.tokenize_prompt(prompt)
940+
prompt_token_ids = TextGenerationController.tokenize_prompt(
941+
self.controller.tokenizer, prompt
942+
)
937943
tokens = torch.tensor(
938944
prompt_token_ids, dtype=torch.int64, device=torch.cuda.current_device()
939945
)
@@ -1635,18 +1641,14 @@ async def async_bookkeep(
16351641
for record in finished_request_records:
16361642
for request in record.requests:
16371643
if request.prompt is None:
1638-
request.prompt = self.controller.tokenizer.detokenize(
1639-
request.prompt_tokens.tolist()
1644+
request.prompt = TextGenerationController.detokenize(
1645+
self.controller.tokenizer,
1646+
request.prompt_tokens.tolist(),
1647+
remove_EOD=False,
16401648
)
1641-
generated_tokens = request.generated_tokens
1642-
termination_id = request.sampling_params.termination_id
1643-
while (
1644-
generated_tokens
1645-
and termination_id is not None
1646-
and generated_tokens[-1] == termination_id
1647-
):
1648-
generated_tokens = generated_tokens[:-1]
1649-
request.generated_text = self.controller.tokenizer.detokenize(generated_tokens)
1649+
request.generated_text = TextGenerationController.detokenize(
1650+
self.controller.tokenizer, request.generated_tokens
1651+
)
16501652
range_pop()
16511653

16521654
# Handle necessary ZMQ DP coordinator communication.

megatron/core/inference/engines/static_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def add_request(
174174

175175
if inference_request is None:
176176
# Support legacy single-arg tokenize_prompt mocks in tests.
177-
prompt_tokens = self.controller.tokenize_prompt(prompt, add_BOS)
177+
prompt_tokens = TextGenerationController.tokenize_prompt(
178+
self.controller.tokenizer, prompt, add_BOS
179+
)
178180
else:
179181
prompt_tokens = inference_request.prompt_tokens
180182

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -157,58 +157,64 @@ def _init_mtp_sampling_tensor(self):
157157
* -1
158158
)
159159

160-
def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]:
160+
@staticmethod
161+
def tokenize_prompt(tokenizer, prompt: str, add_BOS: bool = False) -> List[int]:
161162
"""Utility to tokenize the input prompts.
162163
163164
Args:
165+
tokenizer: The tokenizer to use.
164166
prompt (str): The input prompt.
167+
add_BOS (bool): Whether to add a BOS token.
165168
166169
Returns:
167170
List[int]: Returns the tokenized prompt.
168171
"""
169172

170-
prompt_tokens = self.tokenizer.tokenize(prompt)
173+
prompt_tokens = tokenizer.tokenize(prompt)
171174

172175
if add_BOS:
173-
assert self.tokenizer.bos is not None
176+
assert tokenizer.bos is not None
174177

175-
while prompt_tokens and prompt_tokens[0] == self.tokenizer.bos:
178+
while prompt_tokens and prompt_tokens[0] == tokenizer.bos:
176179
prompt_tokens.pop(0)
177180

178181
if add_BOS:
179-
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
182+
prompt_tokens = [tokenizer.bos] + prompt_tokens
180183

181184
return prompt_tokens
182185

183-
def _detokenize(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
186+
@staticmethod
187+
def detokenize(
188+
tokenizer, tokens: List[int], remove_EOD: bool = True, skip_special_tokens: bool = True
189+
) -> str:
184190
"""
185-
Detokenize a sequence of token IDs, handling skip_special_tokens for
186-
different tokenizer APIs.
187-
188-
On the first call, inspects `self.tokenizer.detokenize` to see if it accepts
189-
a `skip_special_tokens` keyword argument, and caches that result on `self`.
190-
Subsequent calls will use the cached flag to invoke `detokenize` with the
191-
correct signature (with or without `skip_special_tokens`).
191+
Detokenize a sequence of token IDs, optionally removing trailing EOD
192+
tokens and handling skip_special_tokens for different tokenizer APIs.
192193
193194
Args:
195+
tokenizer: The tokenizer to use for detokenization.
194196
tokens (List[int]): The token IDs to convert back to text.
197+
remove_EOD (bool): Whether to remove trailing EOD tokens before
198+
detokenization. Defaults to True.
195199
skip_special_tokens (bool): Whether to remove special tokens (e.g. BOS/EOS)
196200
during detokenization. Only passed through if the tokenizer supports it.
197201
198202
Returns:
199203
str: The detokenized string.
200204
"""
201-
# cache the check on first call
202-
if not hasattr(self, "_detok_accepts_skip"):
203-
sig_params = inspect.signature(self.tokenizer.detokenize).parameters.values()
204-
self._detok_accepts_skip = any(
205-
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
206-
for p in sig_params
207-
)
208-
if self._detok_accepts_skip:
209-
return self.tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
205+
if remove_EOD:
206+
while tokens and tokens[-1] == tokenizer.eod:
207+
tokens = tokens[:-1]
208+
209+
sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
210+
detok_accepts_skip = any(
211+
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
212+
for p in sig_params
213+
)
214+
if detok_accepts_skip:
215+
return tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
210216
else:
211-
return self.tokenizer.detokenize(tokens)
217+
return tokenizer.detokenize(tokens)
212218

213219
def detokenize_generations(
214220
self,
@@ -237,7 +243,10 @@ def detokenize_generations(
237243

238244
if not detokenize_segments:
239245
tokens = tokens_gpu_tensor.tolist()
240-
return self._detokenize(tokens, skip_special_tokens=skip_special_tokens), None
246+
return (
247+
self.detokenize(self.tokenizer, tokens, skip_special_tokens=skip_special_tokens),
248+
None,
249+
)
241250

242251
prompts_plus_generations: List[str] = []
243252
prompts_plus_generations_segments: List[List[str]] = []
@@ -247,7 +256,7 @@ def detokenize_generations(
247256

248257
for sequence_tokens, length in zip(tokens, lengths):
249258
sequence_tokens = sequence_tokens[:length]
250-
detok_str = self._detokenize(sequence_tokens)
259+
detok_str = self.detokenize(self.tokenizer, sequence_tokens)
251260
prompts_plus_generations.append(detok_str)
252261
offsets = self.tokenizer.offsets(sequence_tokens, detok_str)
253262
words = [
@@ -256,7 +265,7 @@ def detokenize_generations(
256265

257266
prompts_plus_generations_segments.append(words)
258267

259-
text = self._detokenize(tokens[0], skip_special_tokens=skip_special_tokens)
268+
text = self.detokenize(self.tokenizer, tokens[0], skip_special_tokens=skip_special_tokens)
260269

261270
return text, prompts_plus_generations_segments
262271

@@ -2469,7 +2478,7 @@ def stream_token(
24692478

24702479
return_segments = sampling_params.return_segments
24712480
detokenize_streaming_text = not getattr(
2472-
sampling_params, "no_detokenize_streaming_text", False
2481+
sampling_params, "nodetokenize_streaming_text", False
24732482
)
24742483

24752484
generated_tokens = tokens[prompt_length : prompt_length + generated_length]

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def test_generate_function(self, model_provider: str) -> None:
775775
prompts = ["prompt1", "prompt2", "prompt3", "prompt4"]
776776

777777
# Mock the tokenize_prompt method to return predictable token sequences
778-
def mock_tokenize_prompt(prompt, add_BOS=False):
778+
def mock_tokenize_prompt(tokenizer, prompt, add_BOS=False):
779779
# Return a token sequence based on the prompt number
780780
prompt_num = int(prompt[-1])
781781
return [10 + i for i in range(prompt_num + 2)]

tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def test_add_bos_token(self):
598598
self.mock_tokenizer.vocab_size = self.vocab_size
599599
self.mock_tokenizer.bos = 0
600600
self.mock_tokenizer.eod = self.vocab_size - 1
601-
self.mock_tokenizer.detokenize.side_effect = lambda x: ' '.join(
601+
self.mock_tokenizer.detokenize.side_effect = lambda x, **_: ' '.join(
602602
[
603603
''.join(random.choices(string.ascii_letters, k=random.randint(1, len(prompt))))
604604
for _ in range(len(x))
@@ -611,35 +611,73 @@ def test_add_bos_token(self):
611611
random.randint(0, self.vocab_size - 1) for _ in range(len(prompt))
612612
]
613613

614+
tokenizer = self.mock_tokenizer
615+
614616
# Test on a tokenizer that does not add BOS by default
615-
no_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False)
616-
assert no_bos_to_no_bos[0] != self.mock_tokenizer.bos
617-
no_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True)
618-
assert no_bos_to_yes_bos[0] == self.mock_tokenizer.bos
619-
assert no_bos_to_yes_bos[1] != self.mock_tokenizer.bos
617+
no_bos_to_no_bos = TextGenerationController.tokenize_prompt(
618+
tokenizer, prompt, add_BOS=False
619+
)
620+
assert no_bos_to_no_bos[0] != tokenizer.bos
621+
no_bos_to_yes_bos = TextGenerationController.tokenize_prompt(
622+
tokenizer, prompt, add_BOS=True
623+
)
624+
assert no_bos_to_yes_bos[0] == tokenizer.bos
625+
assert no_bos_to_yes_bos[1] != tokenizer.bos
620626

621627
# Force the first token to be BOS to emulate a tokenizer that does add BOS by default
622-
self.mock_tokenizer.tokenize.return_value[0] = self.mock_tokenizer.bos
628+
tokenizer.tokenize.return_value[0] = tokenizer.bos
623629

624-
yes_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False)
625-
assert yes_bos_to_no_bos[0] != self.mock_tokenizer.bos
626-
yes_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True)
627-
assert yes_bos_to_yes_bos[0] == self.mock_tokenizer.bos
628-
assert yes_bos_to_yes_bos[1] != self.mock_tokenizer.bos
630+
yes_bos_to_no_bos = TextGenerationController.tokenize_prompt(
631+
tokenizer, prompt, add_BOS=False
632+
)
633+
assert yes_bos_to_no_bos[0] != tokenizer.bos
634+
yes_bos_to_yes_bos = TextGenerationController.tokenize_prompt(
635+
tokenizer, prompt, add_BOS=True
636+
)
637+
assert yes_bos_to_yes_bos[0] == tokenizer.bos
638+
assert yes_bos_to_yes_bos[1] != tokenizer.bos
629639

630640
# Test on an input that has had multiple BOS added
631-
self.mock_tokenizer.tokenize.return_value[1] = self.mock_tokenizer.bos
641+
tokenizer.tokenize.return_value[1] = tokenizer.bos
632642

633-
many_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False)
634-
assert many_bos_to_no_bos[0] != self.mock_tokenizer.bos
635-
many_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True)
636-
assert many_bos_to_yes_bos[0] == self.mock_tokenizer.bos
637-
assert many_bos_to_yes_bos[1] != self.mock_tokenizer.bos
643+
many_bos_to_no_bos = TextGenerationController.tokenize_prompt(
644+
tokenizer, prompt, add_BOS=False
645+
)
646+
assert many_bos_to_no_bos[0] != tokenizer.bos
647+
many_bos_to_yes_bos = TextGenerationController.tokenize_prompt(
648+
tokenizer, prompt, add_BOS=True
649+
)
650+
assert many_bos_to_yes_bos[0] == tokenizer.bos
651+
assert many_bos_to_yes_bos[1] != tokenizer.bos
638652

639653
# Test the assert triggered when the tokenizer has no bos
640-
self.mock_tokenizer.bos = None
654+
tokenizer.bos = None
641655
with pytest.raises(AssertionError):
642-
self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True)
656+
TextGenerationController.tokenize_prompt(tokenizer, prompt, add_BOS=True)
657+
658+
@pytest.mark.parametrize("remove_EOD", [True, False])
659+
def test_remove_eod_token(self, remove_EOD):
660+
self.setup_model(torch.float32)
661+
662+
self.mock_tokenizer.vocab_size = self.vocab_size
663+
self.mock_tokenizer.bos = 0
664+
self.mock_tokenizer.eod = self.vocab_size - 1
665+
self.mock_tokenizer.detokenize.side_effect = lambda x, **_: ' '.join(f"T{t}" for t in x)
666+
667+
tokenizer = self.mock_tokenizer
668+
eod = tokenizer.eod
669+
detok = TextGenerationController.detokenize
670+
671+
# No trailing EOD.
672+
assert detok(tokenizer, [1, 2, 3], remove_EOD=remove_EOD) == "T1 T2 T3"
673+
674+
# Single trailing EOD.
675+
result = detok(tokenizer, [1, 2, eod], remove_EOD=remove_EOD)
676+
assert result == ("T1 T2" if remove_EOD else f"T1 T2 T{eod}")
677+
678+
# Multiple trailing EOD.
679+
result = detok(tokenizer, [1, eod, eod, eod], remove_EOD=remove_EOD)
680+
assert result == ("T1" if remove_EOD else f"T1 T{eod} T{eod} T{eod}")
643681

644682
def test_zero_tokens_generated_batch_vs_single(self):
645683
"""

0 commit comments

Comments
 (0)