Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/megatron/bridge/data/energon/energon_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ class EnergonProvider(DatasetProvider):
task_encoder: Optional[Any] = None
# Enable batch-level online sequence packing
enable_in_batch_packing: bool = False
pad_to_max_length: bool = False
pad_to_multiple_of: int = 128
in_batch_packing_pad_to_multiple_of: int = 1

def _sync_task_encoder_sequence_batching(self) -> None:
if self.task_encoder is None:
return
self.task_encoder.pad_to_max_length = self.pad_to_max_length
self.task_encoder.pad_to_multiple_of = self.pad_to_multiple_of
self.task_encoder.enable_in_batch_packing = self.enable_in_batch_packing
self.task_encoder.in_batch_packing_pad_to_multiple_of = self.in_batch_packing_pad_to_multiple_of

def build_datasets(self, context: DatasetBuildContext):
assert self.path, "EnergonProvider.path must be set. Use CLI override: dataset.path=<path>"
if (
self.enable_in_batch_packing
and self.task_encoder is not None
and hasattr(self.task_encoder, "pack_sequences")
):
self.task_encoder.pack_sequences = True
self._sync_task_encoder_sequence_batching()
dataset = EnergonMultiModalDataModule(
path=self.path,
tokenizer=context.tokenizer if context.tokenizer is not None else self.tokenizer,
Expand Down
72 changes: 49 additions & 23 deletions src/megatron/bridge/data/energon/hf_task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""

import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -60,6 +59,11 @@ class HFEnergonBatch(Batch):
position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0)) # [B, seq_len]
visual_inputs: GenericVisualInputs | None = None
attention_mask: torch.Tensor | None = None
cu_seqlens: torch.Tensor | None = None
cu_seqlens_unpadded: torch.Tensor | None = None
cu_seqlens_argmin: torch.Tensor | None = None
cu_seqlens_unpadded_argmin: torch.Tensor | None = None
max_seqlen: torch.Tensor | None = None


class HFTaskEncoder(DefaultTaskEncoder[ChatMLSample, HFEnergonSample, HFEnergonBatch, dict]):
Expand All @@ -75,6 +79,17 @@ class HFTaskEncoder(DefaultTaskEncoder[ChatMLSample, HFEnergonSample, HFEnergonB
the selected collate function.
max_pixels: Optional max pixel constraint forwarded when supported by
the selected collate function.
collate_fn: Optional collate implementation override. If omitted, the
implementation is selected from the processor type.
pad_to_max_length: Whether collate-time padding should pad non-packed
batches to ``seq_length`` when the selected collate supports it.
pad_to_multiple_of: Non-packed collate-time padding multiple used when
``pad_to_max_length`` is false and the selected collate supports it.
enable_in_batch_packing: Whether the selected collate should do
in-batch sequence packing.
in_batch_packing_pad_to_multiple_of: Per-sample padding multiple used
only by the in-batch packed path, typically to satisfy CP/SP
divisibility.
"""

def __init__(
Expand All @@ -84,14 +99,22 @@ def __init__(
visual_keys: Sequence[str] = ("pixel_values",),
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
collate_fn: Callable[[list, Any], dict[str, Any]] | None = None,
collate_fn: Callable[..., dict[str, Any]] | None = None,
pad_to_max_length: bool = False,
pad_to_multiple_of: int = 128,
enable_in_batch_packing: bool = False,
in_batch_packing_pad_to_multiple_of: int = 1,
):
super().__init__()
self.processor = processor
self.seq_length = seq_length
self.visual_keys: Tuple[str, ...] = tuple(visual_keys)
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.pad_to_max_length = pad_to_max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.enable_in_batch_packing = enable_in_batch_packing
self.in_batch_packing_pad_to_multiple_of = in_batch_packing_pad_to_multiple_of
collate_key = type(processor).__name__ if processor is not None else "default"
if collate_fn is not None:
self._collate_impl = collate_fn
Expand All @@ -103,24 +126,6 @@ def __init__(
)
self._collate_impl = COLLATE_FNS[collate_key]

def _supported_collate_kwargs(self) -> dict[str, Any]:
"""Return encoder options accepted by the selected collate function."""
try:
parameters = inspect.signature(self._collate_impl).parameters
except (TypeError, ValueError):
return {}

accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values())
candidates: dict[str, Any] = {"visual_keys": self.visual_keys}
if self.min_pixels is not None:
candidates["min_pixels"] = self.min_pixels
if self.max_pixels is not None:
candidates["max_pixels"] = self.max_pixels

if accepts_kwargs:
return candidates
return {key: value for key, value in candidates.items() if key in parameters}

def encode_sample(self, sample: ChatMLSample) -> HFEnergonSample:
"""Normalize a single ChatML sample into a HF-style collate example.

Expand Down Expand Up @@ -154,7 +159,18 @@ def collate_fn(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
The exact batch dictionary returned by the selected HF collate
function for this processor type.
"""
return self._collate_impl(examples, self.processor, **self._supported_collate_kwargs())
return self._collate_impl(
examples,
self.processor,
visual_keys=self.visual_keys,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
sequence_length=self.seq_length,
pad_to_max_length=self.pad_to_max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
pack_sequences=self.enable_in_batch_packing,
in_batch_packing_pad_to_multiple_of=self.in_batch_packing_pad_to_multiple_of,
)

# ------------------------------------------------------------------
# batch
Expand All @@ -164,9 +180,14 @@ def batch(self, samples: List[HFEnergonSample]) -> HFEnergonBatch:
"""Collate normalized samples with the selected HF VLM collator."""
examples = [sample.example for sample in samples]
collated = self.collate_fn(examples)
if collated["input_ids"].shape[1] > self.seq_length:
collated_seq_len = (
int(collated["max_seqlen"].max().item())
if collated.get("max_seqlen") is not None
else collated["input_ids"].shape[1]
)
if collated_seq_len > self.seq_length:
raise ValueError(
f"Collated seq_len {collated['input_ids'].shape[1]} exceeds seq_length {self.seq_length}. "
f"Collated seq_len {collated_seq_len} exceeds seq_length {self.seq_length}. "
"The selected HF VLM collator must enforce seq_length while preserving visual metadata."
)

Expand All @@ -181,6 +202,11 @@ def batch(self, samples: List[HFEnergonSample]) -> HFEnergonBatch:
attention_mask=collated.get("attention_mask"),
position_ids=collated["position_ids"],
visual_inputs=collated.get("visual_inputs"),
cu_seqlens=collated.get("cu_seqlens"),
cu_seqlens_unpadded=collated.get("cu_seqlens_unpadded"),
cu_seqlens_argmin=collated.get("cu_seqlens_argmin"),
cu_seqlens_unpadded_argmin=collated.get("cu_seqlens_unpadded_argmin"),
max_seqlen=collated.get("max_seqlen"),
)

return HFEnergonBatch(**batch_kwargs)
Expand Down
47 changes: 41 additions & 6 deletions src/megatron/bridge/data/energon/nemotron_omni_task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
find_pattern_indices,
get_ltor_masks_and_position_ids,
)
from megatron.bridge.data.vlm_batching import prepare_vlm_batch_for_training
from megatron.bridge.training.utils.visual_inputs import GenericVisualInputs


Expand Down Expand Up @@ -84,7 +85,7 @@ class NemotronOmniTaskBatch(Batch):
imgs_sizes: Optional[torch.Tensor] = None # [total_frames, 2]
num_frames: Optional[torch.Tensor] = None # [num_media_items]
num_image_tiles: Optional[torch.Tensor] = None # [total_images] LM-side token count per image
# Packed-sequence metadata (only populated when pack_sequences=True).
# Packed-sequence metadata (only populated when enable_in_batch_packing=True).
cu_seqlens: Optional[torch.Tensor] = None
cu_seqlens_unpadded: Optional[torch.Tensor] = None
cu_seqlens_argmin: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -115,6 +116,14 @@ class NemotronOmniTaskEncoder(DefaultTaskEncoder[ChatMLSample, NemotronOmniTaskS
num_mel_bins: Number of mel frequency bins (must match the sound
encoder config, typically 128 for Parakeet).
visual_keys: Processor output keys to capture as visual tensors.
pad_to_max_length: Whether collate-time padding should pad non-packed
batches to ``seq_length`` when supported.
pad_to_multiple_of: Non-packed collate-time padding multiple used when
``pad_to_max_length`` is false and supported.
enable_in_batch_packing: Whether to do in-batch sequence packing.
in_batch_packing_pad_to_multiple_of: Per-sample padding multiple used
only by the in-batch packed path, typically to satisfy CP/SP
divisibility.
"""

def __init__(
Expand All @@ -129,7 +138,10 @@ def __init__(
video_nframes: int = 8,
use_temporal_video_embedder: bool = False,
patch_dim: int = 16,
pack_sequences: bool = False,
pad_to_max_length: bool = False,
pad_to_multiple_of: int = 128,
enable_in_batch_packing: bool = False,
in_batch_packing_pad_to_multiple_of: int = 1,
):
super().__init__()
self.processor = processor
Expand All @@ -142,7 +154,10 @@ def __init__(
self.video_nframes = video_nframes
self.use_temporal_video_embedder = use_temporal_video_embedder
self.patch_dim = patch_dim
self.pack_sequences = pack_sequences
self.pad_to_max_length = pad_to_max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.enable_in_batch_packing = enable_in_batch_packing
self.in_batch_packing_pad_to_multiple_of = in_batch_packing_pad_to_multiple_of

@staticmethod
def _decode_video_bytes(video_bytes: bytes, nframes: int = 8, fps: float = 1.0):
Expand Down Expand Up @@ -484,7 +499,7 @@ def encode_sample(self, sample: ChatMLSample) -> NemotronOmniTaskSample:

def batch(self, samples: List[NemotronOmniTaskSample]) -> NemotronOmniTaskBatch:
"""Pad-and-collate (default) OR pack samples along the seq dim when
``pack_sequences=True``. Packing emits ``cu_seqlens`` / ``cu_seqlens_unpadded``
``enable_in_batch_packing=True``. Packing emits ``cu_seqlens`` / ``cu_seqlens_unpadded``
/ ``max_seqlen`` so TE's THD kernels handle cross-sample masking (and CP
partitioning via ``thd_get_partitioned_indices``) without an attention mask.
"""
Expand All @@ -496,7 +511,7 @@ def batch(self, samples: List[NemotronOmniTaskSample]) -> NemotronOmniTaskBatch:
cu_seqlens_argmin_t: Optional[torch.Tensor] = None
max_seqlen_t: Optional[torch.Tensor] = None

if self.pack_sequences:
if self.enable_in_batch_packing:
# Concatenate samples along the seq dim into a single [1, total_len]
# microbatch. TE attention kernels use cu_seqlens for per-sample
# masking; no attention_mask needed.
Expand Down Expand Up @@ -550,6 +565,26 @@ def batch(self, samples: List[NemotronOmniTaskSample]) -> NemotronOmniTaskBatch:
reset_attention_mask=False,
reset_position_ids=False,
)
text_batch = {
"input_ids": tokens,
"labels": labels,
"loss_mask": loss_mask_t,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
prepare_vlm_batch_for_training(
text_batch,
sequence_length=self.seq_length,
pad_to_max_length=self.pad_to_max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
pad_token_id=0,
ignore_index=IGNORE_INDEX,
)
tokens = text_batch["input_ids"]
labels = text_batch["labels"]
loss_mask_t = text_batch["loss_mask"]
position_ids = text_batch["position_ids"]
attention_mask = text_batch["attention_mask"]

# Aggregate visual tensors.
# The temporal video path ships pixel_values as [1, N_i*patches_per_frame, feat]
Expand All @@ -565,7 +600,7 @@ def batch(self, samples: List[NemotronOmniTaskSample]) -> NemotronOmniTaskBatch:
if not tensors:
batched_visual[key] = None
continue
if self.pack_sequences and tensors[0].dim() == 3:
if self.enable_in_batch_packing and tensors[0].dim() == 3:
batched_visual[key] = torch.cat(tensors, dim=1)
else:
batched_visual[key] = torch.cat(tensors, dim=0)
Expand Down
43 changes: 17 additions & 26 deletions src/megatron/bridge/data/hf_datasets/conversation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,49 +36,40 @@ def __init__(
base_examples: List[Dict[str, Any]],
target_length: int,
processor: Any,
collate_impl: Optional[Callable[[list, Any], Dict[str, torch.Tensor]]] = None,
pack_sequences: bool = False,
pack_sequences_pad_to_multiple_of: int = 1,
collate_impl: Optional[Callable[..., Dict[str, torch.Tensor]]] = None,
sequence_length: int | None = None,
pad_to_max_length: bool = False,
pad_to_multiple_of: int = 128,
enable_in_batch_packing: bool = False,
in_batch_packing_pad_to_multiple_of: int = 1,
) -> None:
assert isinstance(base_examples, list) and len(base_examples) > 0, "base_examples must be a non-empty list"
self._base_examples = base_examples
self._length = int(max(0, target_length))
self._processor = processor
# Choose collate implementation by processor type name when not provided
collate_key = type(processor).__name__ if processor is not None else "default"
if collate_impl is not None:
selected_impl = collate_impl
else:
if collate_impl is None:
from megatron.bridge.data.vlm_datasets.collate import COLLATE_FNS

if collate_key not in COLLATE_FNS:
raise ValueError(
f"No conversation collate function registered for processor type '{collate_key}'. "
"Add it to COLLATE_FNS or pass collate_impl explicitly."
)
selected_impl = COLLATE_FNS[collate_key]
collate_impl = COLLATE_FNS[collate_key]
assert collate_impl is not None

# If packing requested, only collates that advertise `pack_sequences` support it;
# bind via functools.partial so the DataLoader just calls f(batch, processor).
if pack_sequences:
import inspect
from functools import partial

sig = inspect.signature(selected_impl)
if "pack_sequences" in sig.parameters:
pack_kwargs: dict[str, Any] = {"pack_sequences": True}
if "pack_sequences_pad_to_multiple_of" in sig.parameters:
pack_kwargs["pack_sequences_pad_to_multiple_of"] = pack_sequences_pad_to_multiple_of
selected_impl = partial(selected_impl, **pack_kwargs)
else:
raise ValueError(
f"Collate function {getattr(selected_impl, '__name__', selected_impl)} "
f"does not accept pack_sequences=True. Use a collate that supports packing "
f"(e.g. nemotron_omni_collate_fn)."
)
collate_kwargs: dict[str, Any] = {
"sequence_length": sequence_length,
"pad_to_max_length": pad_to_max_length,
"pad_to_multiple_of": pad_to_multiple_of,
"pack_sequences": enable_in_batch_packing,
"in_batch_packing_pad_to_multiple_of": in_batch_packing_pad_to_multiple_of,
}

def _bound_collate(batch: list) -> Dict[str, torch.Tensor]:
return selected_impl(batch, self._processor) # type: ignore[call-arg]
return collate_impl(batch, self._processor, **collate_kwargs)

self.collate_fn = _bound_collate

Expand Down
Loading
Loading