1616from typing import List , Optional , Dict , Any , Tuple , Literal
1717
1818from tokenizers import Tokenizer as HFTokenizer
19+ import torch
1920from tqdm import tqdm
2021
2122from keys_values .data .dataloader import MyDataLoader
2627)
2728from keys_values .data .module import (
2829 SequenceLengthFilteredDataModule ,
30+ SequenceLengthFilteredDataTrainState ,
2931 METADATA_SEQ_LENGTHS_KEY ,
3032 METADATA_KEYS ,
3133 RawDatasetType ,
4042
4143METADATA_FNAME = "helmet_metadata.json"
4244
43- METADATA_TARGET_CHOICE_KEY = "target_choice"
45+
46+ class HelmetDataTrainState (SequenceLengthFilteredDataTrainState ):
47+ """
48+ Also contains the `target_choice` indexes for training, validation and
49+ test split.
50+ """
51+
52+ def __init__ (self ):
53+ super ().__init__ ()
54+ self ._train_target_choice = None
55+ self ._val_target_choice = None
56+ self ._test_target_choice = None
57+
58+ @property
59+ def train_target_choice (self ) -> Optional [List [int ]]:
60+ return self ._train_target_choice
61+
62+ @train_target_choice .setter
63+ def train_target_choice (self , value : Optional [List [int ]]) -> None :
64+ # `>` is OK, as dataset may be padded after splitting
65+ if len (value ) < len (self .train_data_index ):
66+ raise ValueError (
67+ f"len(train_target_choice) = { len (value )} < { len (self .train_data_index )} = len(self.train_data_index)"
68+ )
69+ if not all (x >= 0 for x in value ):
70+ raise ValueError ("All entries of train_target_choice must be >= 0" )
71+ self ._train_target_choice = value .copy ()
72+
73+ @property
74+ def val_target_choice (self ) -> Optional [List [int ]]:
75+ return self ._val_target_choice
76+
77+ @val_target_choice .setter
78+ def val_target_choice (self , value : Optional [List [int ]]) -> None :
79+ # `>` is OK, as dataset may be padded after splitting
80+ if len (value ) < len (self .val_data_index ):
81+ raise ValueError (
82+ f"len(val_target_choice) = { len (value )} < { len (self .val_data_index )} = len(self.val_data_index)"
83+ )
84+ if not all (x >= 0 for x in value ):
85+ raise ValueError ("All entries of val_target_choice must be >= 0" )
86+ self ._val_target_choice = value .copy ()
87+
88+ @property
89+ def test_target_choice (self ) -> Optional [List [int ]]:
90+ return self ._test_target_choice
91+
92+ @test_target_choice .setter
93+ def test_target_choice (self , value : Optional [List [int ]]) -> None :
94+ if not all (x >= 0 for x in value ):
95+ raise ValueError ("All entries of test_target_choice must be >= 0" )
96+ self ._test_target_choice = value .copy ()
97+
98+ def state_dict (self ) -> Dict [str , torch .Tensor ]:
99+ kwargs = dict (dtype = torch .int64 )
100+ result = super ().state_dict ()
101+ result .update (
102+ {
103+ f"{ name } _target_choice" : torch .tensor (value , ** kwargs )
104+ for name , value in zip (
105+ ("train" , "val" , "test" ),
106+ (
107+ self .train_target_choice ,
108+ self .val_target_choice ,
109+ self .test_target_choice ,
110+ ),
111+ )
112+ if value is not None
113+ }
114+ )
115+ return result
116+
117+ def load_state_dict (self , state_dict : Dict [str , torch .Tensor ]):
118+ super ().load_state_dict (state_dict )
119+ train_ind = state_dict .get ("train_target_choice" )
120+ val_ind = state_dict .get ("val_target_choice" )
121+ test_ind = state_dict .get ("test_target_choice" )
122+ self .train_target_choice = None if train_ind is None else train_ind .tolist ()
123+ self .val_target_choice = None if val_ind is None else val_ind .tolist ()
124+ self .test_target_choice = None if test_ind is None else test_ind .tolist ()
44125
45126
46127class Helmet (SequenceLengthFilteredDataModule ):
@@ -121,8 +202,6 @@ def __init__(
121202 self .max_length = max_length
122203 self .dataset_parent_dir = dataset_parent_dir
123204 self .metadata_dir = metadata_dir
124- self .target_choices = [None , None , None ]
125- self ._metadata = None
126205
127206 def _metadata_keys (
128207 self ,
@@ -145,12 +224,6 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
145224 )
146225 print (f"\n Transforming HELMET '{ self .dataset_key } ' ({ self .max_length } ) ..." )
147226 metadata = self ._load_metadata ()
148- self ._metadata = metadata # Needed in :meth:`_create_datasets`
149- self .target_choices = [
150- self ._get_target_choice (metadata , "train" ),
151- self ._get_target_choice (metadata , "val" ),
152- self ._get_target_choice (metadata , "test" ),
153- ]
154227 train_data , dev_seq_lengths , dev_needs_store = self ._transform (
155228 dev_data , split = "dev" , seq_lengths = self ._get_seq_lengths (metadata , "dev" )
156229 )
@@ -245,13 +318,6 @@ def _get_seq_lengths(
245318 ) -> Optional [List [int ]]:
246319 return get_dict (metadata , self ._metadata_keys (METADATA_SEQ_LENGTHS_KEY , split ))
247320
248- def _get_target_choice (
249- self , metadata : Optional [Dict [str , Any ]], split : str
250- ) -> Optional [List [int ]]:
251- return get_dict (
252- metadata , self ._metadata_keys (METADATA_TARGET_CHOICE_KEY , split )
253- )
254-
255321 def _load_metadata (self ) -> Optional [Dict [str , Any ]]:
256322 if self .metadata_dir is None :
257323 return None
@@ -282,48 +348,73 @@ def _create_datasets(
282348 val_kwargs : Dict [str , Any ],
283349 test_kwargs : Optional [Dict [str , Any ]],
284350 ) -> None :
285- num_sets = 2
351+ assert self .training_state is not None # Sanity check
352+ if not isinstance (self .training_state , HelmetDataTrainState ):
353+ # Must have been created in :meth:`SequenceLengthFilteredDataModule.setup`
354+ if not isinstance (
355+ self .training_state , SequenceLengthFilteredDataTrainState
356+ ):
357+ raise TypeError (
358+ f"type(self.training_state) = { type (self .training_state )} : Invalid"
359+ )
360+ # Convert it
361+ new_training_state = HelmetDataTrainState ()
362+ new_training_state .initialize (
363+ self .training_state .train_data_index ,
364+ self .training_state .val_data_index ,
365+ )
366+ self .training_state = new_training_state
367+ else :
368+ for name , value in zip (
369+ ("train" , "val" , "test" ),
370+ (
371+ self .training_state .train_target_choice ,
372+ self .training_state .val_target_choice ,
373+ self .training_state .test_target_choice ,
374+ ),
375+ ):
376+ if value is not None :
377+ print (
378+ f"Loaded { name } _target_choice ({ len (value )} ) from training state"
379+ )
380+ target_choice = self .training_state .train_target_choice
286381 self .train_dataset = SFTDataset (
287382 ** train_kwargs ,
288383 mask_prompt = self .mask_prompt ,
289384 ignore_index = self .ignore_index ,
290- target_choice = self . target_choices [ 0 ] ,
385+ target_choice = target_choice ,
291386 seed = self .seed ,
292387 )
388+ if target_choice is None :
389+ print (
390+ f"Sampled train_target_choice ({ len (self .train_dataset .target_choice )} )"
391+ )
392+ self .training_state .train_target_choice = self .train_dataset .target_choice
393+ target_choice = self .training_state .val_target_choice
293394 self .val_dataset = SFTDataset (
294395 ** val_kwargs ,
295396 mask_prompt = self .mask_prompt ,
296397 ignore_index = self .ignore_index ,
297- target_choice = self . target_choices [ 1 ] ,
398+ target_choice = target_choice ,
298399 seed = self .seed ,
299400 )
401+ if target_choice is None :
402+ print (f"Sampled val_target_choice ({ len (self .val_dataset .target_choice )} )" )
403+ self .training_state .val_target_choice = self .val_dataset .target_choice
300404 if test_kwargs is not None :
405+ target_choice = self .training_state .test_target_choice
301406 self .test_dataset = SFTDataset (
302407 ** test_kwargs ,
303408 mask_prompt = self .mask_prompt ,
304409 ignore_index = self .ignore_index ,
305- target_choice = self . target_choices [ 2 ] ,
410+ target_choice = target_choice ,
306411 seed = self .seed ,
307412 )
308- num_sets += 1
309- # Update meta-data?
310- do_store_meta = any (x is None for x in self .target_choices [:num_sets ])
311- if do_store_meta :
312- for i , (data , split ) in enumerate (
313- zip (
314- (self .train_dataset , self .val_dataset , self .test_dataset ),
315- ("train" , "val" , "test" ),
413+ if target_choice is None :
414+ print (
415+ f"Sampled test_target_choice ({ len (self .test_dataset .target_choice )} )"
316416 )
317- ):
318- if self .target_choices [i ] is None and data is not None :
319- new_choices = data .target_choice .copy ()
320- self .target_choices [i ] = new_choices
321- set_dict (
322- self ._metadata ,
323- self ._metadata_keys (METADATA_TARGET_CHOICE_KEY , split ),
324- new_choices ,
325- )
326- self ._store_metadata (self ._metadata )
417+ self .training_state .test_target_choice = self .test_dataset .target_choice
327418
328419 def _get_collate_fn (self ) -> MyDataLoader :
329420 return get_sft_collate_fn (ignore_index = self .ignore_index )
@@ -397,3 +488,8 @@ def smart_lastrec_info(self, tokenizer: HFTokenizer) -> SmartInitialInformation:
397488 max_initial_fraction = max_initial_fraction ,
398489 include_end_string = include_end_string ,
399490 )
491+
492+ def load_training_state (self , state_dict : Dict [str , torch .Tensor ]):
493+ if self .training_state is None :
494+ self .training_state = HelmetDataTrainState ()
495+ self .training_state .load_state_dict (state_dict )
0 commit comments