Skip to content

Commit bdd3879

Browse files
author
Grzegorz Pluto-Prondzinski
authored
Switch RAFT dataset source from ought/raft to regisss/raft for compatibility with datasets>=4.0.0 (#2310)
1 parent 5d26212 commit bdd3879

3 files changed

Lines changed: 9 additions & 9 deletions

File tree

examples/language-modeling/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ The format of the text files (with extensions .text or .txt) is expected to be
530530

531531
To run prompt tuning finetuning, you can use `run_prompt_tuning_clm.py`.
532532
Here are single-card command examples for Llama2-7B:
533-
- single-card finetuning of meta-llama/Llama-2-7b-hf with dataset "ought/raft" and config "twitter_complaints":
533+
- single-card finetuning of meta-llama/Llama-2-7b-hf with dataset "regisss/raft" and config "default":
534534
```bash
535535
PT_HPU_LAZY_MODE=1 python3 run_prompt_tuning_clm.py \
536536
--model_name_or_path meta-llama/Llama-2-7b-hf \

examples/language-modeling/run_prompt_tuning_clm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ class DataTrainingArguments:
172172
"""
173173

174174
dataset_name: Optional[str] = field(
175-
default="ought/raft", metadata={"help": "The name of the dataset to use (via the datasets library)."}
175+
default="regisss/raft", metadata={"help": "The name of the dataset to use (via the datasets library)."}
176176
)
177177
dataset_config_name: Optional[str] = field(
178-
default="twitter_complaints",
178+
default="default",
179179
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
180180
)
181181
max_eval_samples: Optional[int] = field(
@@ -250,11 +250,11 @@ def main():
250250
streaming=data_args.streaming,
251251
trust_remote_code=model_args.trust_remote_code,
252252
)
253-
if data_args.dataset_name == "ought/raft" and data_args.dataset_config_name == "twitter_complaints":
254-
text_column = "Tweet text"
253+
if data_args.dataset_name == "regisss/raft" and data_args.dataset_config_name == "default":
254+
text_column = "Abstract Note"
255255
label_column = "text_label"
256256
else:
257-
raise ValueError("preprocess is only for ought/raft twitter_complaints now")
257+
raise ValueError("preprocess is only for regisss/raft default now")
258258
classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
259259
dataset = dataset.map(
260260
lambda x: {"text_label": [classes[label] for label in x["Label"]]},

tests/test_examples.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,21 +1085,21 @@ class MultiCardCausalLanguageModelingPromptTuningExampleTester(
10851085
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
10861086
):
10871087
TASK_NAME = "prompt-tuning"
1088-
DATASET_NAME = "ought/raft"
1088+
DATASET_NAME = "regisss/raft"
10891089

10901090

10911091
class MultiCardCausalLanguageModelingPrefixTuningExampleTester(
10921092
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
10931093
):
10941094
TASK_NAME = "prefix-tuning"
1095-
DATASET_NAME = "ought/raft"
1095+
DATASET_NAME = "regisss/raft"
10961096

10971097

10981098
class MultiCardCausalLanguageModelingPTuningExampleTester(
10991099
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_prompt_tuning_clm", multi_card=True
11001100
):
11011101
TASK_NAME = "p-tuning"
1102-
DATASET_NAME = "ought/raft"
1102+
DATASET_NAME = "regisss/raft"
11031103

11041104

11051105
class MultiCardMultiTastPromptPeftExampleTester(

0 commit comments

Comments
 (0)