File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ class PretrainingConfig(BaseModel):
131131 # Logging / checkpointing
132132 wandb_project : str = "grouping-trainer"
133133 num_logs : int = 2000
134- num_checkpoints : int = 10
134+ num_checkpoints : int = 50
135135
136136
137137def _load_model_and_tokenizer (base_model : str ) -> tuple [Any , PreTrainedTokenizerBase ]:
@@ -159,9 +159,13 @@ def _get_train_sampler(self, train_dataset=None) -> Sampler | None:
159159
160160
161161def _sort_dataset_by_length_desc (dataset : Dataset ) -> Dataset :
162- lengths = np .array ([len (input_ids ) for input_ids in dataset ["input_ids" ]])
163- sorted_indices = np .argsort (- lengths )
164- sorted_dataset = dataset .select (sorted_indices )
162+ dataset_with_length = dataset .map (
163+ lambda batch : {"length" : [len (input_ids ) for input_ids in batch ["input_ids" ]]},
164+ batched = True ,
165+ )
166+ assert isinstance (dataset_with_length , Dataset )
167+ sorted_dataset = dataset_with_length .sort ("length" , reverse = True ).remove_columns ("length" )
168+ assert isinstance (sorted_dataset , Dataset )
165169 return sorted_dataset # type: ignore[bad-return]
166170
167171
You can’t perform that action at this time.
0 commit comments