Skip to content

Commit c7e88e8

Browse files
committed
more ckpt
1 parent db0d931 commit c7e88e8

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/grouping_trainer/pretrain.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class PretrainingConfig(BaseModel):
131131
# Logging / checkpointing
132132
wandb_project: str = "grouping-trainer"
133133
num_logs: int = 2000
134-
num_checkpoints: int = 10
134+
num_checkpoints: int = 50
135135

136136

137137
def _load_model_and_tokenizer(base_model: str) -> tuple[Any, PreTrainedTokenizerBase]:
@@ -159,9 +159,13 @@ def _get_train_sampler(self, train_dataset=None) -> Sampler | None:
159159

160160

161161
def _sort_dataset_by_length_desc(dataset: Dataset) -> Dataset:
162-
lengths = np.array([len(input_ids) for input_ids in dataset["input_ids"]])
163-
sorted_indices = np.argsort(-lengths)
164-
sorted_dataset = dataset.select(sorted_indices)
162+
dataset_with_length = dataset.map(
163+
lambda batch: {"length": [len(input_ids) for input_ids in batch["input_ids"]]},
164+
batched=True,
165+
)
166+
assert isinstance(dataset_with_length, Dataset)
167+
sorted_dataset = dataset_with_length.sort("length", reverse=True).remove_columns("length")
168+
assert isinstance(sorted_dataset, Dataset)
165169
return sorted_dataset # type: ignore[bad-return]
166170

167171

0 commit comments

Comments
 (0)