4040
4141METADATA_FNAME = "helmet_metadata.json"
4242
43+ METADATA_TARGET_CHOICE_KEY = "target_choice"
44+
4345
4446class Helmet (SequenceLengthFilteredDataModule ):
4547 """Data module for HELMET benchmark datasets.
@@ -119,10 +121,16 @@ def __init__(
119121 self .max_length = max_length
120122 self .dataset_parent_dir = dataset_parent_dir
121123 self .metadata_dir = metadata_dir
124+ self .target_choices = [None , None , None ]
125+ self ._metadata = None
122126
123- def _metadata_keys (self , split : str ) -> List [str ]:
127+ def _metadata_keys (
128+ self ,
129+ root_key : str ,
130+ split : str ,
131+ ) -> List [str ]:
124132 return [
125- METADATA_SEQ_LENGTHS_KEY ,
133+ root_key ,
126134 self .dataset_key ,
127135 self .max_length ,
128136 self .tokenizer .model_name ,
@@ -137,6 +145,12 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
137145 )
138146 print (f"\n Transforming HELMET '{ self .dataset_key } ' ({ self .max_length } ) ..." )
139147 metadata = self ._load_metadata ()
148+ self ._metadata = metadata # Needed in :meth:`_create_datasets`
149+ self .target_choices = [
150+ self ._get_target_choice (metadata , "train" ),
151+ self ._get_target_choice (metadata , "val" ),
152+ self ._get_target_choice (metadata , "test" ),
153+ ]
140154 train_data , dev_seq_lengths , dev_needs_store = self ._transform (
141155 dev_data , split = "dev" , seq_lengths = self ._get_seq_lengths (metadata , "dev" )
142156 )
@@ -147,9 +161,17 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
147161 if metadata is None :
148162 metadata = dict ()
149163 if dev_needs_store :
150- set_dict (metadata , self ._metadata_keys ("dev" ), dev_seq_lengths )
164+ set_dict (
165+ metadata ,
166+ self ._metadata_keys (METADATA_SEQ_LENGTHS_KEY , "dev" ),
167+ dev_seq_lengths ,
168+ )
151169 if eval_needs_store :
152- set_dict (metadata , self ._metadata_keys ("eval" ), eval_seq_lengths )
170+ set_dict (
171+ metadata ,
172+ self ._metadata_keys (METADATA_SEQ_LENGTHS_KEY , "eval" ),
173+ eval_seq_lengths ,
174+ )
153175 self ._store_metadata (metadata )
154176 return train_data , test_data
155177
@@ -221,7 +243,14 @@ def _transform(
221243 def _get_seq_lengths (
222244 self , metadata : Optional [Dict [str , Any ]], split : str
223245 ) -> Optional [List [int ]]:
224- return get_dict (metadata , self ._metadata_keys (split ))
246+ return get_dict (metadata , self ._metadata_keys (METADATA_SEQ_LENGTHS_KEY , split ))
247+
248+ def _get_target_choice (
249+ self , metadata : Optional [Dict [str , Any ]], split : str
250+ ) -> Optional [List [int ]]:
251+ return get_dict (
252+ metadata , self ._metadata_keys (METADATA_TARGET_CHOICE_KEY , split )
253+ )
225254
226255 def _load_metadata (self ) -> Optional [Dict [str , Any ]]:
227256 if self .metadata_dir is None :
@@ -253,25 +282,48 @@ def _create_datasets(
253282 val_kwargs : Dict [str , Any ],
254283 test_kwargs : Optional [Dict [str , Any ]],
255284 ) -> None :
285+ num_sets = 2
256286 self .train_dataset = SFTDataset (
257287 ** train_kwargs ,
258288 mask_prompt = self .mask_prompt ,
259289 ignore_index = self .ignore_index ,
290+ target_choice = self .target_choices [0 ],
260291 seed = self .seed ,
261292 )
262293 self .val_dataset = SFTDataset (
263294 ** val_kwargs ,
264295 mask_prompt = self .mask_prompt ,
265296 ignore_index = self .ignore_index ,
297+ target_choice = self .target_choices [1 ],
266298 seed = self .seed ,
267299 )
268300 if test_kwargs is not None :
269301 self .test_dataset = SFTDataset (
270302 ** test_kwargs ,
271303 mask_prompt = self .mask_prompt ,
272304 ignore_index = self .ignore_index ,
305+ target_choice = self .target_choices [2 ],
273306 seed = self .seed ,
274307 )
308+ num_sets += 1
309+ # Update meta-data?
310+ do_store_meta = any (x is None for x in self .target_choices [:num_sets ])
311+ if do_store_meta :
312+ for i , (data , split ) in enumerate (
313+ zip (
314+ (self .train_dataset , self .val_dataset , self .test_dataset ),
315+ ("train" , "val" , "test" ),
316+ )
317+ ):
318+ if self .target_choices [i ] is None and data is not None :
319+ new_choices = data .target_choice .copy ()
320+ self .target_choices [i ] = new_choices
321+ set_dict (
322+ self ._metadata ,
323+ self ._metadata_keys (METADATA_TARGET_CHOICE_KEY , split ),
324+ new_choices ,
325+ )
326+ self ._store_metadata (self ._metadata )
275327
276328 def _get_collate_fn (self ) -> MyDataLoader :
277329 return get_sft_collate_fn (ignore_index = self .ignore_index )
0 commit comments