@@ -80,7 +80,7 @@ def test_invalid_index_type(self, datasets):
80
80
with pytest .raises (TypeError ):
81
81
multi_dataset ["invalid_type" ] # Non-integer index
82
82
83
- def test_packed_dataset (self , torch_datasets ):
83
+ def test_single_packed_dataset (self , torch_datasets ):
84
84
torch_datasets [0 ] = PackedDataset (
85
85
torch_datasets [0 ],
86
86
max_seq_len = 25 ,
@@ -90,3 +90,33 @@ def test_packed_dataset(self, torch_datasets):
90
90
91
91
with pytest .raises (ValueError ):
92
92
concated_dataset = ConcatDataset (torch_datasets )
93
+
94
+ def test_all_packed_datasets (self , torch_datasets ):
95
+ for i in range (len (torch_datasets )):
96
+ torch_datasets [i ] = PackedDataset (
97
+ torch_datasets [i ],
98
+ max_seq_len = 2000 ,
99
+ max_packs = 16 ,
100
+ split_across_pack = True ,
101
+ )
102
+ concated_dataset = ConcatDataset (torch_datasets )
103
+ assert concated_dataset .packed
104
+
105
+ # 2k tokens per pack
106
+ # 1st ds has 4k tokens, 2nd ds has 8k tokens, 3rd ds has 15k tokens
107
+ # 4th ds has 16k tokens, 5th ds has 23k tokens, 6th ds has 42k tokens
108
+
109
+ assert concated_dataset [0 ]["seq_lens" ][0 ] == 4
110
+ # 2nd packed ds starts at idx 2
111
+ assert concated_dataset [2 ]["seq_lens" ][0 ] == 8
112
+ # 3rd packed ds starts at idx 6
113
+ assert concated_dataset [6 ]["seq_lens" ][0 ] == 15
114
+ # 4th packed ds starts at idx 14
115
+ assert concated_dataset [14 ]["seq_lens" ][0 ] == 16
116
+ # 5th packed ds starts at idx 22
117
+ assert concated_dataset [22 ]["seq_lens" ][0 ] == 23
118
+ # 6th packed ds starts at idx 34
119
+ assert concated_dataset [34 ]["seq_lens" ][0 ] == 42
120
+
121
+ # Total length is 2 + 4 + 8 + 8 + 12 + 16 (because of max_packs) = 50
122
+ assert len (concated_dataset ) == 50
0 commit comments