Skip to content

Commit 893129c

Browse files
committed
fix(data): address collate packing review comments
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent 5a25523 commit 893129c

22 files changed

Lines changed: 185 additions & 154 deletions

File tree

src/megatron/bridge/data/energon/hf_task_encoder.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
"""
2121

2222
import dataclasses
23-
import inspect
2423
from dataclasses import dataclass, field
2524
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
2625

@@ -100,7 +99,7 @@ def __init__(
10099
visual_keys: Sequence[str] = ("pixel_values",),
101100
min_pixels: Optional[int] = None,
102101
max_pixels: Optional[int] = None,
103-
collate_fn: Callable[[list, Any], dict[str, Any]] | None = None,
102+
collate_fn: Callable[..., dict[str, Any]] | None = None,
104103
pad_to_max_length: bool = False,
105104
pad_to_multiple_of: int = 128,
106105
enable_in_batch_packing: bool = False,
@@ -127,31 +126,6 @@ def __init__(
127126
)
128127
self._collate_impl = COLLATE_FNS[collate_key]
129128

130-
def _supported_collate_kwargs(self) -> dict[str, Any]:
131-
"""Return encoder options accepted by the selected collate function."""
132-
try:
133-
parameters = inspect.signature(self._collate_impl).parameters
134-
except (TypeError, ValueError):
135-
return {}
136-
137-
accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values())
138-
candidates: dict[str, Any] = {
139-
"visual_keys": self.visual_keys,
140-
"sequence_length": self.seq_length,
141-
"pad_to_max_length": self.pad_to_max_length,
142-
"pad_to_multiple_of": self.pad_to_multiple_of,
143-
"pack_sequences": self.enable_in_batch_packing,
144-
"pack_sequences_pad_to_multiple_of": self.in_batch_packing_pad_to_multiple_of,
145-
}
146-
if self.min_pixels is not None:
147-
candidates["min_pixels"] = self.min_pixels
148-
if self.max_pixels is not None:
149-
candidates["max_pixels"] = self.max_pixels
150-
151-
if accepts_kwargs:
152-
return candidates
153-
return {key: value for key, value in candidates.items() if key in parameters}
154-
155129
def encode_sample(self, sample: ChatMLSample) -> HFEnergonSample:
156130
"""Normalize a single ChatML sample into a HF-style collate example.
157131
@@ -185,7 +159,18 @@ def collate_fn(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
185159
The exact batch dictionary returned by the selected HF collate
186160
function for this processor type.
187161
"""
188-
return self._collate_impl(examples, self.processor, **self._supported_collate_kwargs())
162+
return self._collate_impl(
163+
examples,
164+
self.processor,
165+
visual_keys=self.visual_keys,
166+
min_pixels=self.min_pixels,
167+
max_pixels=self.max_pixels,
168+
sequence_length=self.seq_length,
169+
pad_to_max_length=self.pad_to_max_length,
170+
pad_to_multiple_of=self.pad_to_multiple_of,
171+
pack_sequences=self.enable_in_batch_packing,
172+
in_batch_packing_pad_to_multiple_of=self.in_batch_packing_pad_to_multiple_of,
173+
)
189174

190175
# ------------------------------------------------------------------
191176
# batch

src/megatron/bridge/data/hf_datasets/conversation_dataset.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
base_examples: List[Dict[str, Any]],
3737
target_length: int,
3838
processor: Any,
39-
collate_impl: Optional[Callable[[list, Any], Dict[str, torch.Tensor]]] = None,
39+
collate_impl: Optional[Callable[..., Dict[str, torch.Tensor]]] = None,
4040
sequence_length: int | None = None,
4141
pad_to_max_length: bool = False,
4242
pad_to_multiple_of: int = 128,
@@ -49,48 +49,27 @@ def __init__(
4949
self._processor = processor
5050
# Choose collate implementation by processor type name when not provided
5151
collate_key = type(processor).__name__ if processor is not None else "default"
52-
if collate_impl is not None:
53-
selected_impl = collate_impl
54-
else:
52+
if collate_impl is None:
5553
from megatron.bridge.data.vlm_datasets.collate import COLLATE_FNS
5654

5755
if collate_key not in COLLATE_FNS:
5856
raise ValueError(
5957
f"No conversation collate function registered for processor type '{collate_key}'. "
6058
"Add it to COLLATE_FNS or pass collate_impl explicitly."
6159
)
62-
selected_impl = COLLATE_FNS[collate_key]
60+
collate_impl = COLLATE_FNS[collate_key]
61+
assert collate_impl is not None
6362

64-
# If in-batch packing is requested, bind the selected collate's packing
65-
# kwargs via functools.partial so the DataLoader just calls f(batch, processor).
66-
import inspect
67-
from functools import partial
68-
69-
sig = inspect.signature(selected_impl)
70-
collate_kwargs: dict[str, Any] = {}
71-
if sequence_length is not None and "sequence_length" in sig.parameters:
72-
collate_kwargs["sequence_length"] = sequence_length
73-
if "pad_to_max_length" in sig.parameters:
74-
collate_kwargs["pad_to_max_length"] = pad_to_max_length
75-
if "pad_to_multiple_of" in sig.parameters:
76-
collate_kwargs["pad_to_multiple_of"] = pad_to_multiple_of
77-
78-
if enable_in_batch_packing:
79-
if "pack_sequences" in sig.parameters:
80-
collate_kwargs["pack_sequences"] = True
81-
if "pack_sequences_pad_to_multiple_of" in sig.parameters:
82-
collate_kwargs["pack_sequences_pad_to_multiple_of"] = in_batch_packing_pad_to_multiple_of
83-
else:
84-
raise ValueError(
85-
f"Collate function {getattr(selected_impl, '__name__', selected_impl)} "
86-
f"does not accept in-batch packing. Use a collate that supports packing "
87-
f"(e.g. nemotron_omni_collate_fn)."
88-
)
89-
if collate_kwargs:
90-
selected_impl = partial(selected_impl, **collate_kwargs)
63+
collate_kwargs: dict[str, Any] = {
64+
"sequence_length": sequence_length,
65+
"pad_to_max_length": pad_to_max_length,
66+
"pad_to_multiple_of": pad_to_multiple_of,
67+
"pack_sequences": enable_in_batch_packing,
68+
"in_batch_packing_pad_to_multiple_of": in_batch_packing_pad_to_multiple_of,
69+
}
9170

9271
def _bound_collate(batch: list) -> Dict[str, torch.Tensor]:
93-
return selected_impl(batch, self._processor) # type: ignore[call-arg]
72+
return collate_impl(batch, self._processor, **collate_kwargs)
9473

9574
self.collate_fn = _bound_collate
9675

src/megatron/bridge/data/hf_datasets/provider.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Provider that builds conversation datasets from HuggingFace datasets."""
1616

17-
import inspect
1817
import logging
1918
from dataclasses import dataclass
2019
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
@@ -78,7 +77,7 @@ class HFConversationDatasetProvider(DatasetProvider):
7877
do_test: bool = True
7978

8079
# Optional collate override. If None, inferred from processor type.
81-
collate_impl: Optional[Callable[[list, Any], Dict[str, torch.Tensor]]] = None
80+
collate_impl: Optional[Callable[..., Dict[str, torch.Tensor]]] = None
8281

8382
# Keep parity with GPTDatasetConfig usage in batching utilities
8483
skip_getting_attention_mask_from_dataset: bool = True
@@ -99,18 +98,6 @@ class HFConversationDatasetProvider(DatasetProvider):
9998
# ConfigContainer fills this from model CP/SP constraints when available.
10099
in_batch_packing_pad_to_multiple_of: int = 1
101100

102-
def _collate_supports_packing(self, processor: Any) -> bool:
103-
collate_key = type(processor).__name__ if processor is not None else "default"
104-
if self.collate_impl is not None:
105-
selected_impl = self.collate_impl
106-
else:
107-
from megatron.bridge.data.vlm_datasets.collate import COLLATE_FNS
108-
109-
selected_impl = COLLATE_FNS.get(collate_key)
110-
if selected_impl is None:
111-
return False
112-
return "pack_sequences" in inspect.signature(selected_impl).parameters
113-
114101
def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]:
115102
return get_hf_dataset_maker(self.maker_name)
116103

@@ -139,7 +126,7 @@ def _build_split_dataset(
139126
sequence_length=self.seq_length,
140127
pad_to_max_length=self.pad_to_max_length,
141128
pad_to_multiple_of=self.pad_to_multiple_of,
142-
enable_in_batch_packing=self.enable_in_batch_packing and self._collate_supports_packing(processor),
129+
enable_in_batch_packing=self.enable_in_batch_packing,
143130
in_batch_packing_pad_to_multiple_of=self.in_batch_packing_pad_to_multiple_of,
144131
)
145132

src/megatron/bridge/data/hf_datasets/text_collate.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,13 @@ def text_chat_collate_fn(
156156
processor: Any,
157157
*,
158158
max_length: int | None = None,
159+
sequence_length: int | None = None,
159160
pad_to_max_length: bool = False,
161+
pad_to_multiple_of: int = 128,
160162
warn_on_all_masked: bool = True,
161163
ignore_index: int = IGNORE_INDEX,
162164
pack_sequences: bool = False,
163-
pack_sequences_pad_to_multiple_of: int = 1,
165+
in_batch_packing_pad_to_multiple_of: int = 1,
164166
) -> dict[str, Any]:
165167
"""Collate text-only HF chat examples using the shared assistant-mask path.
166168
@@ -170,20 +172,26 @@ def text_chat_collate_fn(
170172
processor: A HF tokenizer or processor. It must expose
171173
``apply_chat_template`` directly or through ``processor.tokenizer``.
172174
max_length: Optional tokenizer truncation length.
175+
sequence_length: Optional tokenizer truncation length used by
176+
conversation-dataset providers.
173177
pad_to_max_length: If set with ``max_length``, pad every row to
174178
``max_length`` instead of the longest row in the batch.
179+
pad_to_multiple_of: Accepted for parity with VLM collate functions.
175180
warn_on_all_masked: Forwarded to assistant-mask construction.
176181
ignore_index: Label ignore value for masked targets.
177182
pack_sequences: If True, flatten the padded microbatch and emit
178183
packed-sequence metadata for GPT-style training steps.
179-
pack_sequences_pad_to_multiple_of: Optional per-sequence length multiple
184+
in_batch_packing_pad_to_multiple_of: Optional per-sequence length multiple
180185
used when ``pack_sequences`` inserts padding for CP/SP constraints.
181186
182187
Returns:
183188
Batch dictionary with VLM-style ``input_ids`` and GPT-style ``tokens``
184189
aliases, shifted ``labels`` and ``loss_mask``, ``position_ids``, and
185190
optional tokenizer fields such as ``attention_mask``.
186191
"""
192+
del pad_to_multiple_of
193+
194+
max_length = max_length if max_length is not None else sequence_length
187195
tokenizer = get_processor_tokenizer(processor)
188196
conversations = [_normalize_text_conversation(example) for example in examples]
189197
rendered_texts = [_render_chat(conversation, processor, tokenizer) for conversation in conversations]
@@ -232,6 +240,6 @@ def text_chat_collate_fn(
232240
batch,
233241
pad_token_id=int(pad_token_id),
234242
ignore_index=ignore_index,
235-
pad_to_multiple_of=pack_sequences_pad_to_multiple_of,
243+
pad_to_multiple_of=in_batch_packing_pad_to_multiple_of,
236244
)
237245
return batch

src/megatron/bridge/data/sequence_packing.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616

1717
from __future__ import annotations
1818

19-
import logging
2019
from collections.abc import MutableMapping
2120
from typing import Any
2221

2322
import torch
2423

2524

26-
logger = logging.getLogger(__name__)
27-
28-
2925
def _sequence_lengths(tokens: torch.Tensor, *, pad_token_id: int, padding_mask: torch.Tensor | None) -> list[int]:
3026
lengths = []
3127
batch_size, seq_len = tokens.shape
@@ -202,25 +198,11 @@ def pack_batch_sequences(
202198
"attention_mask": padding_mask,
203199
"position_ids": position_ids,
204200
}
205-
try:
206-
pack_padded_sequences_in_batch(
207-
batch,
208-
pad_token_id=pad_token_id,
209-
pad_to_multiple_of=pad_to_multiple_of,
210-
)
211-
except ValueError as exc:
212-
if str(exc) != "Cannot pack a batch with no non-padding tokens.":
213-
raise
214-
logger.warning("No valid sequences found in batch, skipping packing")
215-
return (
216-
tokens[:, :0],
217-
labels[:, :0] if labels is not None else None,
218-
loss_mask[:, :0] if loss_mask is not None else None,
219-
attention_mask,
220-
position_ids[:, :0],
221-
torch.tensor([0], dtype=torch.int32, device=tokens.device),
222-
torch.tensor(0, dtype=torch.int32, device=tokens.device),
223-
)
201+
pack_padded_sequences_in_batch(
202+
batch,
203+
pad_token_id=pad_token_id,
204+
pad_to_multiple_of=pad_to_multiple_of,
205+
)
224206

225207
return (
226208
batch["input_ids"],

src/megatron/bridge/data/vlm_batching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def prepare_vlm_batch_for_training(
107107
pad_to_max_length: bool = False,
108108
pad_to_multiple_of: int = 128,
109109
pack_sequences: bool = False,
110-
pack_sequences_pad_to_multiple_of: int = 1,
110+
in_batch_packing_pad_to_multiple_of: int = 1,
111111
pad_token_id: int = 0,
112112
ignore_index: int = IGNORE_INDEX,
113113
) -> None:
@@ -125,7 +125,7 @@ def prepare_vlm_batch_for_training(
125125
``pad_to_max_length`` is false.
126126
pack_sequences: If true, flatten the microbatch and emit packed-sequence
127127
metadata instead of returning a padded attention mask.
128-
pack_sequences_pad_to_multiple_of: Per-sequence packed length multiple
128+
in_batch_packing_pad_to_multiple_of: Per-sequence packed length multiple
129129
for CP/SP constraints.
130130
pad_token_id: Token value for inserted padding.
131131
ignore_index: Label value for inserted padding.
@@ -147,7 +147,7 @@ def prepare_vlm_batch_for_training(
147147
batch,
148148
pad_token_id=pad_token_id,
149149
ignore_index=ignore_index,
150-
pad_to_multiple_of=pack_sequences_pad_to_multiple_of,
150+
pad_to_multiple_of=in_batch_packing_pad_to_multiple_of,
151151
tokens_key=token_key,
152152
)
153153
# Legacy VLM packing always carried both padded and unpadded metadata,

src/megatron/bridge/models/gemma_vl/data/collate_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def gemma3_vl_collate_fn(
3939
pad_to_max_length: bool = False,
4040
pad_to_multiple_of: int = 128,
4141
pack_sequences: bool = False,
42-
pack_sequences_pad_to_multiple_of: int = 1,
42+
in_batch_packing_pad_to_multiple_of: int = 1,
4343
) -> dict[str, torch.Tensor]:
4444
"""Collate function for Gemma3 VL models."""
4545
skipped_tokens = extract_skipped_token_ids(processor)
@@ -113,7 +113,7 @@ def gemma3_vl_collate_fn(
113113
pad_to_max_length=pad_to_max_length,
114114
pad_to_multiple_of=pad_to_multiple_of,
115115
pack_sequences=pack_sequences,
116-
pack_sequences_pad_to_multiple_of=pack_sequences_pad_to_multiple_of,
116+
in_batch_packing_pad_to_multiple_of=in_batch_packing_pad_to_multiple_of,
117117
ignore_index=IGNORE_INDEX,
118118
)
119119
return batch

src/megatron/bridge/models/glm_vl/data/collate_fn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ def glm4v_collate_fn(
2828
examples: list,
2929
processor,
3030
*,
31+
visual_keys: object = None,
32+
min_pixels: int | None = None,
33+
max_pixels: int | None = None,
3134
sequence_length: int | None = None,
3235
pad_to_max_length: bool = False,
3336
pad_to_multiple_of: int = 128,
3437
pack_sequences: bool = False,
35-
pack_sequences_pad_to_multiple_of: int = 1,
38+
in_batch_packing_pad_to_multiple_of: int = 1,
3639
) -> dict[str, torch.Tensor]:
3740
"""Collate function for GLM-4.5V model.
3841
@@ -42,6 +45,8 @@ def glm4v_collate_fn(
4245
defaults). We wrap all visual tensors — including ``mm_token_type_ids`` — in
4346
:class:`GenericVisualInputs` so they flow through ``vlm_step.py`` to the model.
4447
"""
48+
del visual_keys, min_pixels, max_pixels
49+
4550
skipped_tokens = extract_skipped_token_ids(processor)
4651

4752
batch = processor.apply_chat_template(
@@ -88,7 +93,7 @@ def glm4v_collate_fn(
8893
pad_to_max_length=pad_to_max_length,
8994
pad_to_multiple_of=pad_to_multiple_of,
9095
pack_sequences=pack_sequences,
91-
pack_sequences_pad_to_multiple_of=pack_sequences_pad_to_multiple_of,
96+
in_batch_packing_pad_to_multiple_of=in_batch_packing_pad_to_multiple_of,
9297
ignore_index=IGNORE_INDEX,
9398
)
9499

src/megatron/bridge/models/kimi_vl/data/collate_fn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,14 @@ def kimi_k25_vl_collate_fn(
138138
processor,
139139
max_length: int | None = None,
140140
*,
141+
visual_keys: object = None,
142+
min_pixels: int | None = None,
143+
max_pixels: int | None = None,
141144
sequence_length: int | None = None,
142145
pad_to_max_length: bool = False,
143146
pad_to_multiple_of: int = 128,
144147
pack_sequences: bool = False,
145-
pack_sequences_pad_to_multiple_of: int = 1,
148+
in_batch_packing_pad_to_multiple_of: int = 1,
146149
) -> dict[str, torch.Tensor]:
147150
"""Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.
148151
@@ -152,6 +155,8 @@ def kimi_k25_vl_collate_fn(
152155
3. Pads all sequences to fixed max_length
153156
This ensures the model forward pass doesn't change sequence length dynamically.
154157
"""
158+
del visual_keys, min_pixels, max_pixels
159+
155160
skipped_tokens = extract_skipped_token_ids(processor)
156161

157162
# Get media token ID
@@ -310,7 +315,7 @@ def kimi_k25_vl_collate_fn(
310315
pad_to_max_length=pad_to_max_length,
311316
pad_to_multiple_of=pad_to_multiple_of,
312317
pack_sequences=pack_sequences,
313-
pack_sequences_pad_to_multiple_of=pack_sequences_pad_to_multiple_of,
318+
in_batch_packing_pad_to_multiple_of=in_batch_packing_pad_to_multiple_of,
314319
ignore_index=IGNORE_INDEX,
315320
)
316321
return result

0 commit comments

Comments
 (0)