Skip to content

Commit 65cccde

Browse files
committed
Helmet dataset: max_seq_lenght = None by default
1 parent 1e2ec65 commit 65cccde

6 files changed

Lines changed: 26 additions & 27 deletions

File tree

keys_values/data/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,14 @@
3838
class LongContextDataset(Dataset):
3939
"""
4040
Base class for some datasets we define here.
41-
4241
"""
4342

4443
def __init__(
4544
self,
4645
data: List[Dict[str, str]],
4746
tokenizer: Tokenizer,
4847
prompt_style: Union[str, PromptStyle],
49-
max_seq_length: int = -1,
48+
max_seq_length: Optional[int] = None,
5049
transform: Optional[Callable[[Dict[str, str]], Dict[str, str]]] = None,
5150
) -> None:
5251
self.data = data

keys_values/data/helmet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
validation. The rest is used for training.
9191
ignore_index: Value used to mask prompt positions in the labels.
9292
max_seq_length: Sequences longer than this (in tokens) are filtered
93-
out. Defaults to no filtering (``100000``).
93+
out. Defaults to no filtering.
9494
seed: Random seed for the train/val split.
9595
metadata_dir: If given, sequence lengths for every case are stored
9696
in a JSON metadata file in this directory so that subsequent
@@ -204,7 +204,7 @@ def _transform(
204204
new_seq_lengths.append(seq_length)
205205
else:
206206
seq_length = seq_lengths[idx]
207-
if seq_length > self.max_seq_length:
207+
if self.max_seq_length is not None and seq_length > self.max_seq_length:
208208
continue
209209
output = instance["output"]
210210
results.append(
@@ -215,10 +215,7 @@ def _transform(
215215
}
216216
)
217217
final_seq_lengths = new_seq_lengths if seq_lengths is None else seq_lengths
218-
print(
219-
f"Kept {len(results)} of {len(dataset)} {split} records "
220-
f"(<= {self.max_seq_length} tokens)"
221-
)
218+
print(f"Kept {len(results)} of {len(dataset)} {split} records")
222219
return results, final_seq_lengths, needs_store
223220

224221
def _get_seq_lengths(

keys_values/data/longbench_v2.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
mask_prompt: bool = True,
8484
val_split_fraction: float = 0.1,
8585
ignore_index: int = -100,
86-
max_seq_length: Optional[int] = None,
86+
max_seq_length: Optional[int] = 100000,
8787
seed: int = 42,
8888
repo_id: str = "THUDM/LongBench-v2",
8989
access_token: Optional[str] = None,
@@ -388,7 +388,7 @@ def smart_lastrec_info(self, tokenizer: HFTokenizer) -> SmartInitialInformation:
388388

389389
def filter_and_transform(
390390
dataset: Any,
391-
max_seq_length: int,
391+
max_seq_length: Optional[int],
392392
tokenizer: Tokenizer,
393393
seq_lengths: Optional[List[int]],
394394
head_model: str,
@@ -414,7 +414,10 @@ def filter_and_transform(
414414
test_results: RawDatasetType = []
415415
num_used = 0
416416
num_total = 0
417-
print(f"\nProcessing dataset, filtering out records with > {max_seq_length} tokens")
417+
if max_seq_length is not None:
418+
print(f"\nProcessing dataset, filtering out records with > {max_seq_length} tokens")
419+
else:
420+
print(f"\nProcessing dataset")
418421
if seq_lengths is None:
419422
# Show progress bar: This takes a while
420423
data_iter = tqdm(dataset)
@@ -445,7 +448,7 @@ def filter_and_transform(
445448
new_seq_lengths.append(seq_length)
446449
else:
447450
seq_length = seq_lengths[idx]
448-
if seq_length <= max_seq_length:
451+
if max_seq_length is None or seq_length <= max_seq_length:
449452
num_used += 1
450453
train_results.append(
451454
{
@@ -465,8 +468,8 @@ def filter_and_transform(
465468
"num_tokens_instruction": seq_length,
466469
}
467470
)
468-
print(f"\nKept {num_used} of {num_total} records: {max_seq_length} tokens or less")
469-
if test_set_tag == "rest":
471+
print(f"\nKept {num_used} of {num_total} records")
472+
if test_set_tag == "rest" and test_results:
470473
# Sort by increasing length
471474
test_results = sorted(
472475
test_results,

keys_values/data/module.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ class SequenceLengthFilteredDataModule(DataModule):
5353
use a :class:`EvaluationDataLoader`, which returns batches coupled with
5454
tasks. Here, we try to form micro batches with sequences of similar length,
5555
but there is no concept of macro batches.
56-
5756
"""
5857

5958
def __init__(
@@ -75,7 +74,7 @@ def __init__(
7574
ignore_index: The index to use for elements to be ignored in the
7675
label.
7776
max_seq_length: Sequences longer than this number of tokens are
78-
filtered out. Defaults to 100000.
77+
filtered out.
7978
seed: The random seed for creating the train/val splits and shuffling
8079
the dataset.
8180
trainloader_longest_first: If set, :meth:`train_dataloader` returns
@@ -94,7 +93,7 @@ def __init__(
9493
self.mask_prompt = mask_prompt
9594
self.val_split_fraction = val_split_fraction
9695
self.ignore_index = ignore_index
97-
self.max_seq_length = 100000 if max_seq_length is None else max_seq_length
96+
self.max_seq_length = max_seq_length
9897
self.seed = seed
9998
self.head_model = None
10099
self._is_sequence_classification = None
@@ -268,7 +267,7 @@ def setup(self, stage: str = "") -> None:
268267
),
269268
tokenizer=self.tokenizer,
270269
prompt_style=Default(),
271-
max_seq_length=-1,
270+
max_seq_length=None,
272271
)
273272
else:
274273
test_kwargs = None

keys_values/data/sequence_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
tokenizer: Tokenizer,
5757
prompt_style: Union[str, PromptStyle],
5858
class_labels: Iterable[str],
59-
max_seq_length: int = -1,
59+
max_seq_length: Optional[int] = None,
6060
transform: Optional[Callable[[Dict[str, str]], Dict[str, str]]] = None,
6161
) -> None:
6262
super().__init__(
@@ -108,8 +108,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
108108
if self.transform is not None:
109109
example = self.transform(example)
110110
prompt = self.prompt_style.apply(prompt=example["instruction"], **example)
111+
max_length = -1 if self.max_seq_length is None else self.max_seq_length
111112
encoded_prompt = self.tokenizer.encode(
112-
prompt, bos=False, eos=True, max_length=self.max_seq_length
113+
prompt, bos=False, eos=True, max_length=max_length,
113114
)
114115
token_counts = {"raw_plus_prompt_template": len(encoded_prompt)}
115116
raw_count = example.get("num_tokens_instruction")

keys_values/data/sft_dataset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
data: List[Dict[str, str]],
5656
tokenizer: Tokenizer,
5757
prompt_style: Union[str, PromptStyle],
58-
max_seq_length: int = -1,
58+
max_seq_length: Optional[int] = None,
5959
mask_prompt: bool = True,
6060
ignore_index: int = -100,
6161
transform: Optional[Callable[[Dict[str, str]], Dict[str, str]]] = None,
@@ -84,9 +84,10 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
8484
if self.transform is not None:
8585
example = self.transform(example)
8686
prompt = self.prompt_style.apply(prompt=example["instruction"], **example)
87+
max_length = -1 if self.max_seq_length is None else self.max_seq_length
8788
encoded_prompt = self.tokenizer.encode(
8889
prompt,
89-
max_length=self.max_seq_length,
90+
max_length=max_length,
9091
)
9192
targets = example["output"]
9293
if isinstance(targets, list):
@@ -99,15 +100,14 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
99100
_targets,
100101
bos=False,
101102
eos=True,
102-
max_length=self.max_seq_length,
103+
max_length=max_length,
103104
)
104105
encoded_prompt_and_response = torch.cat(
105106
(encoded_prompt, encoded_response)
106107
).type(torch.int64)
107-
msl = self.max_seq_length
108-
if 0 < msl < len(encoded_prompt_and_response):
109-
encoded_prompt_and_response = encoded_prompt_and_response[:msl]
110-
encoded_prompt_and_response[msl - 1] = self.tokenizer.eos_id
108+
if 0 < max_length < len(encoded_prompt_and_response):
109+
encoded_prompt_and_response = encoded_prompt_and_response[:max_length]
110+
encoded_prompt_and_response[max_length - 1] = self.tokenizer.eos_id
111111

112112
# The labels are the full prompt with response, but with the prompt masked out
113113
labels = encoded_prompt_and_response.clone()

0 commit comments

Comments
 (0)