Skip to content

Commit 1036095

Browse files
authored
Sample packing for ConcatDataset (#2278)
1 parent 7747db1 commit 1036095

11 files changed

+46
-16
lines changed

recipes/dev/early_exit_finetune_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def _setup_data(
653653
for single_cfg_dataset in cfg_dataset
654654
]
655655
ds = ConcatDataset(datasets=datasets)
656-
packed = False
656+
packed = getattr(ds, "packed", False)
657657
else:
658658
ds = config.instantiate(cfg_dataset, self._tokenizer)
659659
packed = cfg_dataset.get("packed", False)

recipes/full_finetune_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def _setup_data(
646646
for single_cfg_dataset in cfg_dataset
647647
]
648648
ds = ConcatDataset(datasets=datasets)
649-
packed = False
649+
packed = getattr(ds, "packed", False)
650650
else:
651651
ds = config.instantiate(cfg_dataset, self._tokenizer)
652652
packed = cfg_dataset.get("packed", False)

recipes/full_finetune_single_device.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def _setup_data(
558558
for single_cfg_dataset in cfg_dataset
559559
]
560560
ds = ConcatDataset(datasets=datasets)
561-
packed = False
561+
packed = getattr(ds, "packed", False)
562562
else:
563563
ds = config.instantiate(cfg_dataset, self._tokenizer)
564564
packed = cfg_dataset.get("packed", False)

recipes/knowledge_distillation_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def _setup_data(
652652
for single_cfg_dataset in cfg_dataset
653653
]
654654
ds = ConcatDataset(datasets=datasets)
655-
packed = False
655+
packed = getattr(ds, "packed", False)
656656
else:
657657
ds = config.instantiate(cfg_dataset, self._tokenizer)
658658
packed = cfg_dataset.get("packed", False)

recipes/knowledge_distillation_single_device.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def _setup_data(
531531
for single_cfg_dataset in cfg_dataset
532532
]
533533
ds = ConcatDataset(datasets=datasets)
534-
packed = False
534+
packed = getattr(ds, "packed", False)
535535
else:
536536
ds = config.instantiate(cfg_dataset, self._tokenizer)
537537
packed = cfg_dataset.get("packed", False)

recipes/lora_finetune_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def _setup_data(
591591
for single_cfg_dataset in cfg_dataset
592592
]
593593
ds = ConcatDataset(datasets=datasets)
594-
packed = False
594+
packed = getattr(ds, "packed", False)
595595
else:
596596
ds = config.instantiate(cfg_dataset, self._tokenizer)
597597
packed = cfg_dataset.get("packed", False)

recipes/lora_finetune_single_device.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def _setup_data(
528528
for single_cfg_dataset in cfg_dataset
529529
]
530530
ds = ConcatDataset(datasets=datasets)
531-
packed = False
531+
packed = getattr(ds, "packed", False)
532532
else:
533533
ds = config.instantiate(cfg_dataset, self._tokenizer)
534534
packed = cfg_dataset.get("packed", False)

recipes/qat_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _setup_data(
606606
for single_cfg_dataset in cfg_dataset
607607
]
608608
ds = ConcatDataset(datasets=datasets)
609-
packed = False
609+
packed = getattr(ds, "packed", False)
610610
else:
611611
ds = config.instantiate(cfg_dataset, self._tokenizer)
612612
packed = cfg_dataset.get("packed", False)

recipes/qat_lora_finetune_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def _setup_data(
633633
for single_cfg_dataset in cfg_dataset
634634
]
635635
ds = ConcatDataset(datasets=datasets)
636-
packed = False
636+
packed = getattr(ds, "packed", False)
637637
else:
638638
ds = config.instantiate(cfg_dataset, self._tokenizer)
639639
packed = cfg_dataset.get("packed", False)

tests/torchtune/datasets/test_concat_dataset.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_invalid_index_type(self, datasets):
8080
with pytest.raises(TypeError):
8181
multi_dataset["invalid_type"] # Non-integer index
8282

83-
def test_packed_dataset(self, torch_datasets):
83+
def test_single_packed_dataset(self, torch_datasets):
8484
torch_datasets[0] = PackedDataset(
8585
torch_datasets[0],
8686
max_seq_len=25,
@@ -90,3 +90,33 @@ def test_packed_dataset(self, torch_datasets):
9090

9191
with pytest.raises(ValueError):
9292
concated_dataset = ConcatDataset(torch_datasets)
93+
94+
def test_all_packed_datasets(self, torch_datasets):
95+
for i in range(len(torch_datasets)):
96+
torch_datasets[i] = PackedDataset(
97+
torch_datasets[i],
98+
max_seq_len=2000,
99+
max_packs=16,
100+
split_across_pack=True,
101+
)
102+
concated_dataset = ConcatDataset(torch_datasets)
103+
assert concated_dataset.packed
104+
105+
# 2k tokens per pack
106+
# 1st ds has 4k tokens, 2nd ds has 8k tokens, 3rd ds has 15k tokens
107+
# 4th ds has 16k tokens, 5th ds has 23k tokens, 6th ds has 42k tokens
108+
109+
assert concated_dataset[0]["seq_lens"][0] == 4
110+
# 2nd packed ds starts at idx 2
111+
assert concated_dataset[2]["seq_lens"][0] == 8
112+
# 3rd packed ds starts at idx 6
113+
assert concated_dataset[6]["seq_lens"][0] == 15
114+
# 4th packed ds starts at idx 14
115+
assert concated_dataset[14]["seq_lens"][0] == 16
116+
# 5th packed ds starts at idx 22
117+
assert concated_dataset[22]["seq_lens"][0] == 23
118+
# 6th packed ds starts at idx 34
119+
assert concated_dataset[34]["seq_lens"][0] == 42
120+
121+
# Total length is 2 + 4 + 8 + 8 + 12 + 16 (because of max_packs) = 50
122+
assert len(concated_dataset) == 50

torchtune/datasets/_concat.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ class ConcatDataset(Dataset):
6767
def __init__(self, datasets: List[Dataset]):
6868
self._datasets: List[Dataset] = datasets
6969

70-
for dataset in self._datasets:
71-
if isinstance(dataset, PackedDataset):
72-
raise ValueError(
73-
"ConcatDataset can't process instances of PackedDataset."
74-
)
75-
70+
is_packed = [isinstance(dataset, PackedDataset) for dataset in datasets]
71+
if any(is_packed) and not all(is_packed):
72+
raise ValueError(
73+
"ConcatDataset can't process a mix of packed and non-packed datasets."
74+
)
75+
self.packed = all(is_packed)
7676
self._len: int = sum(len(dataset) for dataset in datasets)
7777
self._indexes: List[Tuple[int, int, int]] = []
7878

0 commit comments

Comments
 (0)